devirtualize service Data traits

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-27 03:17:20 +00:00
parent a6edaad6fc
commit 7ad7badd60
64 changed files with 1190 additions and 1176 deletions

View file

@ -14,7 +14,6 @@ pub(super) async fn account_data(subcommand: AccountData) -> Result<RoomMessageE
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services() let results = services()
.account_data .account_data
.db
.changes_since(room_id.as_deref(), &user_id, since)?; .changes_since(room_id.as_deref(), &user_id, since)?;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
@ -30,7 +29,6 @@ pub(super) async fn account_data(subcommand: AccountData) -> Result<RoomMessageE
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services() let results = services()
.account_data .account_data
.db
.get(room_id.as_deref(), &user_id, kind)?; .get(room_id.as_deref(), &user_id, kind)?;
let query_time = timer.elapsed(); let query_time = timer.elapsed();

View file

@ -1,4 +1,4 @@
use std::collections::HashMap; use std::{collections::HashMap, sync::Arc};
use ruma::{ use ruma::{
api::client::error::ErrorKind, api::client::error::ErrorKind,
@ -8,33 +8,25 @@ use ruma::{
}; };
use tracing::warn; use tracing::warn;
use crate::{services, utils, Error, KeyValueDatabase, Result}; use crate::{services, utils, Error, KeyValueDatabase, KvTree, Result};
pub trait Data: Send + Sync { pub(super) struct Data {
/// Places one event in the account data of the user and removes the roomuserdataid_accountdata: Arc<dyn KvTree>,
/// previous entry. roomusertype_roomuserdataid: Arc<dyn KvTree>,
fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()>;
/// Searches the account data for a specific kind.
fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>>;
/// Returns all changes to the account data that happened after `since`.
fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
} }
impl Data for KeyValueDatabase { impl Data {
pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
roomuserdataid_accountdata: db.roomuserdataid_accountdata.clone(),
roomusertype_roomuserdataid: db.roomusertype_roomuserdataid.clone(),
}
}
/// Places one event in the account data of the user and removes the /// Places one event in the account data of the user and removes the
/// previous entry. /// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))] pub(super) fn update(
fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: &RoomAccountDataEventType,
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value, data: &serde_json::Value,
) -> Result<()> { ) -> Result<()> {
let mut prefix = room_id let mut prefix = room_id
@ -80,9 +72,8 @@ impl Data for KeyValueDatabase {
} }
/// Searches the account data for a specific kind. /// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, kind))] pub(super) fn get(
fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, kind: &RoomAccountDataEventType,
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> { ) -> Result<Option<Box<serde_json::value::RawValue>>> {
let mut key = room_id let mut key = room_id
.map(ToString::to_string) .map(ToString::to_string)
@ -107,8 +98,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns all changes to the account data that happened after `since`. /// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))] pub(super) fn changes_since(
fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> { ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
let mut userdata = HashMap::new(); let mut userdata = HashMap::new();

View file

@ -2,40 +2,46 @@ mod data;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
pub use data::Data; use conduit::{Result, Server};
use data::Data;
use database::KeyValueDatabase;
use ruma::{ use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw, serde::Raw,
RoomId, UserId, RoomId, UserId,
}; };
use crate::Result;
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
/// Places one event in the account data of the user and removes the /// Places one event in the account data of the user and removes the
/// previous entry. /// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))] #[allow(clippy::needless_pass_by_value)]
pub fn update( pub fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value, data: &serde_json::Value,
) -> Result<()> { ) -> Result<()> {
self.db.update(room_id, user_id, event_type, data) self.db.update(room_id, user_id, &event_type, data)
} }
/// Searches the account data for a specific kind. /// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, event_type))] #[allow(clippy::needless_pass_by_value)]
pub fn get( pub fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> { ) -> Result<Option<Box<serde_json::value::RawValue>>> {
self.db.get(room_id, user_id, event_type) self.db.get(room_id, user_id, &event_type)
} }
/// Returns all changes to the account data that happened after `since`. /// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))] #[tracing::instrument(skip_all, name = "since")]
pub fn changes_since( pub fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> { ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
pub mod console; pub mod console;
mod create; mod create;
mod grant; mod grant;
@ -44,17 +47,16 @@ pub struct Command {
} }
impl Service { impl Service {
#[must_use] pub fn build(_server: &Arc<Server>, _db: &Arc<KeyValueDatabase>) -> Result<Arc<Self>> {
pub fn build() -> Arc<Self> {
let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT); let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT);
Arc::new(Self { Ok(Arc::new(Self {
sender, sender,
receiver: Mutex::new(receiver), receiver: Mutex::new(receiver),
handler_join: Mutex::new(None), handler_join: Mutex::new(None),
handle: Mutex::new(None), handle: Mutex::new(None),
#[cfg(feature = "console")] #[cfg(feature = "console")]
console: console::Console::new(), console: console::Console::new(),
}) }))
} }
pub async fn start_handler(self: &Arc<Self>) { pub async fn start_handler(self: &Arc<Self>) {

View file

@ -1,28 +1,23 @@
use std::sync::Arc;
use database::KvTree;
use ruma::api::appservice::Registration; use ruma::api::appservice::Registration;
use crate::{utils, Error, KeyValueDatabase, Result}; use crate::{utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync { pub struct Data {
/// Registers an appservice and returns the ID to the caller id_appserviceregistrations: Arc<dyn KvTree>,
fn register_appservice(&self, yaml: Registration) -> Result<String>;
/// Remove an appservice registration
///
/// # Arguments
///
/// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()>;
fn get_registration(&self, id: &str) -> Result<Option<Registration>>;
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>;
fn all(&self) -> Result<Vec<(String, Registration)>>;
} }
impl Data for KeyValueDatabase { impl Data {
pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
id_appserviceregistrations: db.id_appserviceregistrations.clone(),
}
}
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String> { pub(super) fn register_appservice(&self, yaml: Registration) -> Result<String> {
let id = yaml.id.as_str(); let id = yaml.id.as_str();
self.id_appserviceregistrations self.id_appserviceregistrations
.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
@ -35,13 +30,13 @@ impl Data for KeyValueDatabase {
/// # Arguments /// # Arguments
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> { pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations self.id_appserviceregistrations
.remove(service_name.as_bytes())?; .remove(service_name.as_bytes())?;
Ok(()) Ok(())
} }
fn get_registration(&self, id: &str) -> Result<Option<Registration>> { pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
self.id_appserviceregistrations self.id_appserviceregistrations
.get(id.as_bytes())? .get(id.as_bytes())?
.map(|bytes| { .map(|bytes| {
@ -51,14 +46,14 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> { pub(super) fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
utils::string_from_bytes(&id) utils::string_from_bytes(&id)
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
}))) })))
} }
fn all(&self) -> Result<Vec<(String, Registration)>> { pub fn all(&self) -> Result<Vec<(String, Registration)>> {
self.iter_ids()? self.iter_ids()?
.filter_map(Result::ok) .filter_map(Result::ok)
.map(move |id| { .map(move |id| {

View file

@ -113,13 +113,17 @@ impl TryFrom<Registration> for RegistrationInfo {
} }
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
registration_info: RwLock<BTreeMap<String, RegistrationInfo>>, registration_info: RwLock<BTreeMap<String, RegistrationInfo>>,
} }
use conduit::Server;
use database::KeyValueDatabase;
impl Service { impl Service {
pub fn build(db: Arc<dyn Data>) -> Result<Self> { pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
let mut registration_info = BTreeMap::new(); let mut registration_info = BTreeMap::new();
let db = Data::new(db);
// Inserting registrations into cache // Inserting registrations into cache
for appservice in db.all()? { for appservice in db.all()? {
registration_info.insert( registration_info.insert(

View file

@ -1,6 +1,9 @@
use std::collections::{BTreeMap, HashMap}; use std::{
collections::{BTreeMap, HashMap},
sync::Arc,
};
use async_trait::async_trait; use database::{Cork, KeyValueDatabase, KvTree};
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache; use lru_cache::LruCache;
use ruma::{ use ruma::{
@ -10,56 +13,60 @@ use ruma::{
}; };
use tracing::trace; use tracing::trace;
use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result}; use crate::{services, utils, Error, Result};
const COUNTER: &[u8] = b"c"; const COUNTER: &[u8] = b"c";
const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
#[async_trait] pub struct Data {
pub trait Data: Send + Sync { global: Arc<dyn KvTree>,
fn next_count(&self) -> Result<u64>; todeviceid_events: Arc<dyn KvTree>,
fn current_count(&self) -> Result<u64>; userroomid_joined: Arc<dyn KvTree>,
fn last_check_for_updates_id(&self) -> Result<u64>; userroomid_invitestate: Arc<dyn KvTree>,
fn update_check_for_updates_id(&self, id: u64) -> Result<()>; userroomid_leftstate: Arc<dyn KvTree>,
#[allow(unused_qualifications)] // async traits userroomid_notificationcount: Arc<dyn KvTree>,
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; userroomid_highlightcount: Arc<dyn KvTree>,
fn cleanup(&self) -> Result<()>; pduid_pdu: Arc<dyn KvTree>,
keychangeid_userid: Arc<dyn KvTree>,
fn cork(&self) -> Cork; roomusertype_roomuserdataid: Arc<dyn KvTree>,
fn cork_and_flush(&self) -> Cork; server_signingkeys: Arc<dyn KvTree>,
readreceiptid_readreceipt: Arc<dyn KvTree>,
fn memory_usage(&self) -> String; userid_lastonetimekeyupdate: Arc<dyn KvTree>,
fn clear_caches(&self, amount: u32); pub(super) db: Arc<KeyValueDatabase>,
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
fn remove_keypair(&self) -> Result<()>;
fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
fn database_version(&self) -> Result<u64>;
fn bump_database_version(&self, new_version: u64) -> Result<()>;
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { unimplemented!() }
fn backup_list(&self) -> Result<String> { Ok(String::new()) }
fn file_list(&self) -> Result<String> { Ok(String::new()) }
} }
#[async_trait] impl Data {
impl Data for KeyValueDatabase { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
fn next_count(&self) -> Result<u64> { Self {
global: db.global.clone(),
todeviceid_events: db.todeviceid_events.clone(),
userroomid_joined: db.userroomid_joined.clone(),
userroomid_invitestate: db.userroomid_invitestate.clone(),
userroomid_leftstate: db.userroomid_leftstate.clone(),
userroomid_notificationcount: db.userroomid_notificationcount.clone(),
userroomid_highlightcount: db.userroomid_highlightcount.clone(),
pduid_pdu: db.pduid_pdu.clone(),
keychangeid_userid: db.keychangeid_userid.clone(),
roomusertype_roomuserdataid: db.roomusertype_roomuserdataid.clone(),
server_signingkeys: db.server_signingkeys.clone(),
readreceiptid_readreceipt: db.readreceiptid_readreceipt.clone(),
userid_lastonetimekeyupdate: db.userid_lastonetimekeyupdate.clone(),
db: db.clone(),
}
}
pub fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?) utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes.")) .map_err(|_| Error::bad_database("Count has invalid bytes."))
} }
fn current_count(&self) -> Result<u64> { pub fn current_count(&self) -> Result<u64> {
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes.")) utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
}) })
} }
fn last_check_for_updates_id(&self) -> Result<u64> { pub fn last_check_for_updates_id(&self) -> Result<u64> {
self.global self.global
.get(LAST_CHECK_FOR_UPDATES_COUNT)? .get(LAST_CHECK_FOR_UPDATES_COUNT)?
.map_or(Ok(0_u64), |bytes| { .map_or(Ok(0_u64), |bytes| {
@ -68,16 +75,15 @@ impl Data for KeyValueDatabase {
}) })
} }
fn update_check_for_updates_id(&self, id: u64) -> Result<()> { pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.global self.global
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(()) Ok(())
} }
#[allow(unused_qualifications)] // async traits
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec(); let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone(); let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xFF); userid_prefix.push(0xFF);
@ -177,20 +183,20 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn cleanup(&self) -> Result<()> { self.db.cleanup() } pub fn cleanup(&self) -> Result<()> { self.db.db.cleanup() }
fn cork(&self) -> Cork { Cork::new(&self.db, false, false) } pub fn cork(&self) -> Cork { Cork::new(&self.db.db, false, false) }
fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) } pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db.db, true, false) }
fn memory_usage(&self) -> String { pub fn memory_usage(&self) -> String {
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); let auth_chain_cache = self.db.auth_chain_cache.lock().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); let appservice_in_room_cache = self.db.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); let lasttimelinecount_cache = self.db.lasttimelinecount_cache.lock().unwrap().len();
let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity(); let max_auth_chain_cache = self.db.auth_chain_cache.lock().unwrap().capacity();
let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity(); let max_appservice_in_room_cache = self.db.appservice_in_room_cache.read().unwrap().capacity();
let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity(); let max_lasttimelinecount_cache = self.db.lasttimelinecount_cache.lock().unwrap().capacity();
format!( format!(
"\ "\
@ -198,26 +204,26 @@ auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache}
appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache} appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache}
lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n
{}", {}",
self.db.memory_usage().unwrap_or_default() self.db.db.memory_usage().unwrap_or_default()
) )
} }
fn clear_caches(&self, amount: u32) { pub fn clear_caches(&self, amount: u32) {
if amount > 1 { if amount > 1 {
let c = &mut *self.auth_chain_cache.lock().unwrap(); let c = &mut *self.db.auth_chain_cache.lock().unwrap();
*c = LruCache::new(c.capacity()); *c = LruCache::new(c.capacity());
} }
if amount > 2 { if amount > 2 {
let c = &mut *self.appservice_in_room_cache.write().unwrap(); let c = &mut *self.db.appservice_in_room_cache.write().unwrap();
*c = HashMap::new(); *c = HashMap::new();
} }
if amount > 3 { if amount > 3 {
let c = &mut *self.lasttimelinecount_cache.lock().unwrap(); let c = &mut *self.db.lasttimelinecount_cache.lock().unwrap();
*c = HashMap::new(); *c = HashMap::new();
} }
} }
fn load_keypair(&self) -> Result<Ed25519KeyPair> { pub fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.global.get(b"keypair")?.map_or_else( let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| { || {
let keypair = utils::generate_keypair(); let keypair = utils::generate_keypair();
@ -249,9 +255,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
}) })
} }
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
fn add_signing_key( pub fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys, &self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
// Not atomic, but this is not critical // Not atomic, but this is not critical
@ -290,8 +296,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server. /// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { pub fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self let signingkeys = self
.db
.server_signingkeys .server_signingkeys
.get(origin.as_bytes())? .get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok()) .and_then(|bytes| serde_json::from_slice(&bytes).ok())
@ -308,20 +315,20 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
Ok(signingkeys) Ok(signingkeys)
} }
fn database_version(&self) -> Result<u64> { pub fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| { self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
}) })
} }
fn bump_database_version(&self, new_version: u64) -> Result<()> { pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.global.insert(b"version", &new_version.to_be_bytes())?; self.global.insert(b"version", &new_version.to_be_bytes())?;
Ok(()) Ok(())
} }
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.backup() } pub fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.db.backup() }
fn backup_list(&self) -> Result<String> { self.db.backup_list() } pub fn backup_list(&self) -> Result<String> { self.db.db.backup_list() }
fn file_list(&self) -> Result<String> { self.db.file_list() } pub fn file_list(&self) -> Result<String> { self.db.db.file_list() }
} }

View file

@ -1,3 +1,5 @@
use conduit::Server;
mod client; mod client;
mod data; mod data;
pub(super) mod emerg_access; pub(super) mod emerg_access;
@ -33,12 +35,12 @@ use tracing::{error, trace};
use url::Url; use url::Url;
use utils::MutexMap; use utils::MutexMap;
use crate::{services, Config, Result}; use crate::{services, Config, KeyValueDatabase, Result};
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
pub config: Config, pub config: Config,
pub cidr_range_denylist: Vec<IPAddress>, pub cidr_range_denylist: Vec<IPAddress>,
@ -62,7 +64,9 @@ pub struct Service {
} }
impl Service { impl Service {
pub fn load(db: Arc<dyn Data>, config: &Config) -> Result<Self> { pub fn build(server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
let config = &server.config;
let db = Data::new(db);
let keypair = db.load_keypair(); let keypair = db.load_keypair();
let keypair = match keypair { let keypair = match keypair {

View file

@ -1,4 +1,4 @@
use std::collections::BTreeMap; use std::{collections::BTreeMap, sync::Arc};
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -9,48 +9,24 @@ use ruma::{
OwnedRoomId, RoomId, UserId, OwnedRoomId, RoomId, UserId,
}; };
use crate::{services, utils, Error, KeyValueDatabase, Result}; use crate::{services, utils, Error, KeyValueDatabase, KvTree, Result};
pub(crate) trait Data: Send + Sync { pub(crate) struct Data {
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>; backupid_algorithm: Arc<dyn KvTree>,
backupid_etag: Arc<dyn KvTree>,
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; backupkeyid_backup: Arc<dyn KvTree>,
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>>;
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()>;
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>;
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>;
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>>;
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>;
} }
impl Data for KeyValueDatabase { impl Data {
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
backupid_algorithm: db.backupid_algorithm.clone(),
backupid_etag: db.backupid_etag.clone(),
backupkeyid_backup: db.backupkeyid_backup.clone(),
}
}
pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = services().globals.next_count()?.to_string(); let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
@ -66,7 +42,7 @@ impl Data for KeyValueDatabase {
Ok(version) Ok(version)
} }
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { pub(super) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
@ -83,7 +59,9 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> { pub(super) fn update_backup(
&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
@ -99,7 +77,7 @@ impl Data for KeyValueDatabase {
Ok(version.to_owned()) Ok(version.to_owned())
} }
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { pub(super) fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
let mut last_possible_key = prefix.clone(); let mut last_possible_key = prefix.clone();
@ -120,7 +98,7 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> { pub(super) fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
let mut last_possible_key = prefix.clone(); let mut last_possible_key = prefix.clone();
@ -147,7 +125,7 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> { pub(super) fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
@ -160,7 +138,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn add_key( pub(super) fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> { ) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
@ -185,7 +163,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { pub(super) fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes()); prefix.extend_from_slice(version.as_bytes());
@ -193,7 +171,7 @@ impl Data for KeyValueDatabase {
Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
} }
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { pub(super) fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
@ -208,7 +186,7 @@ impl Data for KeyValueDatabase {
.to_string()) .to_string())
} }
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> { pub(super) fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes()); prefix.extend_from_slice(version.as_bytes());
@ -257,7 +235,7 @@ impl Data for KeyValueDatabase {
Ok(rooms) Ok(rooms)
} }
fn get_room( pub(super) fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId, &self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> { ) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
@ -289,7 +267,7 @@ impl Data for KeyValueDatabase {
.collect()) .collect())
} }
fn get_session( pub(super) fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> { ) -> Result<Option<Raw<KeyBackupData>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
@ -309,7 +287,7 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { pub(super) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
@ -322,7 +300,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { pub(super) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
@ -337,7 +315,9 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { pub(super) fn delete_room_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());

View file

@ -1,20 +1,28 @@
use conduit::Server;
mod data; mod data;
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use conduit::Result;
use data::Data; use data::Data;
use database::KeyValueDatabase;
use ruma::{ use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw, serde::Raw,
OwnedRoomId, RoomId, UserId, OwnedRoomId, RoomId, UserId,
}; };
use crate::Result;
pub struct Service { pub struct Service {
pub(super) db: Arc<dyn Data>, pub(super) db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> { pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
self.db.create_backup(user_id, backup_metadata) self.db.create_backup(user_id, backup_metadata)
} }

View file

@ -1,37 +1,28 @@
use std::sync::Arc;
use conduit::debug_info; use conduit::debug_info;
use database::{KeyValueDatabase, KvTree};
use ruma::api::client::error::ErrorKind; use ruma::api::client::error::ErrorKind;
use tracing::debug; use tracing::debug;
use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, KeyValueDatabase, Result}; use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, Result};
pub(crate) trait Data: Send + Sync { pub struct Data {
fn create_file_metadata( mediaid_file: Arc<dyn KvTree>,
&self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, mediaid_user: Arc<dyn KvTree>,
content_type: Option<&str>, url_previews: Arc<dyn KvTree>,
) -> Result<Vec<u8>>;
fn delete_file_mxc(&self, mxc: String) -> Result<()>;
/// Returns content_disposition, content_type and the metadata key.
fn search_file_metadata(
&self, mxc: String, width: u32, height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>>;
fn get_all_media_keys(&self) -> Vec<Vec<u8>>;
// TODO: use this
#[allow(dead_code)]
fn remove_url_preview(&self, url: &str) -> Result<()>;
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()>;
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData>;
} }
impl Data for KeyValueDatabase { impl Data {
fn create_file_metadata( pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
mediaid_file: db.mediaid_file.clone(),
mediaid_user: db.mediaid_user.clone(),
url_previews: db.url_previews.clone(),
}
}
pub(super) fn create_file_metadata(
&self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
content_type: Option<&str>, content_type: Option<&str>,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
@ -65,7 +56,7 @@ impl Data for KeyValueDatabase {
Ok(key) Ok(key)
} }
fn delete_file_mxc(&self, mxc: String) -> Result<()> { pub(super) fn delete_file_mxc(&self, mxc: String) -> Result<()> {
debug!("MXC URI: {:?}", mxc); debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec(); let mut prefix = mxc.as_bytes().to_vec();
@ -91,7 +82,7 @@ impl Data for KeyValueDatabase {
} }
/// Searches for all files with the given MXC /// Searches for all files with the given MXC
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> { pub(super) fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
debug!("MXC URI: {:?}", mxc); debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec(); let mut prefix = mxc.as_bytes().to_vec();
@ -114,7 +105,7 @@ impl Data for KeyValueDatabase {
Ok(keys) Ok(keys)
} }
fn search_file_metadata( pub(super) fn search_file_metadata(
&self, mxc: String, width: u32, height: u32, &self, mxc: String, width: u32, height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> { ) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
let mut prefix = mxc.as_bytes().to_vec(); let mut prefix = mxc.as_bytes().to_vec();
@ -156,11 +147,13 @@ impl Data for KeyValueDatabase {
/// Gets all the media keys in our database (this includes all the metadata /// Gets all the media keys in our database (this includes all the metadata
/// associated with it such as width, height, content-type, etc) /// associated with it such as width, height, content-type, etc)
fn get_all_media_keys(&self) -> Vec<Vec<u8>> { self.mediaid_file.iter().map(|(key, _)| key).collect() } pub(crate) fn get_all_media_keys(&self) -> Vec<Vec<u8>> { self.mediaid_file.iter().map(|(key, _)| key).collect() }
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> { pub(super) fn set_url_preview(
&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration,
) -> Result<()> {
let mut value = Vec::<u8>::new(); let mut value = Vec::<u8>::new();
value.extend_from_slice(&timestamp.as_secs().to_be_bytes()); value.extend_from_slice(&timestamp.as_secs().to_be_bytes());
value.push(0xFF); value.push(0xFF);
@ -194,7 +187,7 @@ impl Data for KeyValueDatabase {
self.url_previews.insert(url.as_bytes(), &value) self.url_previews.insert(url.as_bytes(), &value)
} }
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { pub(super) fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
let values = self.url_previews.get(url.as_bytes()).ok()??; let values = self.url_previews.get(url.as_bytes()).ok()??;
let mut values = values.split(|&b| b == 0xFF); let mut values = values.split(|&b| b == 0xFF);

View file

@ -44,17 +44,17 @@ pub struct UrlPreviewData {
pub struct Service { pub struct Service {
server: Arc<Server>, server: Arc<Server>,
pub(super) db: Arc<dyn Data>, pub db: Data,
pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>, pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>,
} }
impl Service { impl Service {
pub fn build(server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Self { pub fn build(server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Self { Ok(Self {
server: server.clone(), server: server.clone(),
db: db.clone(), db: Data::new(db),
url_preview_mutex: RwLock::new(HashMap::new()), url_preview_mutex: RwLock::new(HashMap::new()),
} })
} }
/// Uploads a file. /// Uploads a file.

View file

@ -20,7 +20,7 @@ extern crate conduit_database as database;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
pub(crate) use conduit::{config, debug_error, debug_info, debug_warn, utils, Config, Error, PduCount, Result, Server}; pub(crate) use conduit::{config, debug_error, debug_info, debug_warn, utils, Config, Error, PduCount, Result, Server};
pub(crate) use database::KeyValueDatabase; pub(crate) use database::{KeyValueDatabase, KvTree};
pub use globals::{server_is_ours, user_is_local}; pub use globals::{server_is_ours, user_is_local};
pub use pdu::PduEvent; pub use pdu::PduEvent;
pub use services::Services; pub use services::Services;

View file

@ -1,4 +1,7 @@
use std::sync::Arc;
use conduit::debug_warn; use conduit::debug_warn;
use database::KvTree;
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use crate::{ use crate::{
@ -8,26 +11,20 @@ use crate::{
Error, KeyValueDatabase, Result, Error, KeyValueDatabase, Result,
}; };
pub trait Data: Send + Sync { pub struct Data {
/// Returns the latest presence event for the given user. presenceid_presence: Arc<dyn KvTree>,
fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>>; userid_presenceid: Arc<dyn KvTree>,
/// Adds a presence event which will be saved until a new event replaces it.
fn set_presence(
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>,
) -> Result<()>;
/// Removes the presence record for the given user from the database.
fn remove_presence(&self, user_id: &UserId) -> Result<()>;
/// Returns the most recent presence updates that happened after the event
/// with id `since`.
fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a>;
} }
impl Data for KeyValueDatabase { impl Data {
fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
presenceid_presence: db.presenceid_presence.clone(),
userid_presenceid: db.userid_presenceid.clone(),
}
}
pub fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes) let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
@ -44,7 +41,7 @@ impl Data for KeyValueDatabase {
} }
} }
fn set_presence( pub(super) fn set_presence(
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>, &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>, last_active_ago: Option<UInt>, status_msg: Option<String>,
) -> Result<()> { ) -> Result<()> {
@ -105,7 +102,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn remove_presence(&self, user_id: &UserId) -> Result<()> { pub(super) fn remove_presence(&self, user_id: &UserId) -> Result<()> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes) let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
@ -117,7 +114,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a> { pub fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a> {
Box::new( Box::new(
self.presenceid_presence self.presenceid_presence
.iter() .iter()

View file

@ -1,3 +1,5 @@
use conduit::Server;
mod data; mod data;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
@ -14,9 +16,10 @@ use tokio::{sync::Mutex, task::JoinHandle, time::sleep};
use tracing::{debug, error}; use tracing::{debug, error};
use crate::{ use crate::{
database::KeyValueDatabase,
services, user_is_local, services, user_is_local,
utils::{self}, utils::{self},
Config, Error, Result, Error, Result,
}; };
/// Represents data required to be kept in order to implement the presence /// Represents data required to be kept in order to implement the presence
@ -77,7 +80,7 @@ impl Presence {
} }
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, pub timer_sender: loole::Sender<(OwnedUserId, Duration)>,
timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>, timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>,
handler_join: Mutex<Option<JoinHandle<()>>>, handler_join: Mutex<Option<JoinHandle<()>>>,
@ -85,15 +88,16 @@ pub struct Service {
} }
impl Service { impl Service {
pub fn build(db: Arc<dyn Data>, config: &Config) -> Arc<Self> { pub fn build(server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Arc<Self>> {
let config = &server.config;
let (timer_sender, timer_receiver) = loole::unbounded(); let (timer_sender, timer_receiver) = loole::unbounded();
Arc::new(Self { Ok(Arc::new(Self {
db, db: Data::new(db),
timer_sender, timer_sender,
timer_receiver: Mutex::new(timer_receiver), timer_receiver: Mutex::new(timer_receiver),
handler_join: Mutex::new(None), handler_join: Mutex::new(None),
timeout_remote_users: config.presence_timeout_remote_users, timeout_remote_users: config.presence_timeout_remote_users,
}) }))
} }
pub async fn start_handler(self: &Arc<Self>) { pub async fn start_handler(self: &Arc<Self>) {

View file

@ -1,22 +1,25 @@
use std::sync::Arc;
use database::{KeyValueDatabase, KvTree};
use ruma::{ use ruma::{
api::client::push::{set_pusher, Pusher}, api::client::push::{set_pusher, Pusher},
UserId, UserId,
}; };
use crate::{utils, Error, KeyValueDatabase, Result}; use crate::{utils, Error, Result};
pub(crate) trait Data: Send + Sync { pub struct Data {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; senderkey_pusher: Arc<dyn KvTree>,
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>;
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a>;
} }
impl Data for KeyValueDatabase { impl Data {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
senderkey_pusher: db.senderkey_pusher.clone(),
}
}
pub(super) fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
match &pusher { match &pusher {
set_pusher::v3::PusherAction::Post(data) => { set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec(); let mut key = sender.as_bytes().to_vec();
@ -35,7 +38,7 @@ impl Data for KeyValueDatabase {
} }
} }
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { pub(super) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec(); let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xFF); senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes()); senderkey.extend_from_slice(pushkey.as_bytes());
@ -46,7 +49,7 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { pub(super) fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec(); let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -56,7 +59,7 @@ impl Data for KeyValueDatabase {
.collect() .collect()
} }
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> { pub(super) fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec(); let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::{fmt::Debug, mem, sync::Arc}; use std::{fmt::Debug, mem, sync::Arc};
@ -25,10 +28,16 @@ use tracing::{info, trace, warn};
use crate::{debug_info, services, Error, PduEvent, Result}; use crate::{debug_info, services, Error, PduEvent, Result};
pub struct Service { pub struct Service {
pub(super) db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
self.db.set_pusher(sender, pusher) self.db.set_pusher(sender, pusher)
} }

View file

@ -1,31 +1,26 @@
use std::sync::Arc;
use database::KvTree;
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, Result}; use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync { pub struct Data {
/// Creates or updates the alias to the given room id. alias_userid: Arc<dyn KvTree>,
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()>; alias_roomid: Arc<dyn KvTree>,
aliasid_alias: Arc<dyn KvTree>,
/// Forgets about an alias. Returns an error if the alias did not exist.
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>;
/// Looks up the roomid for the given alias.
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>>;
/// Finds the user who assigned the given alias to a room
fn who_created_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedUserId>>;
/// Returns all local aliases that point to the given room
fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>;
/// Returns all local aliases on the server
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
} }
impl Data for KeyValueDatabase { impl Data {
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
alias_userid: db.alias_userid.clone(),
alias_roomid: db.alias_roomid.clone(),
aliasid_alias: db.aliasid_alias.clone(),
}
}
pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
// Comes first as we don't want a stuck alias // Comes first as we don't want a stuck alias
self.alias_userid self.alias_userid
.insert(alias.alias().as_bytes(), user_id.as_bytes())?; .insert(alias.alias().as_bytes(), user_id.as_bytes())?;
@ -41,7 +36,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { pub(super) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
let mut prefix = room_id; let mut prefix = room_id;
prefix.push(0xFF); prefix.push(0xFF);
@ -60,7 +55,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.alias_roomid self.alias_roomid
.get(alias.alias().as_bytes())? .get(alias.alias().as_bytes())?
.map(|bytes| { .map(|bytes| {
@ -73,7 +68,7 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn who_created_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedUserId>> { pub(super) fn who_created_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedUserId>> {
self.alias_userid self.alias_userid
.get(alias.alias().as_bytes())? .get(alias.alias().as_bytes())?
.map(|bytes| { .map(|bytes| {
@ -86,7 +81,7 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn local_aliases_for_room<'a>( pub fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId, &'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
@ -100,7 +95,7 @@ impl Data for KeyValueDatabase {
})) }))
} }
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> { pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
Box::new( Box::new(
self.alias_roomid self.alias_roomid
.iter() .iter()

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -15,10 +18,16 @@ use ruma::{
use crate::{services, Error, Result}; use crate::{services, Error, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
if alias == services().globals.admin_alias && user_id != services().globals.server_user { if alias == services().globals.admin_alias && user_id != services().globals.server_user {

View file

@ -1,16 +1,25 @@
use std::{mem::size_of, sync::Arc}; use std::{mem::size_of, sync::Arc};
use database::KvTree;
use crate::{utils, KeyValueDatabase, Result}; use crate::{utils, KeyValueDatabase, Result};
pub trait Data: Send + Sync { pub(super) struct Data {
fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<[u64]>>>; shorteventid_authchain: Arc<dyn KvTree>,
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()>; db: Arc<KeyValueDatabase>,
} }
impl Data for KeyValueDatabase { impl Data {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
shorteventid_authchain: db.shorteventid_authchain.clone(),
db: db.clone(),
}
}
pub(super) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
// Check RAM cache // Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { if let Some(result) = self.db.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result))); return Ok(Some(Arc::clone(result)));
} }
@ -29,7 +38,8 @@ impl Data for KeyValueDatabase {
if let Some(chain) = chain { if let Some(chain) = chain {
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.db
.auth_chain_cache
.lock() .lock()
.unwrap() .unwrap()
.insert(vec![key[0]], Arc::clone(&chain)); .insert(vec![key[0]], Arc::clone(&chain));
@ -41,7 +51,7 @@ impl Data for KeyValueDatabase {
Ok(None) Ok(None)
} }
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> { pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
// Only persist single events in db // Only persist single events in db
if key.len() == 1 { if key.len() == 1 {
self.shorteventid_authchain.insert( self.shorteventid_authchain.insert(
@ -54,7 +64,8 @@ impl Data for KeyValueDatabase {
} }
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.db
.auth_chain_cache
.lock() .lock()
.unwrap() .unwrap()
.insert(key, auth_chain); .insert(key, auth_chain);

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::{ use std::{
collections::{BTreeSet, HashSet}, collections::{BTreeSet, HashSet},
@ -11,10 +14,16 @@ use tracing::{debug, error, trace, warn};
use crate::{services, Error, Result}; use crate::{services, Error, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
pub async fn event_ids_iter<'a>( pub async fn event_ids_iter<'a>(
&self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>, &self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> { ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {

View file

@ -1,31 +1,34 @@
use std::sync::Arc;
use database::KvTree;
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use crate::{utils, Error, KeyValueDatabase, Result}; use crate::{utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync { pub(super) struct Data {
/// Adds the room to the public room directory publicroomids: Arc<dyn KvTree>,
fn set_public(&self, room_id: &RoomId) -> Result<()>;
/// Removes the room from the public room directory.
fn set_not_public(&self, room_id: &RoomId) -> Result<()>;
/// Returns true if the room is in the public room directory.
fn is_public_room(&self, room_id: &RoomId) -> Result<bool>;
/// Returns the unsorted public room directory
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
} }
impl Data for KeyValueDatabase { impl Data {
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
publicroomids: db.publicroomids.clone(),
}
}
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) } pub(super) fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.insert(room_id.as_bytes(), &[])
}
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { pub(super) fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.remove(room_id.as_bytes())
}
pub(super) fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
} }
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { pub(super) fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| { Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes(&bytes) utils::string_from_bytes(&bytes)

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -8,10 +11,16 @@ use ruma::{OwnedRoomId, RoomId};
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }

View file

@ -1,6 +1,8 @@
use conduit::Server;
use database::KeyValueDatabase;
mod parse_incoming_pdu; mod parse_incoming_pdu;
mod signing_keys; mod signing_keys;
pub struct Service;
use std::{ use std::{
cmp, cmp,
@ -32,6 +34,8 @@ use tracing::{debug, error, info, trace, warn};
use super::state_compressor::CompressedStateEvent; use super::state_compressor::CompressedStateEvent;
use crate::{debug_error, debug_info, pdu, services, Error, PduEvent, Result}; use crate::{debug_error, debug_info, pdu, services, Error, PduEvent, Result};
pub struct Service;
// We use some AsyncRecursiveType hacks here so we can call async funtion // We use some AsyncRecursiveType hacks here so we can call async funtion
// recursively. // recursively.
type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>; type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
@ -41,6 +45,8 @@ type AsyncRecursiveCanonicalJsonResult<'a> =
AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>>; AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>>;
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, _db: &Arc<KeyValueDatabase>) -> Result<Self> { Ok(Self {}) }
/// When receiving an event one needs to: /// When receiving an event one needs to:
/// 0. Check the server is in the room /// 0. Check the server is in the room
/// 1. Skip the PDU if we already know about it /// 1. Skip the PDU if we already know about it

View file

@ -1,22 +1,22 @@
use std::sync::Arc;
use database::KvTree;
use ruma::{DeviceId, RoomId, UserId}; use ruma::{DeviceId, RoomId, UserId};
use crate::{KeyValueDatabase, Result}; use crate::{KeyValueDatabase, Result};
pub trait Data: Send + Sync { pub struct Data {
fn lazy_load_was_sent_before( lazyloadedids: Arc<dyn KvTree>,
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool>;
fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()>;
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()>;
} }
impl Data for KeyValueDatabase { impl Data {
fn lazy_load_was_sent_before( pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
lazyloadedids: db.lazyloadedids.clone(),
}
}
pub(super) fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> { ) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
@ -29,7 +29,7 @@ impl Data for KeyValueDatabase {
Ok(self.lazyloadedids.get(&key)?.is_some()) Ok(self.lazyloadedids.get(&key)?.is_some())
} }
fn lazy_load_confirm_delivery( pub(super) fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>, confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> { ) -> Result<()> {
@ -49,7 +49,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { pub(super) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes()); prefix.extend_from_slice(device_id.as_bytes());

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
@ -11,13 +14,20 @@ use tokio::sync::Mutex;
use crate::{PduCount, Result}; use crate::{PduCount, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub lazy_load_waiting: Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>, pub lazy_load_waiting: Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
lazy_load_waiting: Mutex::new(HashMap::new()),
})
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn lazy_load_was_sent_before( pub fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,

View file

@ -1,20 +1,29 @@
use std::sync::Arc;
use database::KvTree;
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use tracing::error; use tracing::error;
use crate::{services, utils, Error, KeyValueDatabase, Result}; use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync { pub struct Data {
fn exists(&self, room_id: &RoomId) -> Result<bool>; disabledroomids: Arc<dyn KvTree>,
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; bannedroomids: Arc<dyn KvTree>,
fn is_disabled(&self, room_id: &RoomId) -> Result<bool>; roomid_shortroomid: Arc<dyn KvTree>,
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; pduid_pdu: Arc<dyn KvTree>,
fn is_banned(&self, room_id: &RoomId) -> Result<bool>;
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()>;
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
} }
impl Data for KeyValueDatabase { impl Data {
fn exists(&self, room_id: &RoomId) -> Result<bool> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
disabledroomids: db.disabledroomids.clone(),
bannedroomids: db.bannedroomids.clone(),
roomid_shortroomid: db.roomid_shortroomid.clone(),
pduid_pdu: db.pduid_pdu.clone(),
}
}
pub(super) fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? { let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(), Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false), None => return Ok(false),
@ -29,7 +38,7 @@ impl Data for KeyValueDatabase {
.is_some()) .is_some())
} }
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { pub(super) fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes(&bytes) utils::string_from_bytes(&bytes)
@ -39,11 +48,11 @@ impl Data for KeyValueDatabase {
})) }))
} }
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
} }
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled { if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?; self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else { } else {
@ -53,9 +62,11 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) } pub(super) fn is_banned(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some())
}
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
if banned { if banned {
self.bannedroomids.insert(room_id.as_bytes(), &[])?; self.bannedroomids.insert(room_id.as_bytes(), &[])?;
} else { } else {
@ -65,7 +76,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { pub(super) fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.bannedroomids.iter().map( Box::new(self.bannedroomids.iter().map(
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
let room_id = utils::string_from_bytes(&room_id_bytes) let room_id = utils::string_from_bytes(&room_id_bytes)

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -8,10 +11,16 @@ use ruma::{OwnedRoomId, RoomId};
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
/// Checks if a room exists. /// Checks if a room exists.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn exists(&self, room_id: &RoomId) -> Result<bool> { self.db.exists(room_id) } pub fn exists(&self, room_id: &RoomId) -> Result<bool> { self.db.exists(room_id) }

View file

@ -1,15 +1,22 @@
use std::sync::Arc;
use database::KvTree;
use ruma::{CanonicalJsonObject, EventId}; use ruma::{CanonicalJsonObject, EventId};
use crate::{Error, KeyValueDatabase, PduEvent, Result}; use crate::{Error, KeyValueDatabase, PduEvent, Result};
pub trait Data: Send + Sync { pub struct Data {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>; eventid_outlierpdu: Arc<dyn KvTree>,
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>;
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>;
} }
impl Data for KeyValueDatabase { impl Data {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
eventid_outlierpdu: db.eventid_outlierpdu.clone(),
}
}
pub(super) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu self.eventid_outlierpdu
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| { .map_or(Ok(None), |pdu| {
@ -17,7 +24,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { pub(super) fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu self.eventid_outlierpdu
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| { .map_or(Ok(None), |pdu| {
@ -25,7 +32,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { pub(super) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.eventid_outlierpdu.insert( self.eventid_outlierpdu.insert(
event_id.as_bytes(), event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -8,10 +11,16 @@ use ruma::{CanonicalJsonObject, EventId};
use crate::{PduEvent, Result}; use crate::{PduEvent, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
/// Returns the pdu from the outlier tree. /// Returns the pdu from the outlier tree.
pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.db.get_outlier_pdu_json(event_id) self.db.get_outlier_pdu_json(event_id)

View file

@ -1,30 +1,33 @@
use std::{mem::size_of, sync::Arc}; use std::{mem::size_of, sync::Arc};
use database::KvTree;
use ruma::{EventId, RoomId, UserId}; use ruma::{EventId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result}; use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result};
pub trait Data: Send + Sync { pub(super) struct Data {
fn add_relation(&self, from: u64, to: u64) -> Result<()>; tofrom_relation: Arc<dyn KvTree>,
#[allow(clippy::type_complexity)] referencedevents: Arc<dyn KvTree>,
fn relations_until<'a>( softfailedeventids: Arc<dyn KvTree>,
&'a self, user_id: &'a UserId, room_id: u64, target: u64, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()>;
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool>;
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>;
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool>;
} }
impl Data for KeyValueDatabase { impl Data {
fn add_relation(&self, from: u64, to: u64) -> Result<()> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
tofrom_relation: db.tofrom_relation.clone(),
referencedevents: db.referencedevents.clone(),
softfailedeventids: db.softfailedeventids.clone(),
}
}
pub(super) fn add_relation(&self, from: u64, to: u64) -> Result<()> {
let mut key = to.to_be_bytes().to_vec(); let mut key = to.to_be_bytes().to_vec();
key.extend_from_slice(&from.to_be_bytes()); key.extend_from_slice(&from.to_be_bytes());
self.tofrom_relation.insert(&key, &[])?; self.tofrom_relation.insert(&key, &[])?;
Ok(()) Ok(())
} }
fn relations_until<'a>( pub(super) fn relations_until<'a>(
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let prefix = target.to_be_bytes().to_vec(); let prefix = target.to_be_bytes().to_vec();
@ -63,7 +66,7 @@ impl Data for KeyValueDatabase {
)) ))
} }
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids { for prev in event_ids {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes()); key.extend_from_slice(prev.as_bytes());
@ -73,17 +76,17 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { pub(super) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes()); key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some()) Ok(self.referencedevents.get(&key)?.is_some())
} }
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.softfailedeventids.insert(event_id.as_bytes(), &[]) self.softfailedeventids.insert(event_id.as_bytes(), &[])
} }
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { pub(super) fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids self.softfailedeventids
.get(event_id.as_bytes()) .get(event_id.as_bytes())
.map(|o| o.is_some()) .map(|o| o.is_some())

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -13,7 +16,7 @@ use serde::Deserialize;
use crate::{services, PduCount, PduEvent, Result}; use crate::{services, PduCount, PduEvent, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, db: Data,
} }
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
@ -27,6 +30,12 @@ struct ExtractRelatesToEventId {
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
#[tracing::instrument(skip(self, from, to))] #[tracing::instrument(skip(self, from, to))]
pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> {
match (from, to) { match (from, to) {

View file

@ -1,5 +1,6 @@
use std::mem::size_of; use std::{mem::size_of, sync::Arc};
use database::KvTree;
use ruma::{ use ruma::{
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
serde::Raw, serde::Raw,
@ -11,32 +12,22 @@ use crate::{services, utils, Error, KeyValueDatabase, Result};
type AnySyncEphemeralRoomEventIter<'a> = type AnySyncEphemeralRoomEventIter<'a> =
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>; Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
pub trait Data: Send + Sync { pub(super) struct Data {
/// Replaces the previous read receipt. roomuserid_privateread: Arc<dyn KvTree>,
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()>; roomuserid_lastprivatereadupdate: Arc<dyn KvTree>,
readreceiptid_readreceipt: Arc<dyn KvTree>,
/// Returns an iterator over the most recent read_receipts in a room that
/// happened after the event with id `since`.
fn readreceipts_since(&self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'_>;
/// Sets a private read marker at `count`.
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>;
/// Returns the private read marker.
///
/// TODO: use this?
#[allow(dead_code)]
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>;
/// Returns the count of the last typing update in this room.
///
/// TODO: use this?
#[allow(dead_code)]
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>;
} }
impl Data for KeyValueDatabase { impl Data {
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
roomuserid_privateread: db.roomuserid_privateread.clone(),
roomuserid_lastprivatereadupdate: db.roomuserid_lastprivatereadupdate.clone(),
readreceiptid_readreceipt: db.readreceiptid_readreceipt.clone(),
}
}
pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -71,9 +62,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn readreceipts_since<'a>( pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'a> {
&'a self, room_id: &RoomId, since: u64,
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
let prefix2 = prefix.clone(); let prefix2 = prefix.clone();
@ -107,7 +96,7 @@ impl Data for KeyValueDatabase {
) )
} }
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
@ -119,7 +108,7 @@ impl Data for KeyValueDatabase {
.insert(&key, &services().globals.next_count()?.to_be_bytes()) .insert(&key, &services().globals.next_count()?.to_be_bytes())
} }
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
@ -133,7 +122,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { pub(super) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -8,10 +11,16 @@ use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserI
use crate::{services, Result}; use crate::{services, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
/// Replaces the previous read receipt. /// Replaces the previous read receipt.
pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
self.db.readreceipt_update(user_id, room_id, event)?; self.db.readreceipt_update(user_id, room_id, event)?;

View file

@ -1,30 +1,24 @@
use std::sync::Arc;
use database::KvTree;
use ruma::RoomId; use ruma::RoomId;
use crate::{services, utils, KeyValueDatabase, Result}; use crate::{services, utils, KeyValueDatabase, Result};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>; type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
pub trait Data: Send + Sync { pub struct Data {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; tokenids: Arc<dyn KvTree>,
fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>;
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>;
} }
/// Splits a string into tokens used as keys in the search inverted index impl Data {
/// pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
/// This may be used to tokenize both message bodies (for indexing) or search Self {
/// queries (for querying). tokenids: db.tokenids.clone(),
fn tokenize(body: &str) -> impl Iterator<Item = String> + '_ { }
body.split_terminator(|c: char| !c.is_alphanumeric()) }
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)
.map(str::to_lowercase)
}
impl Data for KeyValueDatabase { pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let mut batch = tokenize(message_body).map(|word| { let mut batch = tokenize(message_body).map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes()); key.extend_from_slice(word.as_bytes());
@ -36,7 +30,7 @@ impl Data for KeyValueDatabase {
self.tokenids.insert_batch(&mut batch) self.tokenids.insert_batch(&mut batch)
} }
fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let batch = tokenize(message_body).map(|word| { let batch = tokenize(message_body).map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes()); key.extend_from_slice(word.as_bytes());
@ -52,7 +46,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services() let prefix = services()
.rooms .rooms
.short .short
@ -88,3 +82,14 @@ impl Data for KeyValueDatabase {
Ok(Some((Box::new(common_elements), words))) Ok(Some((Box::new(common_elements), words)))
} }
} }
/// Splits a string into tokens used as keys in the search inverted index
///
/// This may be used to tokenize both message bodies (for indexing) or search
/// queries (for querying).
fn tokenize(body: &str) -> impl Iterator<Item = String> + '_ {
body.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)
.map(str::to_lowercase)
}

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -8,10 +11,16 @@ use ruma::RoomId;
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
self.db.index_pdu(shortroomid, pdu_id, message_body) self.db.index_pdu(shortroomid, pdu_id, message_body)

View file

@ -1,33 +1,33 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::warn;
use database::{KeyValueDatabase, KvTree};
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use tracing::warn;
use crate::{services, utils, Error, KeyValueDatabase, Result}; use crate::{services, utils, Error, Result};
pub trait Data: Send + Sync { pub(super) struct Data {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>; eventid_shorteventid: Arc<dyn KvTree>,
shorteventid_eventid: Arc<dyn KvTree>,
fn multi_get_or_create_shorteventid(&self, event_id: &[&EventId]) -> Result<Vec<u64>>; statekey_shortstatekey: Arc<dyn KvTree>,
shortstatekey_statekey: Arc<dyn KvTree>,
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>>; roomid_shortroomid: Arc<dyn KvTree>,
statehash_shortstatehash: Arc<dyn KvTree>,
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64>;
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>>;
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>;
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>;
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>;
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64>;
} }
impl Data for KeyValueDatabase { impl Data {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
eventid_shorteventid: db.eventid_shorteventid.clone(),
shorteventid_eventid: db.shorteventid_eventid.clone(),
statekey_shortstatekey: db.statekey_shortstatekey.clone(),
shortstatekey_statekey: db.shortstatekey_statekey.clone(),
roomid_shortroomid: db.roomid_shortroomid.clone(),
statehash_shortstatehash: db.statehash_shortstatehash.clone(),
}
}
pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else { } else {
@ -42,7 +42,7 @@ impl Data for KeyValueDatabase {
Ok(short) Ok(short)
} }
fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> { pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len()); let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len());
let keys = event_ids let keys = event_ids
.iter() .iter()
@ -74,7 +74,7 @@ impl Data for KeyValueDatabase {
Ok(ret) Ok(ret)
} }
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> { pub(super) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF); statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes()); statekey_vec.extend_from_slice(state_key.as_bytes());
@ -90,7 +90,7 @@ impl Data for KeyValueDatabase {
Ok(short) Ok(short)
} }
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> { pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF); statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes()); statekey_vec.extend_from_slice(state_key.as_bytes());
@ -109,7 +109,7 @@ impl Data for KeyValueDatabase {
Ok(short) Ok(short)
} }
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { pub(super) fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
let bytes = self let bytes = self
.shorteventid_eventid .shorteventid_eventid
.get(&shorteventid.to_be_bytes())? .get(&shorteventid.to_be_bytes())?
@ -124,7 +124,7 @@ impl Data for KeyValueDatabase {
Ok(event_id) Ok(event_id)
} }
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { pub(super) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
let bytes = self let bytes = self
.shortstatekey_statekey .shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())? .get(&shortstatekey.to_be_bytes())?
@ -150,7 +150,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns (shortstatehash, already_existed) /// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? {
( (
utils::u64_from_bytes(&shortstatehash) utils::u64_from_bytes(&shortstatehash)
@ -165,14 +165,14 @@ impl Data for KeyValueDatabase {
}) })
} }
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { pub(super) fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid self.roomid_shortroomid
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
.transpose() .transpose()
} }
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else { } else {

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -7,10 +10,16 @@ use ruma::{events::StateEventType, EventId, RoomId};
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
self.db.get_or_create_shorteventid(event_id) self.db.get_or_create_shorteventid(event_id)
} }

View file

@ -1,8 +1,11 @@
use std::{ use std::{
fmt::{Display, Formatter}, fmt::{Display, Formatter},
str::FromStr, str::FromStr,
sync::Arc,
}; };
use conduit::Server;
use database::KeyValueDatabase;
use lru_cache::LruCache; use lru_cache::LruCache;
use ruma::{ use ruma::{
api::{ api::{
@ -330,6 +333,16 @@ impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk {
} }
impl Service { impl Service {
pub fn build(server: &Arc<Server>, _db: &Arc<KeyValueDatabase>) -> Result<Self> {
let config = &server.config;
Ok(Self {
roomid_spacehierarchy_cache: Mutex::new(LruCache::new(
(f64::from(config.roomid_spacehierarchy_cache_capacity) * config.conduit_cache_capacity_modifier)
as usize,
)),
})
}
///Gets the response for the space hierarchy over federation request ///Gets the response for the space hierarchy over federation request
/// ///
///Panics if the room does not exist, so a check if the room exists should ///Panics if the room does not exist, so a check if the room exists should

View file

@ -1,39 +1,27 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use conduit::utils::mutex_map; use conduit::utils::mutex_map;
use database::KvTree;
use ruma::{EventId, OwnedEventId, RoomId}; use ruma::{EventId, OwnedEventId, RoomId};
use crate::{utils, Error, KeyValueDatabase, Result}; use crate::{utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync { pub struct Data {
/// Returns the last state hash key added to the db for the given room. shorteventid_shortstatehash: Arc<dyn KvTree>,
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>>; roomid_pduleaves: Arc<dyn KvTree>,
roomid_shortstatehash: Arc<dyn KvTree>,
/// Set the state hash to a new version, but does not update `state_cache`.
fn set_room_state(
&self,
room_id: &RoomId,
new_shortstatehash: u64,
_mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()>;
/// Associates a state with an event.
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()>;
/// Returns all events we would send as the `prev_events` of the next event.
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>>;
/// Replace the forward extremities of the room.
fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()>;
} }
impl Data for KeyValueDatabase { impl Data {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
shorteventid_shortstatehash: db.shorteventid_shortstatehash.clone(),
roomid_pduleaves: db.roomid_pduleaves.clone(),
roomid_shortstatehash: db.roomid_shortstatehash.clone(),
}
}
pub(super) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash self.roomid_shortstatehash
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -43,7 +31,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn set_room_state( pub(super) fn set_room_state(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
new_shortstatehash: u64, new_shortstatehash: u64,
@ -54,13 +42,13 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
self.shorteventid_shortstatehash self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(()) Ok(())
} }
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { pub(super) fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -76,7 +64,7 @@ impl Data for KeyValueDatabase {
.collect() .collect()
} }
fn set_forward_extremities( pub(super) fn set_forward_extremities(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_ids: Vec<OwnedEventId>, event_ids: Vec<OwnedEventId>,

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
@ -22,10 +25,16 @@ use super::state_compressor::CompressedStateEvent;
use crate::{services, utils::calculate_hash, Error, PduEvent, Result}; use crate::{services, utils::calculate_hash, Error, PduEvent, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
/// Set the room to the given statehash and update caches. /// Set the room to the given statehash and update caches.
pub async fn force_state( pub async fn force_state(
&self, &self,

View file

@ -1,56 +1,25 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait; use database::KvTree;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
#[async_trait] pub struct Data {
pub trait Data: Send + Sync { eventid_shorteventid: Arc<dyn KvTree>,
/// Builds a StateMap by iterating over all keys that start shorteventid_shortstatehash: Arc<dyn KvTree>,
/// with state_hash, this gives the full state for the given state_hash.
#[allow(unused_qualifications)] // async traits
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>>;
#[allow(unused_qualifications)] // async traits
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>>;
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>>;
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>>;
/// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>>;
/// Returns the full room state.
#[allow(unused_qualifications)] // async traits
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>>;
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>>;
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>>;
} }
#[async_trait] impl Data {
impl Data for KeyValueDatabase { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
eventid_shorteventid: db.eventid_shorteventid.clone(),
shorteventid_shortstatehash: db.shorteventid_shortstatehash.clone(),
}
}
#[allow(unused_qualifications)] // async traits #[allow(unused_qualifications)] // async traits
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services() let full_state = services()
.rooms .rooms
.state_compressor .state_compressor
@ -76,7 +45,9 @@ impl Data for KeyValueDatabase {
} }
#[allow(unused_qualifications)] // async traits #[allow(unused_qualifications)] // async traits
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { pub(super) async fn state_full(
&self, shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = services() let full_state = services()
.rooms .rooms
.state_compressor .state_compressor
@ -116,7 +87,8 @@ impl Data for KeyValueDatabase {
/// Returns a single PDU from `room_id` with key (`event_type`, /// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`). /// `state_key`).
fn state_get_id( #[allow(clippy::unused_self)]
pub(super) fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services() let Some(shortstatekey) = services()
@ -148,7 +120,7 @@ impl Data for KeyValueDatabase {
/// Returns a single PDU from `room_id` with key (`event_type`, /// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`). /// `state_key`).
fn state_get( pub(super) fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> { ) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)? self.state_get_id(shortstatehash, event_type, state_key)?
@ -156,7 +128,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid self.eventid_shorteventid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| { .map_or(Ok(None), |shorteventid| {
@ -173,7 +145,9 @@ impl Data for KeyValueDatabase {
/// Returns the full room state. /// Returns the full room state.
#[allow(unused_qualifications)] // async traits #[allow(unused_qualifications)] // async traits
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { pub(super) async fn room_state_full(
&self, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await self.state_full(current_shortstatehash).await
} else { } else {
@ -183,7 +157,7 @@ impl Data for KeyValueDatabase {
/// Returns a single PDU from `room_id` with key (`event_type`, /// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`). /// `state_key`).
fn room_state_get_id( pub(super) fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
@ -195,7 +169,7 @@ impl Data for KeyValueDatabase {
/// Returns a single PDU from `room_id` with key (`event_type`, /// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`). /// `state_key`).
fn room_state_get( pub(super) fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> { ) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {

View file

@ -1,3 +1,8 @@
use std::sync::Mutex as StdMutex;
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -29,12 +34,25 @@ use tracing::{error, warn};
use crate::{service::pdu::PduBuilder, services, Error, PduEvent, Result}; use crate::{service::pdu::PduBuilder, services, Error, PduEvent, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>, pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>,
pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>, pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>,
} }
impl Service { impl Service {
pub fn build(server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
let config = &server.config;
Ok(Self {
db: Data::new(db),
server_visibility_cache: StdMutex::new(LruCache::new(
(f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
user_visibility_cache: StdMutex::new(LruCache::new(
(f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
})
}
/// Builds a StateMap by iterating over all keys that start /// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash. /// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]

View file

@ -17,101 +17,53 @@ use crate::{
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>; type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>; type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
pub trait Data: Send + Sync { use std::sync::Arc;
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()>;
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; use database::KvTree;
fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool>; pub struct Data {
userroomid_joined: Arc<dyn KvTree>,
/// Makes a user forget a room. roomuserid_joined: Arc<dyn KvTree>,
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()>; userroomid_invitestate: Arc<dyn KvTree>,
roomuserid_invitecount: Arc<dyn KvTree>,
/// Returns an iterator of all servers participating in this room. userroomid_leftstate: Arc<dyn KvTree>,
fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a>; roomuserid_leftcount: Arc<dyn KvTree>,
roomid_inviteviaservers: Arc<dyn KvTree>,
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool>; roomuseroncejoinedids: Arc<dyn KvTree>,
roomid_joinedcount: Arc<dyn KvTree>,
/// Returns an iterator of all rooms a server participates in (as far as we roomid_invitedcount: Arc<dyn KvTree>,
/// know). roomserverids: Arc<dyn KvTree>,
fn server_rooms<'a>(&'a self, server: &ServerName) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; serverroomids: Arc<dyn KvTree>,
db: Arc<KeyValueDatabase>,
/// Returns an iterator of all joined members of a room.
fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>;
/// Returns an iterator of all our local users
/// in the room, even if they're deactivated/guests
fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = OwnedUserId> + 'a>;
/// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest)
fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = OwnedUserId> + 'a>;
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>>;
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>>;
/// Returns an iterator over all User IDs who ever joined a room.
///
/// TODO: use this?
#[allow(dead_code)]
fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>;
/// Returns an iterator over all invited members of a room.
fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>;
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>;
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>;
/// Returns an iterator over all rooms this user joined.
fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_>;
/// Returns an iterator over all rooms a user was invited to.
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a>;
fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>>;
fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>>;
/// Returns an iterator over all rooms a user left.
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a>;
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
/// Gets the servers to either accept or decline invites via for a given
/// room.
fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a>;
/// Add the given servers the list to accept or decline invites via for a
/// given room.
///
/// TODO: use this?
#[allow(dead_code)]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()>;
} }
impl Data for KeyValueDatabase { impl Data {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
userroomid_joined: db.userroomid_joined.clone(),
roomuserid_joined: db.roomuserid_joined.clone(),
userroomid_invitestate: db.userroomid_invitestate.clone(),
roomuserid_invitecount: db.roomuserid_invitecount.clone(),
userroomid_leftstate: db.userroomid_leftstate.clone(),
roomuserid_leftcount: db.roomuserid_leftcount.clone(),
roomid_inviteviaservers: db.roomid_inviteviaservers.clone(),
roomuseroncejoinedids: db.roomuseroncejoinedids.clone(),
roomid_joinedcount: db.roomid_joinedcount.clone(),
roomid_invitedcount: db.roomid_invitedcount.clone(),
roomserverids: db.roomserverids.clone(),
serverroomids: db.serverroomids.clone(),
db: db.clone(),
}
}
pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[]) self.roomuseroncejoinedids.insert(&userroom_id, &[])
} }
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec(); let roomid = room_id.as_bytes().to_vec();
let mut roomuser_id = roomid.clone(); let mut roomuser_id = roomid.clone();
@ -134,7 +86,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn mark_as_invited( pub(super) fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, &self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>, invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()> { ) -> Result<()> {
@ -179,7 +131,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec(); let roomid = room_id.as_bytes().to_vec();
let mut roomuser_id = roomid.clone(); let mut roomuser_id = roomid.clone();
@ -206,7 +158,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
let mut joinedcount = 0_u64; let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64; let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new(); let mut joined_servers = HashSet::new();
@ -256,7 +208,8 @@ impl Data for KeyValueDatabase {
self.serverroomids.insert(&serverroom_id, &[])?; self.serverroomids.insert(&serverroom_id, &[])?;
} }
self.appservice_in_room_cache self.db
.appservice_in_room_cache
.write() .write()
.unwrap() .unwrap()
.remove(room_id); .remove(room_id);
@ -265,8 +218,9 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self, room_id, appservice))] #[tracing::instrument(skip(self, room_id, appservice))]
fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> { pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
let maybe = self let maybe = self
.db
.appservice_in_room_cache .appservice_in_room_cache
.read() .read()
.unwrap() .unwrap()
@ -288,7 +242,8 @@ impl Data for KeyValueDatabase {
.room_members(room_id) .room_members(room_id)
.any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())));
self.appservice_in_room_cache self.db
.appservice_in_room_cache
.write() .write()
.unwrap() .unwrap()
.entry(room_id.to_owned()) .entry(room_id.to_owned())
@ -301,7 +256,7 @@ impl Data for KeyValueDatabase {
/// Makes a user forget a room. /// Makes a user forget a room.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -318,7 +273,9 @@ impl Data for KeyValueDatabase {
/// Returns an iterator of all servers participating in this room. /// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { pub(super) fn room_servers<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -336,7 +293,7 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> { pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
let mut key = server.as_bytes().to_vec(); let mut key = server.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
@ -347,7 +304,9 @@ impl Data for KeyValueDatabase {
/// Returns an iterator of all rooms a server participates in (as far as we /// Returns an iterator of all rooms a server participates in (as far as we
/// know). /// know).
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn server_rooms<'a>(&'a self, server: &ServerName) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { pub(super) fn server_rooms<'a>(
&'a self, server: &ServerName,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec(); let mut prefix = server.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -366,7 +325,7 @@ impl Data for KeyValueDatabase {
/// Returns an iterator of all joined members of a room. /// Returns an iterator of all joined members of a room.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { pub(super) fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -385,7 +344,7 @@ impl Data for KeyValueDatabase {
/// Returns an iterator of all our local users in the room, even if they're /// Returns an iterator of all our local users in the room, even if they're
/// deactivated/guests /// deactivated/guests
fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> { pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> {
Box::new( Box::new(
self.room_members(room_id) self.room_members(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -396,7 +355,9 @@ impl Data for KeyValueDatabase {
/// Returns an iterator of all our local joined users in a room who are /// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest) /// active (not deactivated, not guest)
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> { pub(super) fn active_local_users_in_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> {
Box::new( Box::new(
self.local_users_in_room(room_id) self.local_users_in_room(room_id)
.filter(|user| !services().users.is_deactivated(user).unwrap_or(true)), .filter(|user| !services().users.is_deactivated(user).unwrap_or(true)),
@ -405,7 +366,7 @@ impl Data for KeyValueDatabase {
/// Returns the number of users which are currently in a room /// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_joinedcount self.roomid_joinedcount
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
@ -414,7 +375,7 @@ impl Data for KeyValueDatabase {
/// Returns the number of users which are currently invited to a room /// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_invitedcount self.roomid_invitedcount
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
@ -423,7 +384,9 @@ impl Data for KeyValueDatabase {
/// Returns an iterator over all User IDs who ever joined a room. /// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { pub(super) fn room_useroncejoined<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -446,7 +409,9 @@ impl Data for KeyValueDatabase {
/// Returns an iterator over all invited members of a room. /// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { pub(super) fn room_members_invited<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -468,7 +433,7 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
@ -483,7 +448,7 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
@ -496,7 +461,7 @@ impl Data for KeyValueDatabase {
/// Returns an iterator over all rooms this user joined. /// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_> { pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_> {
Box::new( Box::new(
self.userroomid_joined self.userroomid_joined
.scan_prefix(user_id.as_bytes().to_vec()) .scan_prefix(user_id.as_bytes().to_vec())
@ -516,7 +481,7 @@ impl Data for KeyValueDatabase {
/// Returns an iterator over all rooms a user was invited to. /// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -543,7 +508,9 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { pub(super) fn invite_state(
&self, user_id: &UserId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
@ -560,7 +527,9 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { pub(super) fn left_state(
&self, user_id: &UserId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
@ -578,7 +547,7 @@ impl Data for KeyValueDatabase {
/// Returns an iterator over all rooms a user left. /// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -605,7 +574,7 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -614,7 +583,7 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -623,7 +592,7 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -632,7 +601,7 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -641,7 +610,9 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { pub(super) fn servers_invite_via<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let key = room_id.as_bytes().to_vec(); let key = room_id.as_bytes().to_vec();
Box::new( Box::new(
@ -665,7 +636,7 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self let mut prev_servers = self
.servers_invite_via(room_id) .servers_invite_via(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)

View file

@ -1,6 +1,8 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::Server;
use data::Data; use data::Data;
use database::KeyValueDatabase;
use itertools::Itertools; use itertools::Itertools;
use ruma::{ use ruma::{
events::{ events::{
@ -24,10 +26,16 @@ use crate::{service::appservice::RegistrationInfo, services, user_is_local, Erro
mod data; mod data;
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
/// Update current membership data. /// Update current membership data.
#[tracing::instrument(skip(self, last_state))] #[tracing::instrument(skip(self, last_state))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]

View file

@ -1,21 +1,28 @@
use std::{collections::HashSet, mem::size_of, sync::Arc}; use std::{collections::HashSet, mem::size_of, sync::Arc};
use database::KvTree;
use super::CompressedStateEvent; use super::CompressedStateEvent;
use crate::{utils, Error, KeyValueDatabase, Result}; use crate::{utils, Error, KeyValueDatabase, Result};
pub struct StateDiff { pub(super) struct StateDiff {
pub parent: Option<u64>, pub(super) parent: Option<u64>,
pub added: Arc<HashSet<CompressedStateEvent>>, pub(super) added: Arc<HashSet<CompressedStateEvent>>,
pub removed: Arc<HashSet<CompressedStateEvent>>, pub(super) removed: Arc<HashSet<CompressedStateEvent>>,
} }
pub trait Data: Send + Sync { pub struct Data {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff>; shortstatehash_statediff: Arc<dyn KvTree>,
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>;
} }
impl Data for KeyValueDatabase { impl Data {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
shortstatehash_statediff: db.shortstatehash_statediff.clone(),
}
}
pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self let value = self
.shortstatehash_statediff .shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())? .get(&shortstatehash.to_be_bytes())?
@ -53,7 +60,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { pub(super) fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() { for new in diff.added.iter() {
value.extend_from_slice(&new[..]); value.extend_from_slice(&new[..]);

View file

@ -1,3 +1,8 @@
use std::sync::Mutex as StdMutex;
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::{ use std::{
collections::HashSet, collections::HashSet,
@ -41,16 +46,25 @@ type ParentStatesVec = Vec<(
)>; )>;
type HashSetCompressStateEvent = Result<(u64, Arc<HashSet<CompressedStateEvent>>, Arc<HashSet<CompressedStateEvent>>)>; type HashSetCompressStateEvent = Result<(u64, Arc<HashSet<CompressedStateEvent>>, Arc<HashSet<CompressedStateEvent>>)>;
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
pub stateinfo_cache: StateInfoLruCache, pub stateinfo_cache: StateInfoLruCache,
} }
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
impl Service { impl Service {
pub fn build(server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
let config = &server.config;
Ok(Self {
db: Data::new(db),
stateinfo_cache: StdMutex::new(LruCache::new(
(f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
})
}
/// Returns a stack with info on shortstatehash, full state, added diff and /// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer. /// removed diff for the selected shortstatehash and each parent layer.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]

View file

@ -1,22 +1,24 @@
use std::mem::size_of; use std::{mem::size_of, sync::Arc};
use database::KvTree;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>; type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
pub trait Data: Send + Sync { pub struct Data {
fn threads_until<'a>( threadid_userids: Arc<dyn KvTree>,
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads,
) -> PduEventIterResult<'a>;
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()>;
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>>;
} }
impl Data for KeyValueDatabase { impl Data {
fn threads_until<'a>( pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
threadid_userids: db.threadid_userids.clone(),
}
}
pub(super) fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> { ) -> PduEventIterResult<'a> {
let prefix = services() let prefix = services()
@ -50,7 +52,7 @@ impl Data for KeyValueDatabase {
)) ))
} }
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
let users = participants let users = participants
.iter() .iter()
.map(|user| user.as_bytes()) .map(|user| user.as_bytes())
@ -62,7 +64,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> { pub(super) fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? { if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some( Ok(Some(
users users

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
@ -13,10 +16,16 @@ use serde_json::json;
use crate::{services, Error, PduEvent, Result}; use crate::{services, Error, PduEvent, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
pub fn threads_until<'a>( pub fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads,
) -> Result<impl Iterator<Item = Result<(u64, PduEvent)>> + 'a> { ) -> Result<impl Iterator<Item = Result<(u64, PduEvent)>> + 'a> {

View file

@ -1,76 +1,36 @@
use std::{collections::hash_map, mem::size_of, sync::Arc}; use std::{collections::hash_map, mem::size_of, sync::Arc};
use database::KvTree;
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use tracing::error; use tracing::error;
use super::PduCount; use super::PduCount;
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
pub trait Data: Send + Sync { pub struct Data {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount>; eventid_pduid: Arc<dyn KvTree>,
pduid_pdu: Arc<dyn KvTree>,
/// Returns the `count` of this pdu's id. eventid_outlierpdu: Arc<dyn KvTree>,
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>>; userroomid_notificationcount: Arc<dyn KvTree>,
userroomid_highlightcount: Arc<dyn KvTree>,
/// Returns the json of a pdu. db: Arc<KeyValueDatabase>,
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>;
/// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>;
/// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>>;
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>;
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>>;
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>>;
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>>;
/// Adds a new pdu to the timeline
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()>;
// Adds a new pdu to the backfilled timeline
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()>;
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()>;
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
#[allow(clippy::type_complexity)]
fn pdus_until<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;
/// Returns an iterator over all events in a room that happened after the
/// event with id `from` in chronological order.
#[allow(clippy::type_complexity)]
fn pdus_after<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;
fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()>;
} }
impl Data for KeyValueDatabase { impl Data {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
eventid_pduid: db.eventid_pduid.clone(),
pduid_pdu: db.pduid_pdu.clone(),
eventid_outlierpdu: db.eventid_outlierpdu.clone(),
userroomid_notificationcount: db.userroomid_notificationcount.clone(),
userroomid_highlightcount: db.userroomid_highlightcount.clone(),
db: db.clone(),
}
}
pub(super) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self match self
.db
.lasttimelinecount_cache .lasttimelinecount_cache
.lock() .lock()
.unwrap() .unwrap()
@ -96,7 +56,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> { pub(super) fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pdu_id| pdu_count(&pdu_id)) .map(|pdu_id| pdu_count(&pdu_id))
@ -104,7 +64,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { pub(super) fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else( self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| { || {
self.eventid_outlierpdu self.eventid_outlierpdu
@ -117,7 +77,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { pub(super) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
@ -131,10 +91,12 @@ impl Data for KeyValueDatabase {
} }
/// Returns the pdu's id. /// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) } pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
self.eventid_pduid.get(event_id.as_bytes())
}
/// Returns the pdu. /// Returns the pdu.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { pub(super) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
@ -150,7 +112,7 @@ impl Data for KeyValueDatabase {
/// Returns the pdu. /// Returns the pdu.
/// ///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline. /// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { pub(super) fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(pdu) = self if let Some(pdu) = self
.get_non_outlier_pdu(event_id)? .get_non_outlier_pdu(event_id)?
.map_or_else( .map_or_else(
@ -173,7 +135,7 @@ impl Data for KeyValueDatabase {
/// Returns the pdu. /// Returns the pdu.
/// ///
/// This does __NOT__ check the outliers `Tree`. /// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> { pub(super) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some( Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
@ -182,7 +144,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { pub(super) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some( Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
@ -190,13 +152,16 @@ impl Data for KeyValueDatabase {
}) })
} }
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> { pub(super) fn append_pdu(
&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64,
) -> Result<()> {
self.pduid_pdu.insert( self.pduid_pdu.insert(
pdu_id, pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?; )?;
self.lasttimelinecount_cache self.db
.lasttimelinecount_cache
.lock() .lock()
.unwrap() .unwrap()
.insert(pdu.room_id.clone(), PduCount::Normal(count)); .insert(pdu.room_id.clone(), PduCount::Normal(count));
@ -207,7 +172,9 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> { pub(super) fn prepend_backfill_pdu(
&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject,
) -> Result<()> {
self.pduid_pdu.insert( self.pduid_pdu.insert(
pdu_id, pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
@ -220,7 +187,7 @@ impl Data for KeyValueDatabase {
} }
/// Removes a pdu and creates a new one with the same id. /// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() { if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert( self.pduid_pdu.insert(
pdu_id, pdu_id,
@ -236,7 +203,7 @@ impl Data for KeyValueDatabase {
/// Returns an iterator over all events and their tokens in a room that /// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological /// happened before the event with id `until` in reverse-chronological
/// order. /// order.
fn pdus_until<'a>( pub(super) fn pdus_until<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, until, 1, true)?; let (prefix, current) = count_to_id(room_id, until, 1, true)?;
@ -260,7 +227,7 @@ impl Data for KeyValueDatabase {
)) ))
} }
fn pdus_after<'a>( pub(super) fn pdus_after<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, from, 1, false)?; let (prefix, current) = count_to_id(room_id, from, 1, false)?;
@ -284,7 +251,7 @@ impl Data for KeyValueDatabase {
)) ))
} }
fn increment_notification_counts( pub(super) fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>, &self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()> { ) -> Result<()> {
let mut notifies_batch = Vec::new(); let mut notifies_batch = Vec::new();
@ -311,7 +278,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> { pub(super) fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..]) let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 = let second_last_u64 =
@ -324,7 +291,9 @@ fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
} }
} }
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> { pub(super) fn count_to_id(
room_id: &RoomId, count: PduCount, offset: u64, subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = services() let prefix = services()
.rooms .rooms
.short .short

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::{ use std::{
@ -74,12 +77,19 @@ struct ExtractBody {
} }
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
pub lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>, pub lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
})
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)?

View file

@ -1,5 +1,7 @@
use std::collections::BTreeMap; use std::{collections::BTreeMap, sync::Arc};
use conduit::Server;
use database::KeyValueDatabase;
use ruma::{ use ruma::{
api::federation::transactions::edu::{Edu, TypingContent}, api::federation::transactions::edu::{Edu, TypingContent},
events::SyncEphemeralRoomEvent, events::SyncEphemeralRoomEvent,
@ -23,6 +25,14 @@ pub struct Service {
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, _db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
typing: RwLock::new(BTreeMap::new()),
last_typing_update: RwLock::new(BTreeMap::new()),
typing_update_sender: broadcast::channel(100).0,
})
}
/// Sets a user as typing until the timeout timestamp is reached or /// Sets a user as typing until the timeout timestamp is reached or
/// roomtyping_remove is called. /// roomtyping_remove is called.
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {

View file

@ -1,28 +1,30 @@
use std::sync::Arc;
use database::KvTree;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, Result}; use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync { pub struct Data {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; userroomid_notificationcount: Arc<dyn KvTree>,
userroomid_highlightcount: Arc<dyn KvTree>,
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>; roomuserid_lastnotificationread: Arc<dyn KvTree>,
roomsynctoken_shortstatehash: Arc<dyn KvTree>,
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>; userroomid_joined: Arc<dyn KvTree>,
// Returns the count at which the last reset_notification_counts was called
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>;
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()>;
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>>;
fn get_shared_rooms<'a>(
&'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>>;
} }
impl Data for KeyValueDatabase { impl Data {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
userroomid_notificationcount: db.userroomid_notificationcount.clone(),
userroomid_highlightcount: db.userroomid_highlightcount.clone(),
roomuserid_lastnotificationread: db.roomuserid_lastnotificationread.clone(),
roomsynctoken_shortstatehash: db.roomsynctoken_shortstatehash.clone(),
userroomid_joined: db.userroomid_joined.clone(),
}
}
pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -41,7 +43,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { pub(super) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -53,7 +55,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { pub(super) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -65,7 +67,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
@ -81,7 +83,9 @@ impl Data for KeyValueDatabase {
.unwrap_or(0)) .unwrap_or(0))
} }
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { pub(super) fn associate_token_shortstatehash(
&self, room_id: &RoomId, token: u64, shortstatehash: u64,
) -> Result<()> {
let shortroomid = services() let shortroomid = services()
.rooms .rooms
.short .short
@ -95,7 +99,7 @@ impl Data for KeyValueDatabase {
.insert(&key, &shortstatehash.to_be_bytes()) .insert(&key, &shortstatehash.to_be_bytes())
} }
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services() let shortroomid = services()
.rooms .rooms
.short .short
@ -114,7 +118,7 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn get_shared_rooms<'a>( pub(super) fn get_shared_rooms<'a>(
&'a self, users: Vec<OwnedUserId>, &'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| { let iterators = users.into_iter().map(move |user_id| {

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -8,10 +11,16 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.reset_notification_counts(user_id, room_id) self.db.reset_notification_counts(user_id, room_id)
} }

View file

@ -1,31 +1,31 @@
use std::sync::Arc;
use ruma::{ServerName, UserId}; use ruma::{ServerName, UserId};
use super::{Destination, SendingEvent}; use super::{Destination, SendingEvent};
use crate::{services, utils, Error, KeyValueDatabase, Result}; use crate::{services, utils, Error, KeyValueDatabase, KvTree, Result};
type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>; type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>;
type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>; type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>;
pub trait Data: Send + Sync { pub struct Data {
fn active_requests(&self) -> OutgoingSendingIter<'_>; servercurrentevent_data: Arc<dyn KvTree>,
fn active_requests_for(&self, destination: &Destination) -> SendingEventIter<'_>; servernameevent_data: Arc<dyn KvTree>,
fn delete_active_request(&self, key: Vec<u8>) -> Result<()>; servername_educount: Arc<dyn KvTree>,
fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()>; _db: Arc<KeyValueDatabase>,
/// TODO: use this?
#[allow(dead_code)]
fn delete_all_requests_for(&self, destination: &Destination) -> Result<()>;
fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>>;
fn queued_requests<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a>;
fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()>;
fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>;
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64>;
} }
impl Data for KeyValueDatabase { impl Data {
fn active_requests<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a> { pub(super) fn new(db: Arc<KeyValueDatabase>) -> Self {
Self {
servercurrentevent_data: db.servercurrentevent_data.clone(),
servernameevent_data: db.servernameevent_data.clone(),
servername_educount: db.servername_educount.clone(),
_db: db,
}
}
pub fn active_requests(&self) -> OutgoingSendingIter<'_> {
Box::new( Box::new(
self.servercurrentevent_data self.servercurrentevent_data
.iter() .iter()
@ -33,9 +33,7 @@ impl Data for KeyValueDatabase {
) )
} }
fn active_requests_for<'a>( pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> SendingEventIter<'a> {
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a> {
let prefix = destination.get_prefix(); let prefix = destination.get_prefix();
Box::new( Box::new(
self.servercurrentevent_data self.servercurrentevent_data
@ -44,9 +42,9 @@ impl Data for KeyValueDatabase {
) )
} }
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) } pub(super) fn delete_active_request(&self, key: &[u8]) -> Result<()> { self.servercurrentevent_data.remove(key) }
fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { pub(super) fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix(); let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?; self.servercurrentevent_data.remove(&key)?;
@ -55,7 +53,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { pub(super) fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix(); let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
self.servercurrentevent_data.remove(&key).unwrap(); self.servercurrentevent_data.remove(&key).unwrap();
@ -68,7 +66,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>> { pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>> {
let mut batch = Vec::new(); let mut batch = Vec::new();
let mut keys = Vec::new(); let mut keys = Vec::new();
for (destination, event) in requests { for (destination, event) in requests {
@ -91,7 +89,7 @@ impl Data for KeyValueDatabase {
Ok(keys) Ok(keys)
} }
fn queued_requests<'a>( pub fn queued_requests<'a>(
&'a self, destination: &Destination, &'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a> { ) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a> {
let prefix = destination.get_prefix(); let prefix = destination.get_prefix();
@ -102,7 +100,7 @@ impl Data for KeyValueDatabase {
); );
} }
fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()> { pub(super) fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()> {
for (e, key) in events { for (e, key) in events {
if key.is_empty() { if key.is_empty() {
continue; continue;
@ -120,12 +118,12 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
self.servername_educount self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes()) .insert(server_name.as_bytes(), &last_count.to_be_bytes())
} }
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> { pub fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount self.servername_educount
.get(server_name.as_bytes())? .get(server_name.as_bytes())?
.map_or(Ok(0), |bytes| { .map_or(Ok(0), |bytes| {

View file

@ -1,3 +1,5 @@
use conduit::Server;
mod appservice; mod appservice;
mod data; mod data;
pub mod resolve; pub mod resolve;
@ -15,10 +17,10 @@ use ruma::{
use tokio::{sync::Mutex, task::JoinHandle}; use tokio::{sync::Mutex, task::JoinHandle};
use tracing::{error, warn}; use tracing::{error, warn};
use crate::{server_is_ours, services, Config, Error, Result}; use crate::{server_is_ours, services, Error, KeyValueDatabase, Result};
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
/// The state for a given state hash. /// The state for a given state hash.
sender: loole::Sender<Msg>, sender: loole::Sender<Msg>,
@ -51,16 +53,17 @@ pub enum SendingEvent {
} }
impl Service { impl Service {
pub fn build(db: Arc<dyn Data>, config: &Config) -> Arc<Self> { pub fn build(server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Arc<Self>> {
let config = &server.config;
let (sender, receiver) = loole::unbounded(); let (sender, receiver) = loole::unbounded();
Arc::new(Self { Ok(Arc::new(Self {
db, db: Data::new(db.clone()),
sender, sender,
receiver: Mutex::new(receiver), receiver: Mutex::new(receiver),
handler_join: Mutex::new(None), handler_join: Mutex::new(None),
startup_netburst: config.startup_netburst, startup_netburst: config.startup_netburst,
startup_netburst_keep: config.startup_netburst_keep, startup_netburst_keep: config.startup_netburst_keep,
}) }))
} }
pub async fn close(&self) { pub async fn close(&self) {

View file

@ -144,7 +144,7 @@ impl Service {
if self.startup_netburst_keep >= 0 && entry.len() >= keep { if self.startup_netburst_keep >= 0 && entry.len() >= keep {
warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key)); warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key));
self.db self.db
.delete_active_request(key) .delete_active_request(&key)
.expect("active request deleted"); .expect("active request deleted");
} else { } else {
entry.push(event); entry.push(event);

View file

@ -1,12 +1,7 @@
use std::{ use std::sync::Arc;
collections::{BTreeMap, HashMap},
sync::{Arc, Mutex as StdMutex},
};
use conduit::{debug_info, Result, Server}; use conduit::{debug_info, Result, Server};
use database::KeyValueDatabase; use database::KeyValueDatabase;
use lru_cache::LruCache;
use tokio::sync::{broadcast, Mutex, RwLock};
use tracing::{debug, info, trace}; use tracing::{debug, info, trace};
use crate::{ use crate::{
@ -15,9 +10,9 @@ use crate::{
}; };
pub struct Services { pub struct Services {
pub rooms: rooms::Service,
pub appservice: appservice::Service, pub appservice: appservice::Service,
pub pusher: pusher::Service, pub pusher: pusher::Service,
pub rooms: rooms::Service,
pub transaction_ids: transaction_ids::Service, pub transaction_ids: transaction_ids::Service,
pub uiaa: uiaa::Service, pub uiaa: uiaa::Service,
pub users: users::Service, pub users: users::Service,
@ -34,111 +29,41 @@ pub struct Services {
impl Services { impl Services {
pub async fn build(server: Arc<Server>, db: Arc<KeyValueDatabase>) -> Result<Self> { pub async fn build(server: Arc<Server>, db: Arc<KeyValueDatabase>) -> Result<Self> {
let config = &server.config;
Ok(Self { Ok(Self {
appservice: appservice::Service::build(db.clone())?,
pusher: pusher::Service {
db: db.clone(),
},
rooms: rooms::Service { rooms: rooms::Service {
alias: rooms::alias::Service { alias: rooms::alias::Service::build(&server, &db)?,
db: db.clone(), auth_chain: rooms::auth_chain::Service::build(&server, &db)?,
}, directory: rooms::directory::Service::build(&server, &db)?,
auth_chain: rooms::auth_chain::Service { event_handler: rooms::event_handler::Service::build(&server, &db)?,
db: db.clone(), lazy_loading: rooms::lazy_loading::Service::build(&server, &db)?,
}, metadata: rooms::metadata::Service::build(&server, &db)?,
directory: rooms::directory::Service { outlier: rooms::outlier::Service::build(&server, &db)?,
db: db.clone(), pdu_metadata: rooms::pdu_metadata::Service::build(&server, &db)?,
}, read_receipt: rooms::read_receipt::Service::build(&server, &db)?,
event_handler: rooms::event_handler::Service, search: rooms::search::Service::build(&server, &db)?,
lazy_loading: rooms::lazy_loading::Service { short: rooms::short::Service::build(&server, &db)?,
db: db.clone(), state: rooms::state::Service::build(&server, &db)?,
lazy_load_waiting: Mutex::new(HashMap::new()), state_accessor: rooms::state_accessor::Service::build(&server, &db)?,
}, state_cache: rooms::state_cache::Service::build(&server, &db)?,
metadata: rooms::metadata::Service { state_compressor: rooms::state_compressor::Service::build(&server, &db)?,
db: db.clone(), timeline: rooms::timeline::Service::build(&server, &db)?,
}, threads: rooms::threads::Service::build(&server, &db)?,
outlier: rooms::outlier::Service { typing: rooms::typing::Service::build(&server, &db)?,
db: db.clone(), spaces: rooms::spaces::Service::build(&server, &db)?,
}, user: rooms::user::Service::build(&server, &db)?,
pdu_metadata: rooms::pdu_metadata::Service {
db: db.clone(),
},
read_receipt: rooms::read_receipt::Service {
db: db.clone(),
},
search: rooms::search::Service {
db: db.clone(),
},
short: rooms::short::Service {
db: db.clone(),
},
state: rooms::state::Service {
db: db.clone(),
},
state_accessor: rooms::state_accessor::Service {
db: db.clone(),
server_visibility_cache: StdMutex::new(LruCache::new(
(f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier)
as usize,
)),
user_visibility_cache: StdMutex::new(LruCache::new(
(f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier)
as usize,
)),
},
state_cache: rooms::state_cache::Service {
db: db.clone(),
},
state_compressor: rooms::state_compressor::Service {
db: db.clone(),
stateinfo_cache: StdMutex::new(LruCache::new(
(f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
},
timeline: rooms::timeline::Service {
db: db.clone(),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
},
threads: rooms::threads::Service {
db: db.clone(),
},
typing: rooms::typing::Service {
typing: RwLock::new(BTreeMap::new()),
last_typing_update: RwLock::new(BTreeMap::new()),
typing_update_sender: broadcast::channel(100).0,
},
spaces: rooms::spaces::Service {
roomid_spacehierarchy_cache: Mutex::new(LruCache::new(
(f64::from(config.roomid_spacehierarchy_cache_capacity)
* config.conduit_cache_capacity_modifier) as usize,
)),
},
user: rooms::user::Service {
db: db.clone(),
},
}, },
transaction_ids: transaction_ids::Service { appservice: appservice::Service::build(&server, &db)?,
db: db.clone(), pusher: pusher::Service::build(&server, &db)?,
}, transaction_ids: transaction_ids::Service::build(&server, &db)?,
uiaa: uiaa::Service { uiaa: uiaa::Service::build(&server, &db)?,
db: db.clone(), users: users::Service::build(&server, &db)?,
}, account_data: account_data::Service::build(&server, &db)?,
users: users::Service { presence: presence::Service::build(&server, &db)?,
db: db.clone(), admin: admin::Service::build(&server, &db)?,
connections: StdMutex::new(BTreeMap::new()), key_backups: key_backups::Service::build(&server, &db)?,
}, media: media::Service::build(&server, &db)?,
account_data: account_data::Service { sending: sending::Service::build(&server, &db)?,
db: db.clone(), globals: globals::Service::build(&server, &db)?,
},
presence: presence::Service::build(db.clone(), config),
admin: admin::Service::build(),
key_backups: key_backups::Service {
db: db.clone(),
},
media: media::Service::build(&server, &db),
sending: sending::Service::build(db.clone(), config),
globals: globals::Service::load(db.clone(), config)?,
server, server,
db, db,
}) })

View file

@ -1,19 +1,21 @@
use std::sync::Arc;
use conduit::Result;
use database::{KeyValueDatabase, KvTree};
use ruma::{DeviceId, TransactionId, UserId}; use ruma::{DeviceId, TransactionId, UserId};
use crate::{KeyValueDatabase, Result}; pub struct Data {
userdevicetxnid_response: Arc<dyn KvTree>,
pub(crate) trait Data: Send + Sync {
fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()>;
fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>>;
} }
impl Data for KeyValueDatabase { impl Data {
fn add_txnid( pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
userdevicetxnid_response: db.userdevicetxnid_response.clone(),
}
}
pub(super) fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()> { ) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
@ -27,7 +29,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn existing_txnid( pub(super) fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>> { ) -> Result<Option<Vec<u8>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -8,10 +11,16 @@ use ruma::{DeviceId, TransactionId, UserId};
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub(super) db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
pub fn add_txnid( pub fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()> { ) -> Result<()> {

View file

@ -1,29 +1,30 @@
use std::sync::Arc;
use conduit::{Error, Result};
use database::{KeyValueDatabase, KvTree};
use ruma::{ use ruma::{
api::client::{error::ErrorKind, uiaa::UiaaInfo}, api::client::{error::ErrorKind, uiaa::UiaaInfo},
CanonicalJsonValue, DeviceId, UserId, CanonicalJsonValue, DeviceId, UserId,
}; };
use crate::{Error, KeyValueDatabase, Result}; pub struct Data {
userdevicesessionid_uiaainfo: Arc<dyn KvTree>,
pub(crate) trait Data: Send + Sync { db: Arc<KeyValueDatabase>,
fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()>;
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue>;
fn update_uiaa_session(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
) -> Result<()>;
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo>;
} }
impl Data for KeyValueDatabase { impl Data {
fn set_uiaa_request( pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
userdevicesessionid_uiaainfo: db.userdevicesessionid_uiaainfo.clone(),
db: db.clone(),
}
}
pub(super) fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()> { ) -> Result<()> {
self.userdevicesessionid_uiaarequest self.db
.userdevicesessionid_uiaarequest
.write() .write()
.unwrap() .unwrap()
.insert( .insert(
@ -34,15 +35,18 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> { pub(super) fn get_uiaa_request(
self.userdevicesessionid_uiaarequest &self, user_id: &UserId, device_id: &DeviceId, session: &str,
) -> Option<CanonicalJsonValue> {
self.db
.userdevicesessionid_uiaarequest
.read() .read()
.unwrap() .unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
} }
fn update_uiaa_session( pub(super) fn update_uiaa_session(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> { ) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec(); let mut userdevicesessionid = user_id.as_bytes().to_vec();
@ -64,7 +68,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> { pub(super) fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec(); let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF); userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes()); userdevicesessionid.extend_from_slice(device_id.as_bytes());

View file

@ -1,3 +1,6 @@
use conduit::Server;
use database::KeyValueDatabase;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
@ -18,10 +21,16 @@ use crate::services;
pub const SESSION_ID_LENGTH: usize = 32; pub const SESSION_ID_LENGTH: usize = 32;
pub struct Service { pub struct Service {
pub(super) db: Arc<dyn Data>, pub db: Data,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db),
})
}
/// Creates a new Uiaa session. Make sure the session token is unique. /// Creates a new Uiaa session. Make sure the session token is unique.
pub fn create( pub fn create(
&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue,

View file

@ -1,4 +1,4 @@
use std::{collections::BTreeMap, mem::size_of}; use std::{collections::BTreeMap, mem::size_of, sync::Arc};
use ruma::{ use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
@ -10,149 +10,60 @@ use ruma::{
}; };
use tracing::warn; use tracing::warn;
use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, Result}; use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, KvTree, Result};
pub trait Data: Send + Sync { pub struct Data {
/// Check if a user has an account on this homeserver. userid_password: Arc<dyn KvTree>,
fn exists(&self, user_id: &UserId) -> Result<bool>; token_userdeviceid: Arc<dyn KvTree>,
userid_displayname: Arc<dyn KvTree>,
/// Check if account is deactivated userid_avatarurl: Arc<dyn KvTree>,
fn is_deactivated(&self, user_id: &UserId) -> Result<bool>; userid_blurhash: Arc<dyn KvTree>,
userid_devicelistversion: Arc<dyn KvTree>,
/// Returns the number of users registered on this server. userdeviceid_token: Arc<dyn KvTree>,
fn count(&self) -> Result<usize>; userdeviceid_metadata: Arc<dyn KvTree>,
onetimekeyid_onetimekeys: Arc<dyn KvTree>,
/// Find out which user an access token belongs to. userid_lastonetimekeyupdate: Arc<dyn KvTree>,
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>>; keyid_key: Arc<dyn KvTree>,
userid_masterkeyid: Arc<dyn KvTree>,
/// Returns an iterator over all users on this homeserver. userid_selfsigningkeyid: Arc<dyn KvTree>,
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>; userid_usersigningkeyid: Arc<dyn KvTree>,
keychangeid_userid: Arc<dyn KvTree>,
/// Returns a list of local users as list of usernames. todeviceid_events: Arc<dyn KvTree>,
/// userfilterid_filter: Arc<dyn KvTree>,
/// A user account is considered `local` if the length of it's password is _db: Arc<KeyValueDatabase>,
/// greater then zero.
fn list_local_users(&self) -> Result<Vec<String>>;
/// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>>;
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>;
/// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>>;
/// Sets a new displayname or removes it if displayname is None. You still
/// need to nofify all rooms of this change.
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()>;
/// Get the avatar_url of a user.
fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>>;
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()>;
/// Get the blurhash of a user.
fn blurhash(&self, user_id: &UserId) -> Result<Option<String>>;
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()>;
/// Adds a new device to a user.
fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
) -> Result<()>;
/// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
/// Returns an iterator over all device ids of this user.
fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a>;
/// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>;
fn add_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()>;
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64>;
fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>>;
fn count_one_time_keys(&self, user_id: &UserId, device_id: &DeviceId)
-> Result<BTreeMap<DeviceKeyAlgorithm, UInt>>;
fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()>;
fn add_cross_signing_keys(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
) -> Result<()>;
fn sign_key(&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId)
-> Result<()>;
fn keys_changed<'a>(
&'a self, user_or_room_id: &str, from: u64, to: Option<u64>,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>;
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>;
fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>>;
fn parse_master_key(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)>;
fn get_key(
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>>;
fn get_master_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>>;
fn get_self_signing_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>>;
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>>;
fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value,
) -> Result<()>;
fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>>;
fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()>;
fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()>;
/// Get device metadata.
fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>>;
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>>;
fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<Device>> + 'a>;
/// Creates a new sync filter. Returns the filter id.
fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String>;
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>>;
} }
impl Data for KeyValueDatabase { impl Data {
pub(super) fn new(db: Arc<KeyValueDatabase>) -> Self {
Self {
userid_password: db.userid_password.clone(),
token_userdeviceid: db.token_userdeviceid.clone(),
userid_displayname: db.userid_displayname.clone(),
userid_avatarurl: db.userid_avatarurl.clone(),
userid_blurhash: db.userid_blurhash.clone(),
userid_devicelistversion: db.userid_devicelistversion.clone(),
userdeviceid_token: db.userdeviceid_token.clone(),
userdeviceid_metadata: db.userdeviceid_metadata.clone(),
onetimekeyid_onetimekeys: db.onetimekeyid_onetimekeys.clone(),
userid_lastonetimekeyupdate: db.userid_lastonetimekeyupdate.clone(),
keyid_key: db.keyid_key.clone(),
userid_masterkeyid: db.userid_masterkeyid.clone(),
userid_selfsigningkeyid: db.userid_selfsigningkeyid.clone(),
userid_usersigningkeyid: db.userid_usersigningkeyid.clone(),
keychangeid_userid: db.keychangeid_userid.clone(),
todeviceid_events: db.todeviceid_events.clone(),
userfilterid_filter: db.userfilterid_filter.clone(),
_db: db,
}
}
/// Check if a user has an account on this homeserver. /// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } pub(super) fn exists(&self, user_id: &UserId) -> Result<bool> {
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
}
/// Check if account is deactivated /// Check if account is deactivated
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { pub(super) fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
Ok(self Ok(self
.userid_password .userid_password
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
@ -161,10 +72,10 @@ impl Data for KeyValueDatabase {
} }
/// Returns the number of users registered on this server. /// Returns the number of users registered on this server.
fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) } pub(super) fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) }
/// Find out which user an access token belongs to. /// Find out which user an access token belongs to.
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> { pub(super) fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
self.token_userdeviceid self.token_userdeviceid
.get(token.as_bytes())? .get(token.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -189,7 +100,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns an iterator over all users on this homeserver. /// Returns an iterator over all users on this homeserver.
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { pub fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
Box::new(self.userid_password.iter().map(|(bytes, _)| { Box::new(self.userid_password.iter().map(|(bytes, _)| {
UserId::parse( UserId::parse(
utils::string_from_bytes(&bytes) utils::string_from_bytes(&bytes)
@ -203,7 +114,7 @@ impl Data for KeyValueDatabase {
/// ///
/// A user account is considered `local` if the length of it's password is /// A user account is considered `local` if the length of it's password is
/// greater then zero. /// greater then zero.
fn list_local_users(&self) -> Result<Vec<String>> { pub(super) fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self let users: Vec<String> = self
.userid_password .userid_password
.iter() .iter()
@ -213,7 +124,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns the password hash for the given user. /// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { pub(super) fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password self.userid_password
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -224,7 +135,7 @@ impl Data for KeyValueDatabase {
} }
/// Hash and set the user's password to the Argon2 hash /// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { pub(super) fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password { if let Some(password) = password {
if let Ok(hash) = utils::hash::password(password) { if let Ok(hash) = utils::hash::password(password) {
self.userid_password self.userid_password
@ -243,7 +154,7 @@ impl Data for KeyValueDatabase {
} }
/// Returns the displayname of a user on this homeserver. /// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { pub(super) fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname self.userid_displayname
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -256,7 +167,7 @@ impl Data for KeyValueDatabase {
/// Sets a new displayname or removes it if displayname is None. You still /// Sets a new displayname or removes it if displayname is None. You still
/// need to nofify all rooms of this change. /// need to nofify all rooms of this change.
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { pub(super) fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname { if let Some(displayname) = displayname {
self.userid_displayname self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?; .insert(user_id.as_bytes(), displayname.as_bytes())?;
@ -268,7 +179,7 @@ impl Data for KeyValueDatabase {
} }
/// Get the `avatar_url` of a user. /// Get the `avatar_url` of a user.
fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> { pub(super) fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> {
self.userid_avatarurl self.userid_avatarurl
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
@ -283,7 +194,7 @@ impl Data for KeyValueDatabase {
} }
/// Sets a new avatar_url or removes it if avatar_url is None. /// Sets a new avatar_url or removes it if avatar_url is None.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> { pub(super) fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
if let Some(avatar_url) = avatar_url { if let Some(avatar_url) = avatar_url {
self.userid_avatarurl self.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
@ -295,7 +206,7 @@ impl Data for KeyValueDatabase {
} }
/// Get the blurhash of a user. /// Get the blurhash of a user.
fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> { pub(super) fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_blurhash self.userid_blurhash
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
@ -307,8 +218,8 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
/// Sets a new blurhash or removes it if blurhash is None. /// Sets a new avatar_url or removes it if avatar_url is None.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { pub(super) fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
if let Some(blurhash) = blurhash { if let Some(blurhash) = blurhash {
self.userid_blurhash self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?; .insert(user_id.as_bytes(), blurhash.as_bytes())?;
@ -320,7 +231,7 @@ impl Data for KeyValueDatabase {
} }
/// Adds a new device to a user. /// Adds a new device to a user.
fn create_device( pub(super) fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>, &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
) -> Result<()> { ) -> Result<()> {
// This method should never be called for nonexistent users. We shouldn't assert // This method should never be called for nonexistent users. We shouldn't assert
@ -354,7 +265,7 @@ impl Data for KeyValueDatabase {
} }
/// Removes a device from a user. /// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { pub(super) fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
@ -384,7 +295,9 @@ impl Data for KeyValueDatabase {
} }
/// Returns an iterator over all device ids of this user. /// Returns an iterator over all device ids of this user.
fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> { pub(super) fn all_device_ids<'a>(
&'a self, user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
// All devices have metadata // All devices have metadata
@ -405,7 +318,7 @@ impl Data for KeyValueDatabase {
} }
/// Replaces the access token of one device. /// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { pub(super) fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
@ -436,7 +349,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn add_one_time_key( pub(super) fn add_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>, one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> { ) -> Result<()> {
@ -478,7 +391,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { pub(super) fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(0), |bytes| { .map_or(Ok(0), |bytes| {
@ -487,7 +400,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn take_one_time_key( pub(super) fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> { ) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
@ -521,7 +434,7 @@ impl Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn count_one_time_keys( pub(super) fn count_one_time_keys(
&self, user_id: &UserId, device_id: &DeviceId, &self, user_id: &UserId, device_id: &DeviceId,
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> { ) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
@ -551,7 +464,9 @@ impl Data for KeyValueDatabase {
Ok(counts) Ok(counts)
} }
fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()> { pub(super) fn add_device_keys(
&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>,
) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
@ -566,7 +481,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn add_cross_signing_keys( pub(super) fn add_cross_signing_keys(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>, &self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool, user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
) -> Result<()> { ) -> Result<()> {
@ -574,7 +489,7 @@ impl Data for KeyValueDatabase {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; let (master_key_key, _) = Self::parse_master_key(user_id, master_key)?;
self.keyid_key self.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes())?; .insert(&master_key_key, master_key.json().get().as_bytes())?;
@ -647,7 +562,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn sign_key( pub(super) fn sign_key(
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
) -> Result<()> { ) -> Result<()> {
let mut key = target_id.as_bytes().to_vec(); let mut key = target_id.as_bytes().to_vec();
@ -685,7 +600,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn keys_changed<'a>( pub(super) fn keys_changed<'a>(
&'a self, user_or_room_id: &str, from: u64, to: Option<u64>, &'a self, user_or_room_id: &str, from: u64, to: Option<u64>,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = user_or_room_id.as_bytes().to_vec(); let mut prefix = user_or_room_id.as_bytes().to_vec();
@ -724,7 +639,7 @@ impl Data for KeyValueDatabase {
) )
} }
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes(); let count = services().globals.next_count()?.to_be_bytes();
for room_id in services() for room_id in services()
.rooms .rooms
@ -757,7 +672,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> { pub(super) fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(device_id.as_bytes()); key.extend_from_slice(device_id.as_bytes());
@ -769,8 +684,8 @@ impl Data for KeyValueDatabase {
}) })
} }
fn parse_master_key( pub(super) fn parse_master_key(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> { ) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@ -793,7 +708,7 @@ impl Data for KeyValueDatabase {
Ok((master_key_key, master_key)) Ok((master_key_key, master_key))
} }
fn get_key( pub(super) fn get_key(
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> { ) -> Result<Option<Raw<CrossSigningKey>>> {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
@ -807,7 +722,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn get_master_key( pub(super) fn get_master_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> { ) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_masterkeyid self.userid_masterkeyid
@ -815,7 +730,7 @@ impl Data for KeyValueDatabase {
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
} }
fn get_self_signing_key( pub(super) fn get_self_signing_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> { ) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_selfsigningkeyid self.userid_selfsigningkeyid
@ -823,7 +738,7 @@ impl Data for KeyValueDatabase {
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
} }
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { pub(super) fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid self.userid_usersigningkeyid
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(None), |key| { .map_or(Ok(None), |key| {
@ -836,7 +751,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn add_to_device_event( pub(super) fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value, content: serde_json::Value,
) -> Result<()> { ) -> Result<()> {
@ -858,7 +773,9 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> { pub(super) fn get_to_device_events(
&self, user_id: &UserId, device_id: &DeviceId,
) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
let mut events = Vec::new(); let mut events = Vec::new();
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
@ -876,7 +793,7 @@ impl Data for KeyValueDatabase {
Ok(events) Ok(events)
} }
fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { pub(super) fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes()); prefix.extend_from_slice(device_id.as_bytes());
@ -905,7 +822,7 @@ impl Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { pub(super) fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
@ -935,7 +852,7 @@ impl Data for KeyValueDatabase {
} }
/// Get device metadata. /// Get device metadata.
fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> { pub(super) fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
@ -949,7 +866,7 @@ impl Data for KeyValueDatabase {
}) })
} }
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { pub(super) fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion self.userid_devicelistversion
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -959,7 +876,9 @@ impl Data for KeyValueDatabase {
}) })
} }
fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<Device>> + 'a> { pub(super) fn all_devices_metadata<'a>(
&'a self, user_id: &UserId,
) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
@ -974,7 +893,7 @@ impl Data for KeyValueDatabase {
} }
/// Creates a new sync filter. Returns the filter id. /// Creates a new sync filter. Returns the filter id.
fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> { pub(super) fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> {
let filter_id = utils::random_string(4); let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
@ -987,7 +906,7 @@ impl Data for KeyValueDatabase {
Ok(filter_id) Ok(filter_id)
} }
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> { pub(super) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes()); key.extend_from_slice(filter_id.as_bytes());
@ -1006,7 +925,7 @@ impl Data for KeyValueDatabase {
/// username could be successfully parsed. /// username could be successfully parsed.
/// If `utils::string_from_bytes`(...) returns an error that username will be /// If `utils::string_from_bytes`(...) returns an error that username will be
/// skipped and the error will be logged. /// skipped and the error will be logged.
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> { pub(super) fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
// A valid password is not empty // A valid password is not empty
if password.is_empty() { if password.is_empty() {
None None

View file

@ -1,8 +1,10 @@
use conduit::Server;
mod data; mod data;
use std::{ use std::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
mem, mem,
sync::{Arc, Mutex}, sync::{Arc, Mutex, Mutex as StdMutex},
}; };
use data::Data; use data::Data;
@ -22,7 +24,7 @@ use ruma::{
UInt, UserId, UInt, UserId,
}; };
use crate::{service, services, Error, Result}; use crate::{database::KeyValueDatabase, service, services, Error, Result};
pub struct SlidingSyncCache { pub struct SlidingSyncCache {
lists: BTreeMap<String, SyncRequestList>, lists: BTreeMap<String, SyncRequestList>,
@ -34,11 +36,18 @@ pub struct SlidingSyncCache {
type DbConnections = Mutex<BTreeMap<(OwnedUserId, OwnedDeviceId, String), Arc<Mutex<SlidingSyncCache>>>>; type DbConnections = Mutex<BTreeMap<(OwnedUserId, OwnedDeviceId, String), Arc<Mutex<SlidingSyncCache>>>>;
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Data,
pub connections: DbConnections, pub connections: DbConnections,
} }
impl Service { impl Service {
pub fn build(_server: &Arc<Server>, db: &Arc<KeyValueDatabase>) -> Result<Self> {
Ok(Self {
db: Data::new(db.clone()),
connections: StdMutex::new(BTreeMap::new()),
})
}
/// Check if a user has an account on this homeserver. /// Check if a user has an account on this homeserver.
pub fn exists(&self, user_id: &UserId) -> Result<bool> { self.db.exists(user_id) } pub fn exists(&self, user_id: &UserId) -> Result<bool> { self.db.exists(user_id) }
@ -381,7 +390,7 @@ impl Service {
pub fn parse_master_key( pub fn parse_master_key(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, &self, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> { ) -> Result<(Vec<u8>, CrossSigningKey)> {
self.db.parse_master_key(user_id, master_key) Data::parse_master_key(user_id, master_key)
} }
pub fn get_key( pub fn get_key(