dissolve key_value/*

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-26 21:29:19 +00:00 committed by June 🍓🦴
parent 3122648767
commit b94045a468
88 changed files with 4556 additions and 4751 deletions

View file

@ -142,7 +142,7 @@ pub(crate) async fn sync_events_route(
.collect::<Vec<_>>();
// Coalesce database writes for the remainder of this scope.
let _cork = services().globals.db.cork_and_flush()?;
let _cork = services().globals.cork_and_flush()?;
for room_id in all_joined_rooms {
let room_id = room_id?;

View file

@ -1,12 +1,14 @@
use std::collections::HashMap;
use ruma::{
api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use tracing::warn;
use crate::Result;
use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
/// Places one event in the account data of the user and removes the
@ -26,3 +28,129 @@ pub trait Data: Send + Sync {
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
}
impl Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the
/// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()> {
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
let mut key = prefix;
key.extend_from_slice(event_type.to_string().as_bytes());
if data.get("type").is_none() || data.get("content").is_none() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Account data doesn't have all required fields.",
));
}
self.roomuserdataid_accountdata.insert(
&roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
)?;
let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
// Remove old entry
if let Some(prev) = prev {
self.roomuserdataid_accountdata.remove(&prev)?;
}
Ok(())
}
/// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, kind))]
fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> {
let mut key = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(kind.to_string().as_bytes());
self.roomusertype_roomuserdataid
.get(&key)?
.and_then(|roomuserdataid| {
self.roomuserdataid_accountdata
.get(&roomuserdataid)
.transpose()
})
.transpose()?
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
.transpose()
}
/// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))]
fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
let mut userdata = HashMap::new();
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
// Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone();
first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes());
for r in self
.roomuserdataid_accountdata
.iter_from(&first_possible, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| {
Ok::<_, Error>((
RoomAccountDataEventType::from(
utils::string_from_bytes(
k.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
)
.map_err(|e| {
warn!("RoomUserData ID in database is invalid: {}", e);
Error::bad_database("RoomUserData ID in db is invalid.")
})?,
),
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
))
}) {
let (kind, data) = r?;
userdata.insert(kind, data);
}
Ok(userdata)
}
}

View file

@ -1,6 +1,6 @@
use ruma::api::appservice::Registration;
use crate::Result;
use crate::{utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
/// Registers an appservice and returns the ID to the caller
@ -19,3 +19,55 @@ pub trait Data: Send + Sync {
fn all(&self) -> Result<Vec<(String, Registration)>>;
}
impl Data for KeyValueDatabase {
/// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String> {
let id = yaml.id.as_str();
self.id_appserviceregistrations
.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
Ok(id.to_owned())
}
/// Remove an appservice registration
///
/// # Arguments
///
/// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations
.remove(service_name.as_bytes())?;
Ok(())
}
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
self.id_appserviceregistrations
.get(id.as_bytes())?
.map(|bytes| {
serde_yaml::from_slice(&bytes)
.map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations."))
})
.transpose()
}
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
utils::string_from_bytes(&id)
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
})))
}
fn all(&self) -> Result<Vec<(String, Registration)>> {
self.iter_ids()?
.filter_map(Result::ok)
.map(move |id| {
Ok((
id.clone(),
self.get_registration(&id)?
.expect("iter_ids only returns appservices that exist"),
))
})
.collect()
}
}

View file

@ -1,13 +1,19 @@
use std::{collections::BTreeMap, error::Error};
use std::collections::{BTreeMap, HashMap};
use async_trait::async_trait;
use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache;
use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair,
DeviceId, OwnedServerSigningKeyId, ServerName, UserId,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
};
use tracing::trace;
use crate::{database::Cork, Result};
use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result};
const COUNTER: &[u8] = b"c";
const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
#[async_trait]
pub trait Data: Send + Sync {
@ -41,7 +47,290 @@ pub trait Data: Send + Sync {
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
fn database_version(&self) -> Result<u64>;
fn bump_database_version(&self, new_version: u64) -> Result<()>;
fn backup(&self) -> Result<(), Box<dyn Error>> { unimplemented!() }
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { unimplemented!() }
fn backup_list(&self) -> Result<String> { Ok(String::new()) }
fn file_list(&self) -> Result<String> { Ok(String::new()) }
}
#[async_trait]
impl Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))
}
fn current_count(&self) -> Result<u64> {
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
})
}
fn last_check_for_updates_id(&self) -> Result<u64> {
self.global
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
})
}
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.global
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(())
}
#[allow(unused_qualifications)] // async traits
#[tracing::instrument(skip(self))]
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xFF);
let mut userdeviceid_prefix = userid_prefix.clone();
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
userdeviceid_prefix.push(0xFF);
let mut futures = FuturesUnordered::new();
// Return when *any* user changed their key
// TODO: only send for user they share a room with
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(
self.userroomid_notificationcount
.watch_prefix(&userid_prefix),
);
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in
for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
let short_roomid = services()
.rooms
.short
.get_shortroomid(&room_id)
.ok()
.flatten()
.expect("room exists")
.to_be_bytes()
.to_vec();
let roomid_bytes = room_id.as_bytes().to_vec();
let mut roomid_prefix = roomid_bytes.clone();
roomid_prefix.push(0xFF);
// PDUs
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs
futures.push(Box::pin(async move {
let _result = services().rooms.typing.wait_for_update(&room_id).await;
}));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
// Key changes
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
// Room account data
let mut roomuser_prefix = roomid_prefix.clone();
roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&roomuser_prefix),
);
}
let mut globaluserdata_prefix = vec![0xFF];
globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&globaluserdata_prefix),
);
// More key changes (used when user is not joined to any rooms)
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
// One time keys
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(services().globals.rotate.watch()));
// Wait until one of them finds something
trace!(futures = futures.len(), "watch started");
futures.next().await;
trace!(futures = futures.len(), "watch finished");
Ok(())
}
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
fn flush(&self) -> Result<()> { self.db.flush() }
fn cork(&self) -> Result<Cork> { Ok(Cork::new(&self.db, false, false)) }
fn cork_and_flush(&self) -> Result<Cork> { Ok(Cork::new(&self.db, true, false)) }
fn cork_and_sync(&self) -> Result<Cork> { Ok(Cork::new(&self.db, true, true)) }
fn memory_usage(&self) -> String {
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity();
let max_our_real_users_cache = self.our_real_users_cache.read().unwrap().capacity();
let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity();
let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity();
format!(
"\
auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache}
our_real_users_cache: {our_real_users_cache} / {max_our_real_users_cache}
appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache}
lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n
{}",
self.db.memory_usage().unwrap_or_default()
)
}
fn clear_caches(&self, amount: u32) {
if amount > 1 {
let c = &mut *self.auth_chain_cache.lock().unwrap();
*c = LruCache::new(c.capacity());
}
if amount > 2 {
let c = &mut *self.our_real_users_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 3 {
let c = &mut *self.appservice_in_room_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 4 {
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
*c = HashMap::new();
}
}
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| {
let keypair = utils::generate_keypair();
self.global.insert(b"keypair", &keypair)?;
Ok::<_, Error>(keypair)
},
Ok,
)?;
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
utils::string_from_bytes(
// 1. version
parts
.next()
.expect("splitn always returns at least one element"),
)
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
.and_then(|version| {
// 2. key
parts
.next()
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
.map(|key| (version, key))
})
.and_then(|(version, key)| {
Ed25519KeyPair::from_der(key, version)
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
})
}
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
// Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
let ServerSigningKeys {
verify_keys,
old_verify_keys,
..
} = new_keys;
keys.verify_keys.extend(verify_keys);
keys.old_verify_keys.extend(old_verify_keys);
self.server_signingkeys.insert(
origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?;
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree)
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self
.server_signingkeys
.get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
.map_or_else(BTreeMap::new, |keys: ServerSigningKeys| {
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
tree
});
Ok(signingkeys)
}
fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
})
}
fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.global.insert(b"version", &new_version.to_be_bytes())?;
Ok(())
}
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.backup() }
fn backup_list(&self) -> Result<String> { self.db.backup_list() }
fn file_list(&self) -> Result<String> { self.db.file_list() }
}

View file

@ -9,7 +9,7 @@ use std::{
use argon2::Argon2;
use base64::{engine::general_purpose, Engine as _};
pub use data::Data;
use data::Data;
use hickory_resolver::TokioAsyncResolver;
use ipaddress::IPAddress;
use regex::RegexSet;
@ -29,10 +29,10 @@ use tokio::{
use tracing::{error, trace};
use url::Url;
use crate::{services, Config, Result};
use crate::{database::Cork, services, Config, Result};
mod client;
pub mod data;
mod data;
pub(crate) mod emerg_access;
pub(crate) mod migrations;
mod resolver;
@ -200,6 +200,10 @@ impl Service {
#[allow(dead_code)]
pub fn flush(&self) -> Result<()> { self.db.flush() }
pub fn cork(&self) -> Result<Cork> { self.db.cork() }
pub fn cork_and_flush(&self) -> Result<Cork> { self.db.cork_and_flush() }
pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() }
pub fn max_request_size(&self) -> u32 { self.config.max_request_size }

View file

@ -1,14 +1,17 @@
use std::collections::BTreeMap;
use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
error::ErrorKind,
},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use crate::Result;
use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
pub(crate) trait Data: Send + Sync {
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
@ -45,3 +48,308 @@ pub trait Data: Send + Sync {
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>;
}
impl Data for KeyValueDatabase {
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.insert(
&key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version)
}
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.remove(&key)?;
self.backupid_etag.remove(&key)?;
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned())
}
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, _)| {
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
})
.transpose()
}
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, value)| {
let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
Ok((
version,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
))
})
.transpose()
}
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm
.get(&key)?
.map_or(Ok(None), |bytes| {
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
})
}
fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.insert(&key, key_data.json().get().as_bytes())?;
Ok(())
}
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
}
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes(
&self
.backupid_etag
.get(&key)?
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
)
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
.to_string())
}
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let room_id = RoomId::parse(
utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((room_id, session_id, key_data))
}) {
let (room_id, session_id, key_data) = result?;
rooms
.entry(room_id)
.or_insert_with(|| RoomKeyBackup {
sessions: BTreeMap::new(),
})
.sessions
.insert(session_id, key_data);
}
Ok(rooms)
}
fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
Ok(self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((session_id, key_data))
})
.filter_map(Result::ok)
.collect())
}
fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.get(&key)?
.map(|value| {
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
})
.transpose()
}
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
}

View file

@ -1,7 +1,7 @@
mod data;
use std::{collections::BTreeMap, sync::Arc};
pub use data::Data;
use data::Data;
use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw,
@ -11,7 +11,7 @@ use ruma::{
use crate::Result;
pub struct Service {
pub db: Arc<dyn Data>,
pub(super) db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,137 +0,0 @@
use std::collections::HashMap;
use ruma::{
api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use tracing::warn;
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the
/// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()> {
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
let mut key = prefix;
key.extend_from_slice(event_type.to_string().as_bytes());
if data.get("type").is_none() || data.get("content").is_none() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Account data doesn't have all required fields.",
));
}
self.roomuserdataid_accountdata.insert(
&roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
)?;
let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
// Remove old entry
if let Some(prev) = prev {
self.roomuserdataid_accountdata.remove(&prev)?;
}
Ok(())
}
/// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, kind))]
fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> {
let mut key = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(kind.to_string().as_bytes());
self.roomusertype_roomuserdataid
.get(&key)?
.and_then(|roomuserdataid| {
self.roomuserdataid_accountdata
.get(&roomuserdataid)
.transpose()
})
.transpose()?
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
.transpose()
}
/// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))]
fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
let mut userdata = HashMap::new();
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
// Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone();
first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes());
for r in self
.roomuserdataid_accountdata
.iter_from(&first_possible, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| {
Ok::<_, Error>((
RoomAccountDataEventType::from(
utils::string_from_bytes(
k.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
)
.map_err(|e| {
warn!("RoomUserData ID in database is invalid: {}", e);
Error::bad_database("RoomUserData ID in db is invalid.")
})?,
),
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
))
}) {
let (kind, data) = r?;
userdata.insert(kind, data);
}
Ok(userdata)
}
}

View file

@ -1,55 +0,0 @@
use ruma::api::appservice::Registration;
use crate::{utils, Error, KeyValueDatabase, Result};
impl crate::appservice::Data for KeyValueDatabase {
/// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String> {
let id = yaml.id.as_str();
self.id_appserviceregistrations
.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
Ok(id.to_owned())
}
/// Remove an appservice registration
///
/// # Arguments
///
/// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations
.remove(service_name.as_bytes())?;
Ok(())
}
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
self.id_appserviceregistrations
.get(id.as_bytes())?
.map(|bytes| {
serde_yaml::from_slice(&bytes)
.map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations."))
})
.transpose()
}
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
utils::string_from_bytes(&id)
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
})))
}
fn all(&self) -> Result<Vec<(String, Registration)>> {
self.iter_ids()?
.filter_map(Result::ok)
.map(move |id| {
Ok((
id.clone(),
self.get_registration(&id)?
.expect("iter_ids only returns appservices that exist"),
))
})
.collect()
}
}

View file

@ -1,299 +0,0 @@
use std::collections::{BTreeMap, HashMap};
use async_trait::async_trait;
use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache;
use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
};
use tracing::trace;
use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result};
const COUNTER: &[u8] = b"c";
const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
#[async_trait]
impl crate::globals::Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))
}
fn current_count(&self) -> Result<u64> {
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
})
}
fn last_check_for_updates_id(&self) -> Result<u64> {
self.global
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
})
}
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.global
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(())
}
#[allow(unused_qualifications)] // async traits
#[tracing::instrument(skip(self))]
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xFF);
let mut userdeviceid_prefix = userid_prefix.clone();
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
userdeviceid_prefix.push(0xFF);
let mut futures = FuturesUnordered::new();
// Return when *any* user changed their key
// TODO: only send for user they share a room with
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(
self.userroomid_notificationcount
.watch_prefix(&userid_prefix),
);
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in
for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
let short_roomid = services()
.rooms
.short
.get_shortroomid(&room_id)
.ok()
.flatten()
.expect("room exists")
.to_be_bytes()
.to_vec();
let roomid_bytes = room_id.as_bytes().to_vec();
let mut roomid_prefix = roomid_bytes.clone();
roomid_prefix.push(0xFF);
// PDUs
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs
futures.push(Box::pin(async move {
let _result = services().rooms.typing.wait_for_update(&room_id).await;
}));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
// Key changes
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
// Room account data
let mut roomuser_prefix = roomid_prefix.clone();
roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&roomuser_prefix),
);
}
let mut globaluserdata_prefix = vec![0xFF];
globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&globaluserdata_prefix),
);
// More key changes (used when user is not joined to any rooms)
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
// One time keys
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(services().globals.rotate.watch()));
// Wait until one of them finds something
trace!(futures = futures.len(), "watch started");
futures.next().await;
trace!(futures = futures.len(), "watch finished");
Ok(())
}
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
fn flush(&self) -> Result<()> { self.db.flush() }
fn cork(&self) -> Result<Cork> { Ok(Cork::new(&self.db, false, false)) }
fn cork_and_flush(&self) -> Result<Cork> { Ok(Cork::new(&self.db, true, false)) }
fn cork_and_sync(&self) -> Result<Cork> { Ok(Cork::new(&self.db, true, true)) }
fn memory_usage(&self) -> String {
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity();
let max_our_real_users_cache = self.our_real_users_cache.read().unwrap().capacity();
let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity();
let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity();
format!(
"\
auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache}
our_real_users_cache: {our_real_users_cache} / {max_our_real_users_cache}
appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache}
lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n
{}",
self.db.memory_usage().unwrap_or_default()
)
}
fn clear_caches(&self, amount: u32) {
if amount > 1 {
let c = &mut *self.auth_chain_cache.lock().unwrap();
*c = LruCache::new(c.capacity());
}
if amount > 2 {
let c = &mut *self.our_real_users_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 3 {
let c = &mut *self.appservice_in_room_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 4 {
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
*c = HashMap::new();
}
}
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| {
let keypair = utils::generate_keypair();
self.global.insert(b"keypair", &keypair)?;
Ok::<_, Error>(keypair)
},
Ok,
)?;
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
utils::string_from_bytes(
// 1. version
parts
.next()
.expect("splitn always returns at least one element"),
)
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
.and_then(|version| {
// 2. key
parts
.next()
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
.map(|key| (version, key))
})
.and_then(|(version, key)| {
Ed25519KeyPair::from_der(key, version)
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
})
}
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
// Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
let ServerSigningKeys {
verify_keys,
old_verify_keys,
..
} = new_keys;
keys.verify_keys.extend(verify_keys);
keys.old_verify_keys.extend(old_verify_keys);
self.server_signingkeys.insert(
origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?;
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree)
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self
.server_signingkeys
.get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
.map_or_else(BTreeMap::new, |keys: ServerSigningKeys| {
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
tree
});
Ok(signingkeys)
}
fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
})
}
fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.global.insert(b"version", &new_version.to_be_bytes())?;
Ok(())
}
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.backup() }
fn backup_list(&self) -> Result<String> { self.db.backup_list() }
fn file_list(&self) -> Result<String> { self.db.file_list() }
}

View file

@ -1,317 +0,0 @@
use std::collections::BTreeMap;
use ruma::{
api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
error::ErrorKind,
},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::key_backups::Data for KeyValueDatabase {
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.insert(
&key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version)
}
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.remove(&key)?;
self.backupid_etag.remove(&key)?;
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned())
}
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, _)| {
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
})
.transpose()
}
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, value)| {
let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
Ok((
version,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
))
})
.transpose()
}
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm
.get(&key)?
.map_or(Ok(None), |bytes| {
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
})
}
fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.insert(&key, key_data.json().get().as_bytes())?;
Ok(())
}
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
}
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes(
&self
.backupid_etag
.get(&key)?
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
)
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
.to_string())
}
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let room_id = RoomId::parse(
utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((room_id, session_id, key_data))
}) {
let (room_id, session_id, key_data) = result?;
rooms
.entry(room_id)
.or_insert_with(|| RoomKeyBackup {
sessions: BTreeMap::new(),
})
.sessions
.insert(session_id, key_data);
}
Ok(rooms)
}
fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
Ok(self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((session_id, key_data))
})
.filter_map(Result::ok)
.collect())
}
fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.get(&key)?
.map(|value| {
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
})
.transpose()
}
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
}

View file

@ -1,235 +0,0 @@
use conduit::debug_info;
use ruma::api::client::error::ErrorKind;
use tracing::debug;
use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, KeyValueDatabase, Result};
impl crate::media::Data for KeyValueDatabase {
fn create_file_metadata(
&self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xFF);
key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
);
self.mediaid_file.insert(&key, &[])?;
if let Some(user) = sender_user {
let key = mxc.as_bytes().to_vec();
let user = user.as_bytes().to_vec();
self.mediaid_user.insert(&key, &user)?;
}
Ok(key)
}
fn delete_file_mxc(&self, mxc: String) -> Result<()> {
debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
debug!("MXC db prefix: {prefix:?}");
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
debug!("Deleting key: {:?}", key);
self.mediaid_file.remove(&key)?;
}
for (key, value) in self.mediaid_user.scan_prefix(mxc.as_bytes().to_vec()) {
if key == mxc.as_bytes().to_vec() {
let user = string_from_bytes(&value).unwrap_or_default();
debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}");
self.mediaid_user.remove(&key)?;
}
}
Ok(())
}
/// Searches for all files with the given MXC
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
let keys: Vec<Vec<u8>> = self
.mediaid_file
.scan_prefix(prefix)
.map(|(key, _)| key)
.collect();
if keys.is_empty() {
return Err(Error::bad_database(
"Failed to find any keys in database with the provided MXC.",
));
}
debug!("Got the following keys: {:?}", keys);
Ok(keys)
}
fn search_file_metadata(
&self, mxc: String, width: u32, height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(&width.to_be_bytes());
prefix.extend_from_slice(&height.to_be_bytes());
prefix.push(0xFF);
let (key, _) = self
.mediaid_file
.scan_prefix(prefix)
.next()
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let mut parts = key.rsplit(|&b| b == 0xFF);
let content_type = parts
.next()
.map(|bytes| {
string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
})
.transpose()?;
let content_disposition_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let content_disposition = if content_disposition_bytes.is_empty() {
None
} else {
Some(
string_from_bytes(content_disposition_bytes)
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
)
};
Ok((content_disposition, content_type, key))
}
/// Gets all the media keys in our database (this includes all the metadata
/// associated with it such as width, height, content-type, etc)
fn get_all_media_keys(&self) -> Vec<Vec<u8>> { self.mediaid_file.iter().map(|(key, _)| key).collect() }
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
let mut value = Vec::<u8>::new();
value.extend_from_slice(&timestamp.as_secs().to_be_bytes());
value.push(0xFF);
value.extend_from_slice(
data.title
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(
data.description
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(
data.image
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
value.push(0xFF);
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
value.push(0xFF);
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
self.url_previews.insert(url.as_bytes(), &value)
}
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
let values = self.url_previews.get(url.as_bytes()).ok()??;
let mut values = values.split(|&b| b == 0xFF);
let _ts = values.next();
/* if we ever decide to use timestamp, this is here.
match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
Some(0) => None,
x => x,
};*/
let title = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let description = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let image = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let image_size = match values
.next()
.map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
let image_width = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
let image_height = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
Some(UrlPreviewData {
title,
description,
image,
image_size,
image_width,
image_height,
})
}
}

View file

@ -1,14 +0,0 @@
mod account_data;
//mod admin;
mod appservice;
mod globals;
mod key_backups;
mod media;
//mod pdu;
mod presence;
mod pusher;
mod rooms;
mod sending;
mod transaction_ids;
mod uiaa;
mod users;

View file

@ -1,127 +0,0 @@
use conduit::debug_info;
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use crate::{
presence::Presence,
services,
utils::{self, user_id_from_bytes},
Error, KeyValueDatabase, Result,
};
impl crate::presence::Data for KeyValueDatabase {
fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.get(&key)?
.map(|presence_bytes| -> Result<(u64, PresenceEvent)> {
Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?))
})
.transpose()
} else {
Ok(None)
}
}
fn set_presence(
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>,
) -> Result<()> {
let last_presence = self.get_presence(user_id)?;
let state_changed = match last_presence {
None => true,
Some(ref presence) => presence.1.content.presence != *presence_state,
};
let now = utils::millis_since_unix_epoch();
let last_last_active_ts = match last_presence {
None => 0,
Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
};
let last_active_ts = match last_active_ago {
None => now,
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
};
// tighten for state flicker?
if !state_changed && last_active_ts <= last_last_active_ts {
debug_info!(
"presence spam {:?} last_active_ts:{:?} <= {:?}",
user_id,
last_active_ts,
last_last_active_ts
);
return Ok(());
}
let status_msg = if status_msg.as_ref().is_some_and(String::is_empty) {
None
} else {
status_msg
};
let presence = Presence::new(
presence_state.to_owned(),
currently_active.unwrap_or(false),
last_active_ts,
status_msg,
);
let count = services().globals.next_count()?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.insert(&key, &presence.to_json_bytes()?)?;
self.userid_presenceid
.insert(user_id.as_bytes(), &count.to_be_bytes())?;
if let Some((last_count, _)) = last_presence {
let key = presenceid_key(last_count, user_id);
self.presenceid_presence.remove(&key)?;
}
Ok(())
}
fn remove_presence(&self, user_id: &UserId) -> Result<()> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence.remove(&key)?;
self.userid_presenceid.remove(user_id.as_bytes())?;
}
Ok(())
}
fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a> {
Box::new(
self.presenceid_presence
.iter()
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec<u8>)> {
let (count, user_id) = presenceid_parse(&key)?;
Ok((user_id, count, presence_bytes))
})
.filter(move |(_, count, _)| *count > since),
)
}
}
#[inline]
fn presenceid_key(count: u64, user_id: &UserId) -> Vec<u8> {
[count.to_be_bytes().to_vec(), user_id.as_bytes().to_vec()].concat()
}
#[inline]
fn presenceid_parse(key: &[u8]) -> Result<(u64, OwnedUserId)> {
let (count, user_id) = key.split_at(8);
let user_id = user_id_from_bytes(user_id)?;
let count = utils::u64_from_bytes(count).unwrap();
Ok((count, user_id))
}

View file

@ -1,65 +0,0 @@
use ruma::{
api::client::push::{set_pusher, Pusher},
UserId,
};
use crate::{utils, Error, KeyValueDatabase, Result};
impl crate::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
match &pusher {
set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.senderkey_pusher
.insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
Ok(())
},
set_pusher::v3::PusherAction::Delete(ids) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(ids.pushkey.as_bytes());
self.senderkey_pusher.remove(&key).map_err(Into::into)
},
}
}
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes());
self.senderkey_pusher
.get(&senderkey)?
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.transpose()
}
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
self.senderkey_pusher
.scan_prefix(prefix)
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.collect()
}
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next();
let push_key = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
let push_key_string = utils::string_from_bytes(push_key)
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
Ok(push_key_string)
}))
}
}

View file

@ -1,75 +0,0 @@
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::alias::Data for KeyValueDatabase {
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(())
}
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
let mut prefix = room_id;
prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?;
}
self.alias_roomid.remove(alias.alias().as_bytes())?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
}
Ok(())
}
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.alias_roomid
.get(alias.alias().as_bytes())?
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
})
.transpose()
}
fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
}))
}
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
Box::new(
self.alias_roomid
.iter()
.map(|(room_alias_bytes, room_id_bytes)| {
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
Ok((room_id, room_alias_localpart))
}),
)
}
}

View file

@ -1,59 +0,0 @@
use std::{mem::size_of, sync::Arc};
use crate::{utils, KeyValueDatabase, Result};
impl crate::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result)));
}
// We only save auth chains for single events in the db
if key.len() == 1 {
// Check DB cache
let chain = self
.shorteventid_authchain
.get(&key[0].to_be_bytes())?
.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
.collect::<Arc<[u64]>>()
});
if let Some(chain) = chain {
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(vec![key[0]], Arc::clone(&chain));
return Ok(Some(chain));
}
}
Ok(None)
}
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
// Only persist single events in db
if key.len() == 1 {
self.shorteventid_authchain.insert(
&key[0].to_be_bytes(),
&auth_chain
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
)?;
}
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(key, auth_chain);
Ok(())
}
}

View file

@ -1,23 +0,0 @@
use ruma::{OwnedRoomId, RoomId};
use crate::{utils, Error, KeyValueDatabase, Result};
impl crate::rooms::directory::Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
}
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
}))
}
}

View file

@ -1,53 +0,0 @@
use ruma::{DeviceId, RoomId, UserId};
use crate::{KeyValueDatabase, Result};
impl crate::rooms::lazy_loading::Data for KeyValueDatabase {
fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some())
}
fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for ll_id in confirmed_user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?;
}
Ok(())
}
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?;
}
Ok(())
}
}

View file

@ -1,76 +0,0 @@
use ruma::{OwnedRoomId, RoomId};
use tracing::error;
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};
// Look for PDUs in that room.
Ok(self
.pduid_pdu
.iter_from(&prefix, false)
.next()
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
}))
}
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
}
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.disabledroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
if banned {
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.bannedroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.bannedroomids.iter().map(
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|e| {
error!("Invalid room_id bytes in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids.")
})?
.try_into()
.map_err(|e| {
error!("Invalid room_id in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids")
})?;
Ok(room_id)
},
))
}
}

View file

@ -1,21 +0,0 @@
mod alias;
mod auth_chain;
mod directory;
mod lazy_load;
mod metadata;
mod outlier;
mod pdu_metadata;
mod read_receipt;
mod search;
mod short;
mod state;
mod state_accessor;
mod state_cache;
mod state_compressor;
mod threads;
mod timeline;
mod user;
use crate::KeyValueDatabase;
impl crate::rooms::Data for KeyValueDatabase {}

View file

@ -1,28 +0,0 @@
use ruma::{CanonicalJsonObject, EventId};
use crate::{Error, KeyValueDatabase, PduEvent, Result};
impl crate::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
)
}
}

View file

@ -1,79 +0,0 @@
use std::{mem, sync::Arc};
use ruma::{EventId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result};
impl crate::rooms::pdu_metadata::Data for KeyValueDatabase {
fn add_relation(&self, from: u64, to: u64) -> Result<()> {
let mut key = to.to_be_bytes().to_vec();
key.extend_from_slice(&from.to_be_bytes());
self.tofrom_relation.insert(&key, &[])?;
Ok(())
}
fn relations_until<'a>(
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
let count_raw = match until {
PduCount::Normal(x) => x.saturating_sub(1),
PduCount::Backfilled(x) => {
current.extend_from_slice(&0_u64.to_be_bytes());
u64::MAX.saturating_sub(x).saturating_sub(1)
},
};
current.extend_from_slice(&count_raw.to_be_bytes());
Ok(Box::new(
self.tofrom_relation
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((PduCount::Normal(from), pdu))
}),
))
}
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
}
Ok(())
}
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some())
}
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.softfailedeventids.insert(event_id.as_bytes(), &[])
}
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids
.get(event_id.as_bytes())
.map(|o| o.is_some())
}
}

View file

@ -1,120 +0,0 @@
use std::mem;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
// Remove old entry
if let Some((old, _)) = self
.readreceiptid_readreceipt
.iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| {
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element")
== user_id.as_bytes()
}) {
// This is the old room_latest
self.readreceiptid_readreceipt.remove(&old)?;
}
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert(
&room_latest_id,
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
)?;
Ok(())
}
fn readreceipts_since<'a>(
&'a self, room_id: &RoomId, since: u64,
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since
Box::new(
self.readreceiptid_readreceipt
.iter_from(&first_possible_edu, false)
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
let user_id = UserId::parse(
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
)
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
json.remove("room_id");
Ok((
user_id,
count,
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
))
}),
)
}
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes())
}
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.get(&key)?
.map_or(Ok(None), |v| {
Ok(Some(
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
))
})
}
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastprivatereadupdate
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
}
}

View file

@ -1,64 +0,0 @@
use ruma::RoomId;
use crate::{services, utils, KeyValueDatabase, Result};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
impl crate::rooms::search::Data for KeyValueDatabase {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)
.map(str::to_lowercase)
.map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
(key, Vec::new())
});
self.tokenids.insert_batch(&mut batch)
}
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let words: Vec<_> = search_string
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(str::to_lowercase)
.collect();
let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
self.tokenids
.iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(key, _)| key[prefix3.len()..].to_vec())
});
let Some(common_elements) = utils::common_elements(iterators, |a, b| {
// We compare b with a because we reversed the iterator earlier
b.cmp(a)
}) else {
return Ok(None);
};
Ok(Some((Box::new(common_elements), words)))
}
}

View file

@ -1,164 +0,0 @@
use std::sync::Arc;
use ruma::{events::StateEventType, EventId, RoomId};
use tracing::warn;
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else {
let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
Ok(short)
}
fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len());
let keys = event_ids
.iter()
.map(|id| id.as_bytes())
.collect::<Vec<&[u8]>>();
for (i, short) in self
.eventid_shorteventid
.multi_get(&keys)?
.iter()
.enumerate()
{
match short {
Some(short) => ret.push(
utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
),
None => {
let short = services().globals.next_count()?;
self.eventid_shorteventid
.insert(keys[i], &short.to_be_bytes())?;
self.shorteventid_eventid
.insert(&short.to_be_bytes(), keys[i])?;
debug_assert!(ret.len() == i, "position of result must match input");
ret.push(short);
},
}
}
Ok(ret)
}
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = self
.statekey_shortstatekey
.get(&statekey_vec)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
})
.transpose()?;
Ok(short)
}
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
} else {
let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
shortstatekey
};
Ok(short)
}
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
Ok(event_id)
}
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
})?);
let state_key = utils::string_from_bytes(statekey_bytes)
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
let result = (event_type, state_key);
Ok(result)
}
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? {
(
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
)
} else {
let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
})
}
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
.transpose()
}
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else {
let short = services().globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
})
}
}

View file

@ -1,73 +0,0 @@
use std::{collections::HashSet, sync::Arc};
use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard;
use crate::{utils, Error, KeyValueDatabase, Result};
impl crate::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
}
fn set_room_state(
&self,
room_id: &RoomId,
new_shortstatehash: u64,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
}
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.clone();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
}
Ok(())
}
}

View file

@ -1,165 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
#[async_trait]
impl crate::rooms::state_accessor::Data for KeyValueDatabase {
#[allow(unused_qualifications)] // async traits
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let parsed = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
i = i.wrapping_add(1);
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
#[allow(unused_qualifications)] // async traits
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let (_, eventid) = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.clone(),
),
pdu,
);
}
i = i.wrapping_add(1);
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services()
.rooms
.short
.get_shortstatekey(event_type, state_key)?
else {
return Ok(None);
};
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
}
/// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")
})
})
.transpose()
})
}
/// Returns the full room state.
#[allow(unused_qualifications)] // async traits
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
}

View file

@ -1,626 +0,0 @@
use std::{collections::HashSet, sync::Arc};
use itertools::Itertools;
use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use tracing::error;
use crate::{
appservice::RegistrationInfo,
services, user_is_local,
utils::{self},
Error, KeyValueDatabase, Result,
};
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
impl crate::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])
}
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xFF);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if self
.roomuserid_joined
.scan_prefix(roomid_prefix.clone())
.count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count() == 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(())
}
fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate.insert(
&userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
}
Ok(())
}
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xFF);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate.insert(
&userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO
self.roomuserid_leftcount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
if self
.roomuserid_joined
.scan_prefix(roomid_prefix.clone())
.count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count() == 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(())
}
fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new();
let mut real_users = HashSet::new();
for joined in self.room_members(room_id).filter_map(Result::ok) {
joined_servers.insert(joined.server_name().to_owned());
if user_is_local(&joined) && !services().users.is_deactivated(&joined).unwrap_or(true) {
real_users.insert(joined);
}
joinedcount = joinedcount.saturating_add(1);
}
for _invited in self.room_members_invited(room_id).filter_map(Result::ok) {
invitedcount = invitedcount.saturating_add(1);
}
self.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
self.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
self.our_real_users_cache
.write()
.unwrap()
.insert(room_id.to_owned(), Arc::new(real_users));
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) {
if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?;
self.serverroomids.remove(&serverroom_id)?;
}
}
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
Ok(())
}
#[tracing::instrument(skip(self, room_id))]
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
let maybe = self
.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.cloned();
if let Some(users) = maybe {
Ok(users)
} else {
self.update_joined_count(room_id)?;
Ok(Arc::clone(
self.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.unwrap(),
))
}
}
#[tracing::instrument(skip(self, room_id, appservice))]
fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
let maybe = self
.appservice_in_room_cache
.read()
.unwrap()
.get(room_id)
.and_then(|map| map.get(&appservice.registration.id))
.copied();
if let Some(b) = maybe {
Ok(b)
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
services().globals.server_name(),
)
.ok();
let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|| self
.room_members(room_id)
.any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())));
self.appservice_in_room_cache
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default()
.insert(appservice.registration.id.clone(), in_room);
Ok(in_room)
}
}
/// Makes a user forget a room.
#[tracing::instrument(skip(self))]
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
Ok(())
}
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self))]
fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
}))
}
#[tracing::instrument(skip(self))]
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
let mut key = server.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.serverroomids.get(&key).map(|o| o.is_some())
}
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self))]
fn server_rooms<'a>(&'a self, server: &ServerName) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
}))
}
/// Returns an iterator over all joined members of a room.
#[tracing::instrument(skip(self))]
fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
}))
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self))]
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_joinedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self))]
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_invitedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self))]
fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuseroncejoinedids
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
}),
)
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self))]
fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuserid_invitecount
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
}),
)
}
#[tracing::instrument(skip(self))]
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(&key)?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
))
})
}
#[tracing::instrument(skip(self))]
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_leftcount
.get(&key)?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db.")))
.transpose()
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self))]
fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(
self.userroomid_joined
.scan_prefix(user_id.as_bytes().to_vec())
.map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
}),
)
}
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self))]
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_invitestate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self))]
fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok(state)
})
.transpose()
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self))]
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_leftstate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self))]
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
self.roomid_inviteviaservers
.get(&key)?
.map(|servers| {
let state = serde_json::from_slice(&servers).map_err(|e| {
error!("Invalid state in userroomid_leftstate: {e}");
Error::bad_database("Invalid state in userroomid_leftstate.")
})?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
prev_servers.append(servers.to_owned().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
Ok(())
}
}

View file

@ -1,60 +0,0 @@
use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{rooms::state_compressor::data::StateDiff, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let parent = if parent != 0 {
Some(parent)
} else {
None
};
let mut add_mode = true;
let mut added = HashSet::new();
let mut removed = HashSet::new();
let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i += size_of::<u64>();
continue;
}
if add_mode {
added.insert(v.try_into().expect("we checked the size above"));
} else {
removed.insert(v.try_into().expect("we checked the size above"));
}
i += 2 * size_of::<u64>();
}
Ok(StateDiff {
parent,
added: Arc::new(added),
removed: Arc::new(removed),
})
}
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);
}
if !diff.removed.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in diff.removed.iter() {
value.extend_from_slice(&removed[..]);
}
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)
}
}

View file

@ -1,75 +0,0 @@
use std::mem;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
impl crate::rooms::threads::Data for KeyValueDatabase {
fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(until - 1).to_be_bytes());
Ok(Box::new(
self.threadid_userids
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((count, pdu))
}),
))
}
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
let users = participants
.iter()
.map(|user| user.as_bytes())
.collect::<Vec<_>>()
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?;
Ok(())
}
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some(
users
.split(|b| *b == 0xFF)
.map(|bytes| {
UserId::parse(
utils::string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
)
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
})
.filter_map(Result::ok)
.collect(),
))
} else {
Ok(None)
}
}
}

View file

@ -1,295 +0,0 @@
use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use tracing::error;
use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result};
impl crate::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.unwrap()
.entry(room_id.to_owned())
{
hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::max())?
.find_map(|r| {
// Filter out buggy events
if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r);
}
r.ok()
}) {
Ok(*v.insert(last_count.0))
} else {
Ok(PduCount::Normal(0))
}
},
hash_map::Entry::Occupied(o) => Ok(*o.get()),
}
}
/// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pdu_id| pdu_count(&pdu_id))
.transpose()
}
/// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)
}
/// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
}
/// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
/// Returns the pdu.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
}
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(pdu) = self
.get_non_outlier_pdu(event_id)?
.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)?
.map(Arc::new)
{
Ok(Some(pdu))
} else {
Ok(None)
}
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.lasttimelinecount_cache
.lock()
.unwrap()
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
Ok(())
}
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(())
}
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
)?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
}
Ok(())
}
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
fn pdus_until<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
}
fn pdus_after<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
}
fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()> {
let mut notifies_batch = Vec::new();
let mut highlights_batch = Vec::new();
for user in notifies {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
notifies_batch.push(userroom_id);
}
for user in highlights {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id);
}
self.userroomid_notificationcount
.increment_batch(&mut notifies_batch.into_iter())?;
self.userroomid_highlightcount
.increment_batch(&mut highlights_batch.into_iter())?;
Ok(())
}
}
/// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 =
utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
if matches!(second_last_u64, Ok(0)) {
Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)))
} else {
Ok(PduCount::Normal(last_u64))
}
}
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}

View file

@ -1,137 +0,0 @@
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.roomuserid_lastnotificationread
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
Ok(())
}
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
})
}
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
}
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastnotificationread
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
}
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
}
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
})
.transpose()
}
fn get_shared_rooms<'a>(
&'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
self.userroomid_joined
.scan_prefix(prefix)
.map(|(key, _)| {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xFF)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0
.saturating_add(1); // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id)
})
.filter_map(Result::ok)
});
// We use the default compare function because keys are sorted correctly (not
// reversed)
Ok(Box::new(
utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}),
))
}
}

View file

@ -1,191 +0,0 @@
use ruma::{ServerName, UserId};
use crate::{
sending::{Destination, SendingEvent},
services, utils, Error, KeyValueDatabase, Result,
};
impl crate::sending::Data for KeyValueDatabase {
fn active_requests<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a> {
Box::new(
self.servercurrentevent_data
.iter()
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
)
}
fn active_requests_for<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a> {
let prefix = destination.get_prefix();
Box::new(
self.servercurrentevent_data
.scan_prefix(prefix)
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
)
}
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?;
}
Ok(())
}
fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
self.servercurrentevent_data.remove(&key).unwrap();
}
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
self.servernameevent_data.remove(&key).unwrap();
}
Ok(())
}
fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>> {
let mut batch = Vec::new();
let mut keys = Vec::new();
for (destination, event) in requests {
let mut key = destination.get_prefix();
if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value);
} else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
}
let value = if let SendingEvent::Edu(value) = &event {
&**value
} else {
&[]
};
batch.push((key.clone(), value.to_owned()));
keys.push(key);
}
self.servernameevent_data
.insert_batch(&mut batch.into_iter())?;
Ok(keys)
}
fn queued_requests<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a> {
let prefix = destination.get_prefix();
return Box::new(
self.servernameevent_data
.scan_prefix(prefix)
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
);
}
fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()> {
for (e, key) in events {
if key.is_empty() {
continue;
}
let value = if let SendingEvent::Edu(value) = &e {
&**value
} else {
&[]
};
self.servercurrentevent_data.insert(key, value)?;
self.servernameevent_data.remove(key)?;
}
Ok(())
}
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
}
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount
.get(server_name.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
})
}
}
#[tracing::instrument(skip(key))]
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, SendingEvent)> {
// Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
(
Destination::Appservice(server),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
},
)
} else if key.starts_with(b"$") {
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
let user = parts.next().expect("splitn always returns one element");
let user_string = utils::string_from_bytes(user)
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
let user_id =
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let pushkey = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let pushkey_string = utils::string_from_bytes(pushkey)
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
(
Destination::Push(user_id, pushkey_string),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
// I'm pretty sure this should never be called
SendingEvent::Edu(value)
},
)
} else {
let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
(
Destination::Normal(
ServerName::parse(server)
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
},
)
})
}

View file

@ -1,32 +0,0 @@
use ruma::{DeviceId, TransactionId, UserId};
use crate::{KeyValueDatabase, Result};
impl crate::transaction_ids::Data for KeyValueDatabase {
fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(&key, data)?;
Ok(())
}
fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction
self.userdevicetxnid_response.get(&key)
}
}

View file

@ -1,68 +0,0 @@
use ruma::{
api::client::{error::ErrorKind, uiaa::UiaaInfo},
CanonicalJsonValue, DeviceId, UserId,
};
use crate::{Error, KeyValueDatabase, Result};
impl crate::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()> {
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(())
}
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
self.userdevicesessionid_uiaarequest
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(ToOwned::to_owned)
}
fn update_uiaa_session(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?;
} else {
self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
}
Ok(())
}
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice(
&self
.userdevicesessionid_uiaainfo
.get(&userdevicesessionid)?
.ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
}
}

View file

@ -1,898 +0,0 @@
use std::{collections::BTreeMap, mem::size_of};
use argon2::{password_hash::SaltString, PasswordHasher};
use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{AnyToDeviceEvent, StateEventType},
serde::Raw,
uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
OwnedMxcUri, OwnedUserId, UInt, UserId,
};
use tracing::warn;
use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, Result};
impl crate::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) }
/// Check if account is deactivated
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
Ok(self
.userid_password
.get(user_id.as_bytes())?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))?
.is_empty())
}
/// Returns the number of users registered on this server.
fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) }
/// Find out which user an access token belongs to.
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
self.token_userdeviceid
.get(token.as_bytes())?
.map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xFF);
let user_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?;
let device_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?;
Ok(Some((
UserId::parse(
utils::string_from_bytes(user_bytes)
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?,
utils::string_from_bytes(device_bytes)
.map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?,
)))
})
}
/// Returns an iterator over all users on this homeserver.
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
Box::new(self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
}))
}
/// Returns a list of local users as list of usernames.
///
/// A user account is considered `local` if the length of it's password is
/// greater then zero.
fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self
.userid_password
.iter()
.filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw))
.collect();
Ok(users)
}
/// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.")
})?))
})
}
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = calculate_password_hash(password) {
self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(())
} else {
Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Password does not meet the requirements.",
))
}
} else {
self.userid_password.insert(user_id.as_bytes(), b"")?;
Ok(())
}
}
/// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Displayname in db is invalid."))?,
))
})
}
/// Sets a new displayname or removes it if displayname is None. You still
/// need to nofify all rooms of this change.
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname {
self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?;
} else {
self.userid_displayname.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Get the `avatar_url` of a user.
fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> {
self.userid_avatarurl
.get(user_id.as_bytes())?
.map(|bytes| {
let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| {
warn!("Avatar URL in db is invalid: {}", e);
Error::bad_database("Avatar URL in db is invalid.")
})?;
let mxc_uri: OwnedMxcUri = s_bytes.into();
Ok(mxc_uri)
})
.transpose()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
if let Some(avatar_url) = avatar_url {
self.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
} else {
self.userid_avatarurl.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Get the blurhash of a user.
fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_blurhash
.get(user_id.as_bytes())?
.map(|bytes| {
let s = utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?;
Ok(s)
})
.transpose()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
if let Some(blurhash) = blurhash {
self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
} else {
self.userid_blurhash.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Adds a new device to a user.
fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
) -> Result<()> {
// This method should never be called for nonexistent users. We shouldn't assert
// though...
if !self.exists(user_id)? {
warn!("Called create_device for non-existent user {} in database", user_id);
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
}
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(&Device {
device_id: device_id.into(),
display_name: initial_device_display_name,
last_seen_ip: None, // TODO
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
})
.expect("Device::to_string never fails."),
)?;
self.set_token(user_id, device_id, token)?;
Ok(())
}
/// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Remove tokens
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.userdeviceid_token.remove(&userdeviceid)?;
self.token_userdeviceid.remove(&old_token)?;
}
// Remove todevice events
let mut prefix = userdeviceid.clone();
prefix.push(0xFF);
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
self.todeviceid_events.remove(&key)?;
}
// TODO: Remove onetimekeys
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.remove(&userdeviceid)?;
Ok(())
}
/// Returns an iterator over all device ids of this user.
fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
// All devices have metadata
Box::new(
self.userdeviceid_metadata
.scan_prefix(prefix)
.map(|(bytes, _)| {
Ok(utils::string_from_bytes(
bytes
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
.into())
}),
)
}
/// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// should not be None, but we shouldn't assert either lol...
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
warn!(
"Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
// Remove old token
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.token_userdeviceid.remove(&old_token)?;
// It will be removed from userdeviceid_token by the insert later
}
// Assign token to user device combination
self.userdeviceid_token
.insert(&userdeviceid, token.as_bytes())?;
self.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid)?;
Ok(())
}
fn add_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
// All devices have metadata
// Only existing devices should be able to call this, but we shouldn't assert
// either...
if self.userdeviceid_metadata.get(&key)?.is_none() {
warn!(
"Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \
database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
key.push(0xFF);
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore)
key.extend_from_slice(
serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works")
.as_bytes(),
);
self.onetimekeyid_onetimekeys.insert(
&key,
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
)?;
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
Ok(())
}
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
})
}
fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.push(b'"'); // Annoying quotation mark
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
prefix.push(b':');
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys
.scan_prefix(prefix)
.next()
.map(|(key, value)| {
self.onetimekeyid_onetimekeys.remove(&key)?;
Ok((
serde_json::from_slice(
key.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
)
.map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?,
))
})
.transpose()
}
fn count_one_time_keys(
&self, user_id: &UserId, device_id: &DeviceId,
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
let mut counts = BTreeMap::new();
for algorithm in self
.onetimekeyid_onetimekeys
.scan_prefix(userdeviceid)
.map(|(bytes, _)| {
Ok::<_, Error>(
serde_json::from_slice::<OwnedDeviceKeyId>(
bytes
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
.algorithm(),
)
}) {
*counts.entry(algorithm?).or_default() += uint!(1);
}
Ok(counts)
}
fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.keyid_key.insert(
&userdeviceid,
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
)?;
self.mark_device_key_update(user_id)?;
Ok(())
}
fn add_cross_signing_keys(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
) -> Result<()> {
// TODO: Check signatures
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
self.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes())?;
self.userid_masterkeyid
.insert(user_id.as_bytes(), &master_key_key)?;
// Self-signing key
if let Some(self_signing_key) = self_signing_key {
let mut self_signing_key_ids = self_signing_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))?
.keys
.into_values();
let self_signing_key_id = self_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained more than one key.",
));
}
let mut self_signing_key_key = prefix.clone();
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
self.keyid_key
.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?;
self.userid_selfsigningkeyid
.insert(user_id.as_bytes(), &self_signing_key_key)?;
}
// User-signing key
if let Some(user_signing_key) = user_signing_key {
let mut user_signing_key_ids = user_signing_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))?
.keys
.into_values();
let user_signing_key_id = user_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?;
if user_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"User signing key contained more than one key.",
));
}
let mut user_signing_key_key = prefix;
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
self.keyid_key
.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?;
self.userid_usersigningkeyid
.insert(user_id.as_bytes(), &user_signing_key_key)?;
}
if notify {
self.mark_device_key_update(user_id)?;
}
Ok(())
}
fn sign_key(
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
) -> Result<()> {
let mut key = target_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(key_id.as_bytes());
let mut cross_signing_key: serde_json::Value = serde_json::from_slice(
&self
.keyid_key
.get(&key)?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?,
)
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
let signatures = cross_signing_key
.get_mut("signatures")
.ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))?
.as_object_mut()
.ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))?
.entry(sender_id.to_string())
.or_insert_with(|| serde_json::Map::new().into());
signatures
.as_object_mut()
.ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))?
.insert(signature.0, signature.1.into());
self.keyid_key.insert(
&key,
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
)?;
self.mark_device_key_update(target_id)?;
Ok(())
}
fn keys_changed<'a>(
&'a self, user_or_room_id: &str, from: u64, to: Option<u64>,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = user_or_room_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut start = prefix.clone();
start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes());
let to = to.unwrap_or(u64::MAX);
Box::new(
self.keychangeid_userid
.iter_from(&start, false)
.take_while(move |(k, _)| {
k.starts_with(&prefix)
&& if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) {
if let Ok(c) = utils::u64_from_bytes(current) {
c <= to
} else {
warn!("BadDatabase: Could not parse keychangeid_userid bytes");
false
}
} else {
warn!("BadDatabase: Could not parse keychangeid_userid");
false
}
})
.map(|(_, bytes)| {
UserId::parse(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid."))
}),
)
}
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes();
for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
// Don't send key updates to unencrypted rooms
if services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none()
{
continue;
}
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
Ok(())
}
fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(
serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?,
))
})
}
fn parse_master_key(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let master_key = master_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
let mut master_key_ids = master_key.keys.values();
let master_key_id = master_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
if master_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Master key contained more than one key.",
));
}
let mut master_key_key = prefix.clone();
master_key_key.extend_from_slice(master_key_id.as_bytes());
Ok((master_key_key, master_key))
}
fn get_key(
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?;
Ok(Some(Raw::from_json(
serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"),
)))
})
}
fn get_master_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_masterkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
}
fn get_self_signing_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_selfsigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
}
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| {
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?,
))
})
})
}
fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value,
) -> Result<()> {
let mut key = target_user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into());
json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content);
let value = serde_json::to_vec(&json).expect("Map::to_vec always works");
self.todeviceid_events.insert(&key, &value)?;
Ok(())
}
fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
let mut events = Vec::new();
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
events.push(
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
);
}
Ok(events)
}
fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
let mut last = prefix.clone();
last.extend_from_slice(&until.to_be_bytes());
for (key, _) in self
.todeviceid_events
.iter_from(&last, true) // this includes last
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(key, _)| {
Ok::<_, Error>((
key.clone(),
utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()])
.map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?,
))
})
.filter_map(Result::ok)
.take_while(|&(_, count)| count <= until)
{
self.todeviceid_events.remove(&key)?;
}
Ok(())
}
fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Only existing devices should be able to call this, but we shouldn't assert
// either...
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
warn!(
"Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \
metadata in database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(device).expect("Device::to_string always works"),
)?;
Ok(())
}
/// Get device metadata.
fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userdeviceid_metadata
.get(&userdeviceid)?
.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
})?))
})
}
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid devicelistversion in db."))
.map(Some)
})
}
fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
Box::new(
self.userdeviceid_metadata
.scan_prefix(key)
.map(|(_, bytes)| {
serde_json::from_slice::<Device>(&bytes)
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
}),
)
}
/// Creates a new sync filter. Returns the filter id.
fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> {
let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
self.userfilterid_filter
.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?;
Ok(filter_id)
}
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
let raw = self.userfilterid_filter.get(&key)?;
if let Some(raw) = raw {
serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db."))
} else {
Ok(None)
}
}
}
/// Will only return with Some(username) if the password was not empty and the
/// username could be successfully parsed.
/// If `utils::string_from_bytes`(...) returns an error that username will be
/// skipped and the error will be logged.
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
// A valid password is not empty
if password.is_empty() {
None
} else {
match utils::string_from_bytes(username) {
Ok(u) => Some(u),
Err(e) => {
warn!("Failed to parse username while calling get_local_users(): {}", e.to_string());
None
},
}
}
}
/// Calculate a new hash for the given password
fn calculate_password_hash(password: &str) -> Result<String, argon2::password_hash::Error> {
let salt = SaltString::generate(rand::thread_rng());
services()
.globals
.argon
.hash_password(password.as_bytes(), &salt)
.map(|it| it.to_string())
}

View file

@ -1,6 +1,10 @@
use crate::Result;
use conduit::debug_info;
use ruma::api::client::error::ErrorKind;
use tracing::debug;
pub trait Data: Send + Sync {
use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, KeyValueDatabase, Result};
pub(crate) trait Data: Send + Sync {
fn create_file_metadata(
&self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
content_type: Option<&str>,
@ -21,7 +25,237 @@ pub trait Data: Send + Sync {
#[allow(dead_code)]
fn remove_url_preview(&self, url: &str) -> Result<()>;
fn set_url_preview(&self, url: &str, data: &super::UrlPreviewData, timestamp: std::time::Duration) -> Result<()>;
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()>;
fn get_url_preview(&self, url: &str) -> Option<super::UrlPreviewData>;
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData>;
}
impl Data for KeyValueDatabase {
fn create_file_metadata(
&self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xFF);
key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
);
self.mediaid_file.insert(&key, &[])?;
if let Some(user) = sender_user {
let key = mxc.as_bytes().to_vec();
let user = user.as_bytes().to_vec();
self.mediaid_user.insert(&key, &user)?;
}
Ok(key)
}
fn delete_file_mxc(&self, mxc: String) -> Result<()> {
debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
debug!("MXC db prefix: {prefix:?}");
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
debug!("Deleting key: {:?}", key);
self.mediaid_file.remove(&key)?;
}
for (key, value) in self.mediaid_user.scan_prefix(mxc.as_bytes().to_vec()) {
if key == mxc.as_bytes().to_vec() {
let user = string_from_bytes(&value).unwrap_or_default();
debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}");
self.mediaid_user.remove(&key)?;
}
}
Ok(())
}
/// Searches for all files with the given MXC
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
let keys: Vec<Vec<u8>> = self
.mediaid_file
.scan_prefix(prefix)
.map(|(key, _)| key)
.collect();
if keys.is_empty() {
return Err(Error::bad_database(
"Failed to find any keys in database with the provided MXC.",
));
}
debug!("Got the following keys: {:?}", keys);
Ok(keys)
}
fn search_file_metadata(
&self, mxc: String, width: u32, height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(&width.to_be_bytes());
prefix.extend_from_slice(&height.to_be_bytes());
prefix.push(0xFF);
let (key, _) = self
.mediaid_file
.scan_prefix(prefix)
.next()
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let mut parts = key.rsplit(|&b| b == 0xFF);
let content_type = parts
.next()
.map(|bytes| {
string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
})
.transpose()?;
let content_disposition_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let content_disposition = if content_disposition_bytes.is_empty() {
None
} else {
Some(
string_from_bytes(content_disposition_bytes)
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
)
};
Ok((content_disposition, content_type, key))
}
/// Gets all the media keys in our database (this includes all the metadata
/// associated with it such as width, height, content-type, etc)
fn get_all_media_keys(&self) -> Vec<Vec<u8>> { self.mediaid_file.iter().map(|(key, _)| key).collect() }
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
let mut value = Vec::<u8>::new();
value.extend_from_slice(&timestamp.as_secs().to_be_bytes());
value.push(0xFF);
value.extend_from_slice(
data.title
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(
data.description
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(
data.image
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
value.push(0xFF);
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
value.push(0xFF);
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
self.url_previews.insert(url.as_bytes(), &value)
}
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
let values = self.url_previews.get(url.as_bytes()).ok()??;
let mut values = values.split(|&b| b == 0xFF);
let _ts = values.next();
/* if we ever decide to use timestamp, this is here.
match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
Some(0) => None,
x => x,
};*/
let title = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let description = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let image = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let image_size = match values
.next()
.map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
let image_width = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
let image_height = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
Some(UrlPreviewData {
title,
description,
image,
image_size,
image_width,
image_height,
})
}
}

View file

@ -1,7 +1,7 @@
mod data;
use std::{collections::HashMap, io::Cursor, sync::Arc, time::SystemTime};
pub use data::Data;
use data::Data;
use image::imageops::FilterType;
use ruma::{OwnedMxcUri, OwnedUserId};
use serde::Serialize;
@ -39,7 +39,7 @@ pub struct UrlPreviewData {
}
pub struct Service {
pub db: Arc<dyn Data>,
pub(super) db: Arc<dyn Data>,
pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>,
}

View file

@ -1,4 +1,3 @@
pub(crate) mod key_value;
pub mod pdu;
pub mod services;

View file

@ -1,6 +1,12 @@
use conduit::debug_info;
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use crate::Result;
use crate::{
presence::Presence,
services,
utils::{self, user_id_from_bytes},
Error, KeyValueDatabase, Result,
};
pub trait Data: Send + Sync {
/// Returns the latest presence event for the given user.
@ -19,3 +25,121 @@ pub trait Data: Send + Sync {
/// with id `since`.
fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a>;
}
impl Data for KeyValueDatabase {
fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.get(&key)?
.map(|presence_bytes| -> Result<(u64, PresenceEvent)> {
Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?))
})
.transpose()
} else {
Ok(None)
}
}
fn set_presence(
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>,
) -> Result<()> {
let last_presence = self.get_presence(user_id)?;
let state_changed = match last_presence {
None => true,
Some(ref presence) => presence.1.content.presence != *presence_state,
};
let now = utils::millis_since_unix_epoch();
let last_last_active_ts = match last_presence {
None => 0,
Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
};
let last_active_ts = match last_active_ago {
None => now,
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
};
// tighten for state flicker?
if !state_changed && last_active_ts <= last_last_active_ts {
debug_info!(
"presence spam {:?} last_active_ts:{:?} <= {:?}",
user_id,
last_active_ts,
last_last_active_ts
);
return Ok(());
}
let status_msg = if status_msg.as_ref().is_some_and(String::is_empty) {
None
} else {
status_msg
};
let presence = Presence::new(
presence_state.to_owned(),
currently_active.unwrap_or(false),
last_active_ts,
status_msg,
);
let count = services().globals.next_count()?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.insert(&key, &presence.to_json_bytes()?)?;
self.userid_presenceid
.insert(user_id.as_bytes(), &count.to_be_bytes())?;
if let Some((last_count, _)) = last_presence {
let key = presenceid_key(last_count, user_id);
self.presenceid_presence.remove(&key)?;
}
Ok(())
}
fn remove_presence(&self, user_id: &UserId) -> Result<()> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence.remove(&key)?;
self.userid_presenceid.remove(user_id.as_bytes())?;
}
Ok(())
}
fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a> {
Box::new(
self.presenceid_presence
.iter()
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec<u8>)> {
let (count, user_id) = presenceid_parse(&key)?;
Ok((user_id, count, presence_bytes))
})
.filter(move |(_, count, _)| *count > since),
)
}
}
#[inline]
fn presenceid_key(count: u64, user_id: &UserId) -> Vec<u8> {
[count.to_be_bytes().to_vec(), user_id.as_bytes().to_vec()].concat()
}
#[inline]
fn presenceid_parse(key: &[u8]) -> Result<(u64, OwnedUserId)> {
let (count, user_id) = key.split_at(8);
let user_id = user_id_from_bytes(user_id)?;
let count = utils::u64_from_bytes(count).unwrap();
Ok((count, user_id))
}

View file

@ -2,7 +2,7 @@ mod data;
use std::{sync::Arc, time::Duration};
pub use data::Data;
use data::Data;
use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{
events::presence::{PresenceEvent, PresenceEventContent},

View file

@ -3,9 +3,9 @@ use ruma::{
UserId,
};
use crate::Result;
use crate::{utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
pub(crate) trait Data: Send + Sync {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>;
@ -14,3 +14,62 @@ pub trait Data: Send + Sync {
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a>;
}
impl Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
match &pusher {
set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.senderkey_pusher
.insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
Ok(())
},
set_pusher::v3::PusherAction::Delete(ids) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(ids.pushkey.as_bytes());
self.senderkey_pusher.remove(&key).map_err(Into::into)
},
}
}
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes());
self.senderkey_pusher
.get(&senderkey)?
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.transpose()
}
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
self.senderkey_pusher
.scan_prefix(prefix)
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.collect()
}
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next();
let push_key = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
let push_key_string = utils::string_from_bytes(push_key)
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
Ok(push_key_string)
}))
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::{fmt::Debug, mem, sync::Arc};
use bytes::BytesMut;
pub use data::Data;
use data::Data;
use ipaddress::IPAddress;
use ruma::{
api::{
@ -25,7 +25,7 @@ use tracing::{info, trace, warn};
use crate::{debug_info, services, Error, PduEvent, Result};
pub struct Service {
pub db: Arc<dyn Data>,
pub(super) db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,6 +1,6 @@
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::Result;
use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
/// Creates or updates the alias to the given room id.
@ -20,3 +20,75 @@ pub trait Data: Send + Sync {
/// Returns all local aliases on the server
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
}
impl Data for KeyValueDatabase {
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(())
}
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
let mut prefix = room_id;
prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?;
}
self.alias_roomid.remove(alias.alias().as_bytes())?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
}
Ok(())
}
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.alias_roomid
.get(alias.alias().as_bytes())?
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
})
.transpose()
}
fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
}))
}
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
Box::new(
self.alias_roomid
.iter()
.map(|(room_alias_bytes, room_id_bytes)| {
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
Ok((room_id, room_alias_localpart))
}),
)
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::Result;

View file

@ -1,8 +1,64 @@
use std::sync::Arc;
use std::{mem::size_of, sync::Arc};
use crate::Result;
use crate::{utils, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<[u64]>>>;
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()>;
}
impl Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result)));
}
// We only save auth chains for single events in the db
if key.len() == 1 {
// Check DB cache
let chain = self
.shorteventid_authchain
.get(&key[0].to_be_bytes())?
.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
.collect::<Arc<[u64]>>()
});
if let Some(chain) = chain {
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(vec![key[0]], Arc::clone(&chain));
return Ok(Some(chain));
}
}
Ok(None)
}
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
// Only persist single events in db
if key.len() == 1 {
self.shorteventid_authchain.insert(
&key[0].to_be_bytes(),
&auth_chain
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
)?;
}
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(key, auth_chain);
Ok(())
}
}

View file

@ -4,7 +4,7 @@ use std::{
sync::Arc,
};
pub use data::Data;
use data::Data;
use ruma::{api::client::error::ErrorKind, EventId, RoomId};
use tracing::{debug, error, trace, warn};

View file

@ -1,6 +1,6 @@
use ruma::{OwnedRoomId, RoomId};
use crate::Result;
use crate::{utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
/// Adds the room to the public room directory
@ -15,3 +15,23 @@ pub trait Data: Send + Sync {
/// Returns the unsorted public room directory
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
}
impl Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
}
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
}))
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{OwnedRoomId, RoomId};
use crate::Result;

View file

@ -1,6 +1,6 @@
use ruma::{DeviceId, RoomId, UserId};
use crate::Result;
use crate::{KeyValueDatabase, Result};
pub trait Data: Send + Sync {
fn lazy_load_was_sent_before(
@ -14,3 +14,53 @@ pub trait Data: Send + Sync {
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()>;
}
impl Data for KeyValueDatabase {
fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some())
}
fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for ll_id in confirmed_user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?;
}
Ok(())
}
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?;
}
Ok(())
}
}

View file

@ -4,7 +4,7 @@ use std::{
sync::Arc,
};
pub use data::Data;
use data::Data;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex;

View file

@ -1,6 +1,7 @@
use ruma::{OwnedRoomId, RoomId};
use tracing::error;
use crate::Result;
use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
fn exists(&self, room_id: &RoomId) -> Result<bool>;
@ -11,3 +12,75 @@ pub trait Data: Send + Sync {
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()>;
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
}
impl Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};
// Look for PDUs in that room.
Ok(self
.pduid_pdu
.iter_from(&prefix, false)
.next()
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
}))
}
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
}
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.disabledroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
if banned {
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.bannedroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.bannedroomids.iter().map(
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|e| {
error!("Invalid room_id bytes in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids.")
})?
.try_into()
.map_err(|e| {
error!("Invalid room_id in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids")
})?;
Ok(room_id)
},
))
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{OwnedRoomId, RoomId};
use crate::Result;

View file

@ -19,27 +19,6 @@ pub mod timeline;
pub mod typing;
pub mod user;
pub trait Data:
alias::Data
+ auth_chain::Data
+ directory::Data
+ lazy_loading::Data
+ metadata::Data
+ outlier::Data
+ pdu_metadata::Data
+ read_receipt::Data
+ search::Data
+ short::Data
+ state::Data
+ state_accessor::Data
+ state_cache::Data
+ state_compressor::Data
+ timeline::Data
+ threads::Data
+ user::Data
{
}
pub struct Service {
pub alias: alias::Service,
pub auth_chain: auth_chain::Service,

View file

@ -1,9 +1,34 @@
use ruma::{CanonicalJsonObject, EventId};
use crate::{PduEvent, Result};
use crate::{Error, KeyValueDatabase, PduEvent, Result};
pub trait Data: Send + Sync {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>;
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>;
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>;
}
impl Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
)
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{CanonicalJsonObject, EventId};
use crate::{PduEvent, Result};

View file

@ -1,8 +1,8 @@
use std::sync::Arc;
use std::{mem, sync::Arc};
use ruma::{EventId, RoomId, UserId};
use crate::{PduCount, PduEvent, Result};
use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result};
pub trait Data: Send + Sync {
fn add_relation(&self, from: u64, to: u64) -> Result<()>;
@ -15,3 +15,77 @@ pub trait Data: Send + Sync {
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>;
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool>;
}
impl Data for KeyValueDatabase {
fn add_relation(&self, from: u64, to: u64) -> Result<()> {
let mut key = to.to_be_bytes().to_vec();
key.extend_from_slice(&from.to_be_bytes());
self.tofrom_relation.insert(&key, &[])?;
Ok(())
}
fn relations_until<'a>(
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
let count_raw = match until {
PduCount::Normal(x) => x.saturating_sub(1),
PduCount::Backfilled(x) => {
current.extend_from_slice(&0_u64.to_be_bytes());
u64::MAX.saturating_sub(x).saturating_sub(1)
},
};
current.extend_from_slice(&count_raw.to_be_bytes());
Ok(Box::new(
self.tofrom_relation
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((PduCount::Normal(from), pdu))
}),
))
}
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
}
Ok(())
}
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some())
}
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.softfailedeventids.insert(event_id.as_bytes(), &[])
}
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids
.get(event_id.as_bytes())
.map(|o| o.is_some())
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{
api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType},

View file

@ -1,10 +1,12 @@
use std::mem;
use ruma::{
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
serde::Raw,
OwnedUserId, RoomId, UserId,
CanonicalJsonObject, OwnedUserId, RoomId, UserId,
};
use crate::Result;
use crate::{services, utils, Error, KeyValueDatabase, Result};
type AnySyncEphemeralRoomEventIter<'a> =
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
@ -32,3 +34,118 @@ pub trait Data: Send + Sync {
#[allow(dead_code)]
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>;
}
impl Data for KeyValueDatabase {
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
// Remove old entry
if let Some((old, _)) = self
.readreceiptid_readreceipt
.iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| {
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element")
== user_id.as_bytes()
}) {
// This is the old room_latest
self.readreceiptid_readreceipt.remove(&old)?;
}
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert(
&room_latest_id,
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
)?;
Ok(())
}
fn readreceipts_since<'a>(
&'a self, room_id: &RoomId, since: u64,
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since
Box::new(
self.readreceiptid_readreceipt
.iter_from(&first_possible_edu, false)
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
let user_id = UserId::parse(
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
)
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
json.remove("room_id");
Ok((
user_id,
count,
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
))
}),
)
}
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes())
}
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.get(&key)?
.map_or(Ok(None), |v| {
Ok(Some(
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
))
})
}
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastprivatereadupdate
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId};
use crate::{services, Result};

View file

@ -1,6 +1,6 @@
use ruma::RoomId;
use crate::Result;
use crate::{services, utils, KeyValueDatabase, Result};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
@ -9,3 +9,62 @@ pub trait Data: Send + Sync {
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>;
}
impl Data for KeyValueDatabase {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)
.map(str::to_lowercase)
.map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
(key, Vec::new())
});
self.tokenids.insert_batch(&mut batch)
}
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let words: Vec<_> = search_string
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(str::to_lowercase)
.collect();
let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
self.tokenids
.iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(key, _)| key[prefix3.len()..].to_vec())
});
let Some(common_elements) = utils::common_elements(iterators, |a, b| {
// We compare b with a because we reversed the iterator earlier
b.cmp(a)
}) else {
return Ok(None);
};
Ok(Some((Box::new(common_elements), words)))
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::RoomId;
use crate::Result;

View file

@ -1,8 +1,9 @@
use std::sync::Arc;
use ruma::{events::StateEventType, EventId, RoomId};
use tracing::warn;
use crate::Result;
use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>;
@ -24,3 +25,161 @@ pub trait Data: Send + Sync {
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64>;
}
impl Data for KeyValueDatabase {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else {
let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
Ok(short)
}
fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len());
let keys = event_ids
.iter()
.map(|id| id.as_bytes())
.collect::<Vec<&[u8]>>();
for (i, short) in self
.eventid_shorteventid
.multi_get(&keys)?
.iter()
.enumerate()
{
match short {
Some(short) => ret.push(
utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
),
None => {
let short = services().globals.next_count()?;
self.eventid_shorteventid
.insert(keys[i], &short.to_be_bytes())?;
self.shorteventid_eventid
.insert(&short.to_be_bytes(), keys[i])?;
debug_assert!(ret.len() == i, "position of result must match input");
ret.push(short);
},
}
}
Ok(ret)
}
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = self
.statekey_shortstatekey
.get(&statekey_vec)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
})
.transpose()?;
Ok(short)
}
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
} else {
let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
shortstatekey
};
Ok(short)
}
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
Ok(event_id)
}
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
})?);
let state_key = utils::string_from_bytes(statekey_bytes)
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
let result = (event_type, state_key);
Ok(result)
}
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? {
(
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
)
} else {
let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
})
}
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
.transpose()
}
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else {
let short = services().globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
})
}
}

View file

@ -1,7 +1,7 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::Result;

View file

@ -3,7 +3,7 @@ use std::{collections::HashSet, sync::Arc};
use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard;
use crate::Result;
use crate::{utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
/// Returns the last state hash key added to the db for the given room.
@ -31,3 +31,70 @@ pub trait Data: Send + Sync {
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()>;
}
impl Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
}
fn set_room_state(
&self,
room_id: &RoomId,
new_shortstatehash: u64,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
}
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.clone();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
}
Ok(())
}
}

View file

@ -4,7 +4,7 @@ use std::{
sync::Arc,
};
pub use data::Data;
use data::Data;
use ruma::{
api::client::error::ErrorKind,
events::{

View file

@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{PduEvent, Result};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
#[async_trait]
pub trait Data: Send + Sync {
@ -46,3 +46,162 @@ pub trait Data: Send + Sync {
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>>;
}
#[async_trait]
impl Data for KeyValueDatabase {
#[allow(unused_qualifications)] // async traits
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let parsed = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
i = i.wrapping_add(1);
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
#[allow(unused_qualifications)] // async traits
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let (_, eventid) = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.clone(),
),
pdu,
);
}
i = i.wrapping_add(1);
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services()
.rooms
.short
.get_shortstatekey(event_type, state_key)?
else {
return Ok(None);
};
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
}
/// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")
})
})
.transpose()
})
}
/// Returns the full room state.
#[allow(unused_qualifications)] // async traits
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
}

View file

@ -4,7 +4,7 @@ use std::{
sync::{Arc, Mutex},
};
pub use data::Data;
use data::Data;
use lru_cache::LruCache;
use ruma::{
events::{

View file

@ -1,15 +1,21 @@
use std::{collections::HashSet, sync::Arc};
use itertools::Itertools;
use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use tracing::error;
use crate::{service::appservice::RegistrationInfo, Result};
use crate::{
appservice::RegistrationInfo,
services, user_is_local,
utils::{self},
Error, KeyValueDatabase, Result,
};
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
pub trait Data: Send + Sync {
@ -91,3 +97,609 @@ pub trait Data: Send + Sync {
#[allow(dead_code)]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()>;
}
impl Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])
}
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xFF);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if self
.roomuserid_joined
.scan_prefix(roomid_prefix.clone())
.count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count() == 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(())
}
fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate.insert(
&userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
}
Ok(())
}
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xFF);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate.insert(
&userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO
self.roomuserid_leftcount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
if self
.roomuserid_joined
.scan_prefix(roomid_prefix.clone())
.count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count() == 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(())
}
fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new();
let mut real_users = HashSet::new();
for joined in self.room_members(room_id).filter_map(Result::ok) {
joined_servers.insert(joined.server_name().to_owned());
if user_is_local(&joined) && !services().users.is_deactivated(&joined).unwrap_or(true) {
real_users.insert(joined);
}
joinedcount = joinedcount.saturating_add(1);
}
for _invited in self.room_members_invited(room_id).filter_map(Result::ok) {
invitedcount = invitedcount.saturating_add(1);
}
self.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
self.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
self.our_real_users_cache
.write()
.unwrap()
.insert(room_id.to_owned(), Arc::new(real_users));
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) {
if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?;
self.serverroomids.remove(&serverroom_id)?;
}
}
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
Ok(())
}
#[tracing::instrument(skip(self, room_id))]
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
let maybe = self
.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.cloned();
if let Some(users) = maybe {
Ok(users)
} else {
self.update_joined_count(room_id)?;
Ok(Arc::clone(
self.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.unwrap(),
))
}
}
#[tracing::instrument(skip(self, room_id, appservice))]
fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
let maybe = self
.appservice_in_room_cache
.read()
.unwrap()
.get(room_id)
.and_then(|map| map.get(&appservice.registration.id))
.copied();
if let Some(b) = maybe {
Ok(b)
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
services().globals.server_name(),
)
.ok();
let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|| self
.room_members(room_id)
.any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())));
self.appservice_in_room_cache
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default()
.insert(appservice.registration.id.clone(), in_room);
Ok(in_room)
}
}
/// Makes a user forget a room.
#[tracing::instrument(skip(self))]
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
Ok(())
}
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self))]
fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
}))
}
#[tracing::instrument(skip(self))]
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
let mut key = server.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.serverroomids.get(&key).map(|o| o.is_some())
}
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self))]
fn server_rooms<'a>(&'a self, server: &ServerName) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
}))
}
/// Returns an iterator over all joined members of a room.
#[tracing::instrument(skip(self))]
fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
}))
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self))]
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_joinedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self))]
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_invitedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self))]
fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuseroncejoinedids
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
}),
)
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self))]
fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuserid_invitecount
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
}),
)
}
#[tracing::instrument(skip(self))]
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(&key)?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
))
})
}
#[tracing::instrument(skip(self))]
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_leftcount
.get(&key)?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db.")))
.transpose()
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self))]
fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(
self.userroomid_joined
.scan_prefix(user_id.as_bytes().to_vec())
.map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
}),
)
}
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self))]
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_invitestate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self))]
fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok(state)
})
.transpose()
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self))]
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_leftstate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self))]
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
self.roomid_inviteviaservers
.get(&key)?
.map(|servers| {
let state = serde_json::from_slice(&servers).map_err(|e| {
error!("Invalid state in userroomid_leftstate: {e}");
Error::bad_database("Invalid state in userroomid_leftstate.")
})?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
prev_servers.append(servers.to_owned().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
Ok(())
}
}

View file

@ -1,6 +1,6 @@
use std::{collections::HashSet, sync::Arc};
pub use data::Data;
use data::Data;
use itertools::Itertools;
use ruma::{
events::{

View file

@ -1,7 +1,7 @@
use std::{collections::HashSet, sync::Arc};
use std::{collections::HashSet, mem::size_of, sync::Arc};
use super::CompressedStateEvent;
use crate::Result;
use crate::{utils, Error, KeyValueDatabase, Result};
pub struct StateDiff {
pub parent: Option<u64>,
@ -13,3 +13,60 @@ pub trait Data: Send + Sync {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff>;
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>;
}
impl Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let parent = if parent != 0 {
Some(parent)
} else {
None
};
let mut add_mode = true;
let mut added = HashSet::new();
let mut removed = HashSet::new();
let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i += size_of::<u64>();
continue;
}
if add_mode {
added.insert(v.try_into().expect("we checked the size above"));
} else {
removed.insert(v.try_into().expect("we checked the size above"));
}
i += 2 * size_of::<u64>();
}
Ok(StateDiff {
parent,
added: Arc::new(added),
removed: Arc::new(removed),
})
}
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);
}
if !diff.removed.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in diff.removed.iter() {
value.extend_from_slice(&removed[..]);
}
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)
}
}

View file

@ -1,11 +1,11 @@
pub mod data;
mod data;
use std::{
collections::HashSet,
mem::size_of,
sync::{Arc, Mutex},
};
pub use data::Data;
use data::Data;
use lru_cache::LruCache;
use ruma::{EventId, RoomId};

View file

@ -1,6 +1,8 @@
use std::mem;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{PduEvent, Result};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
@ -12,3 +14,71 @@ pub trait Data: Send + Sync {
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()>;
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>>;
}
impl Data for KeyValueDatabase {
fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(until - 1).to_be_bytes());
Ok(Box::new(
self.threadid_userids
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((count, pdu))
}),
))
}
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
let users = participants
.iter()
.map(|user| user.as_bytes())
.collect::<Vec<_>>()
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?;
Ok(())
}
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some(
users
.split(|b| *b == 0xFF)
.map(|bytes| {
UserId::parse(
utils::string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
)
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
})
.filter_map(Result::ok)
.collect(),
))
} else {
Ok(None)
}
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::{collections::BTreeMap, sync::Arc};
pub use data::Data;
use data::Data;
use ruma::{
api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads},
events::relation::BundledThread,

View file

@ -1,9 +1,10 @@
use std::sync::Arc;
use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use tracing::error;
use super::PduCount;
use crate::{PduEvent, Result};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
pub trait Data: Send + Sync {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount>;
@ -66,3 +67,292 @@ pub trait Data: Send + Sync {
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()>;
}
impl Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.unwrap()
.entry(room_id.to_owned())
{
hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::max())?
.find_map(|r| {
// Filter out buggy events
if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r);
}
r.ok()
}) {
Ok(*v.insert(last_count.0))
} else {
Ok(PduCount::Normal(0))
}
},
hash_map::Entry::Occupied(o) => Ok(*o.get()),
}
}
/// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pdu_id| pdu_count(&pdu_id))
.transpose()
}
/// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)
}
/// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
}
/// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
/// Returns the pdu.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
}
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(pdu) = self
.get_non_outlier_pdu(event_id)?
.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)?
.map(Arc::new)
{
Ok(Some(pdu))
} else {
Ok(None)
}
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.lasttimelinecount_cache
.lock()
.unwrap()
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
Ok(())
}
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(())
}
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
)?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
}
Ok(())
}
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
fn pdus_until<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
}
fn pdus_after<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
}
fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()> {
let mut notifies_batch = Vec::new();
let mut highlights_batch = Vec::new();
for user in notifies {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
notifies_batch.push(userroom_id);
}
for user in highlights {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id);
}
self.userroomid_notificationcount
.increment_batch(&mut notifies_batch.into_iter())?;
self.userroomid_highlightcount
.increment_batch(&mut highlights_batch.into_iter())?;
Ok(())
}
}
/// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 =
utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
if matches!(second_last_u64, Ok(0)) {
Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)))
} else {
Ok(PduCount::Normal(last_u64))
}
}
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}

View file

@ -1,11 +1,11 @@
pub mod data;
mod data;
use std::{
collections::{BTreeMap, HashMap, HashSet},
sync::Arc,
};
pub use data::Data;
use data::Data;
use rand::prelude::SliceRandom;
use ruma::{
api::{client::error::ErrorKind, federation},
@ -195,7 +195,7 @@ impl Service {
state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<Vec<u8>> {
// Coalesce database writes for the remainder of this scope.
let _cork = services().globals.db.cork_and_flush()?;
let _cork = services().globals.cork_and_flush()?;
let shortroomid = services()
.rooms

View file

@ -1,6 +1,6 @@
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::Result;
use crate::{services, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
@ -20,3 +20,137 @@ pub trait Data: Send + Sync {
&'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>>;
}
impl Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.roomuserid_lastnotificationread
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
Ok(())
}
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
})
}
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
}
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastnotificationread
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
}
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
}
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
})
.transpose()
}
fn get_shared_rooms<'a>(
&'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
self.userroomid_joined
.scan_prefix(prefix)
.map(|(key, _)| {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xFF)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0
.saturating_add(1); // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id)
})
.filter_map(Result::ok)
});
// We use the default compare function because keys are sorted correctly (not
// reversed)
Ok(Box::new(
utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}),
))
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::Result;

View file

@ -1,7 +1,7 @@
use ruma::ServerName;
use ruma::{ServerName, UserId};
use super::{Destination, SendingEvent};
use crate::Result;
use crate::{services, utils, Error, KeyValueDatabase, Result};
type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>;
type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>;
@ -23,3 +23,188 @@ pub trait Data: Send + Sync {
fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>;
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64>;
}
impl Data for KeyValueDatabase {
fn active_requests<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a> {
Box::new(
self.servercurrentevent_data
.iter()
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
)
}
fn active_requests_for<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a> {
let prefix = destination.get_prefix();
Box::new(
self.servercurrentevent_data
.scan_prefix(prefix)
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
)
}
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?;
}
Ok(())
}
fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
self.servercurrentevent_data.remove(&key).unwrap();
}
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
self.servernameevent_data.remove(&key).unwrap();
}
Ok(())
}
fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>> {
let mut batch = Vec::new();
let mut keys = Vec::new();
for (destination, event) in requests {
let mut key = destination.get_prefix();
if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value);
} else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
}
let value = if let SendingEvent::Edu(value) = &event {
&**value
} else {
&[]
};
batch.push((key.clone(), value.to_owned()));
keys.push(key);
}
self.servernameevent_data
.insert_batch(&mut batch.into_iter())?;
Ok(keys)
}
fn queued_requests<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a> {
let prefix = destination.get_prefix();
return Box::new(
self.servernameevent_data
.scan_prefix(prefix)
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
);
}
fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()> {
for (e, key) in events {
if key.is_empty() {
continue;
}
let value = if let SendingEvent::Edu(value) = &e {
&**value
} else {
&[]
};
self.servercurrentevent_data.insert(key, value)?;
self.servernameevent_data.remove(key)?;
}
Ok(())
}
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
}
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount
.get(server_name.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
})
}
}
#[tracing::instrument(skip(key))]
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, SendingEvent)> {
// Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
(
Destination::Appservice(server),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
},
)
} else if key.starts_with(b"$") {
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
let user = parts.next().expect("splitn always returns one element");
let user_string = utils::string_from_bytes(user)
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
let user_id =
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let pushkey = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let pushkey_string = utils::string_from_bytes(pushkey)
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
(
Destination::Push(user_id, pushkey_string),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
// I'm pretty sure this should never be called
SendingEvent::Edu(value)
},
)
} else {
let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
(
Destination::Normal(
ServerName::parse(server)
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
},
)
})
}

View file

@ -1,6 +1,6 @@
use std::{fmt::Debug, sync::Arc};
pub use data::Data;
use data::Data;
use ruma::{
api::{appservice::Registration, OutgoingRequest},
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
@ -81,7 +81,7 @@ impl Service {
pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> {
let dest = Destination::Push(user.to_owned(), pushkey);
let event = SendingEvent::Pdu(pdu_id.to_owned());
let _cork = services().globals.db.cork()?;
let _cork = services().globals.cork()?;
let keys = self.db.queue_requests(&[(&dest, event.clone())])?;
self.dispatch(Msg {
dest,
@ -94,7 +94,7 @@ impl Service {
pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Result<()> {
let dest = Destination::Appservice(appservice_id);
let event = SendingEvent::Pdu(pdu_id);
let _cork = services().globals.db.cork()?;
let _cork = services().globals.cork()?;
let keys = self.db.queue_requests(&[(&dest, event.clone())])?;
self.dispatch(Msg {
dest,
@ -121,7 +121,7 @@ impl Service {
.into_iter()
.map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned())))
.collect::<Vec<_>>();
let _cork = services().globals.db.cork()?;
let _cork = services().globals.cork()?;
let keys = self.db.queue_requests(
&requests
.iter()
@ -143,7 +143,7 @@ impl Service {
pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Result<()> {
let dest = Destination::Normal(server.to_owned());
let event = SendingEvent::Edu(serialized);
let _cork = services().globals.db.cork()?;
let _cork = services().globals.cork()?;
let keys = self.db.queue_requests(&[(&dest, event.clone())])?;
self.dispatch(Msg {
dest,
@ -170,7 +170,7 @@ impl Service {
.into_iter()
.map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone())))
.collect::<Vec<_>>();
let _cork = services().globals.db.cork()?;
let _cork = services().globals.cork()?;
let keys = self.db.queue_requests(
&requests
.iter()

View file

@ -100,7 +100,7 @@ impl Service {
fn handle_response_ok(
&self, dest: &Destination, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus,
) {
let _cork = services().globals.db.cork();
let _cork = services().globals.cork();
self.db
.delete_all_active_requests_for(dest)
.expect("all active requests deleted");
@ -173,7 +173,7 @@ impl Service {
return Ok(None);
}
let _cork = services().globals.db.cork();
let _cork = services().globals.cork();
let mut events = Vec::new();
// Must retry any previous transaction for this remote.
@ -187,7 +187,7 @@ impl Service {
}
// Compose the next transaction
let _cork = services().globals.db.cork();
let _cork = services().globals.cork();
if !new_events.is_empty() {
self.db.mark_as_active(&new_events)?;
for (e, _) in new_events {

View file

@ -1,8 +1,8 @@
use ruma::{DeviceId, TransactionId, UserId};
use crate::Result;
use crate::{KeyValueDatabase, Result};
pub trait Data: Send + Sync {
pub(crate) trait Data: Send + Sync {
fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()>;
@ -11,3 +11,32 @@ pub trait Data: Send + Sync {
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>>;
}
impl Data for KeyValueDatabase {
fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(&key, data)?;
Ok(())
}
fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction
self.userdevicetxnid_response.get(&key)
}
}

View file

@ -2,13 +2,13 @@ mod data;
use std::sync::Arc;
pub use data::Data;
use data::Data;
use ruma::{DeviceId, TransactionId, UserId};
use crate::Result;
pub struct Service {
pub db: Arc<dyn Data>,
pub(super) db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,8 +1,11 @@
use ruma::{api::client::uiaa::UiaaInfo, CanonicalJsonValue, DeviceId, UserId};
use ruma::{
api::client::{error::ErrorKind, uiaa::UiaaInfo},
CanonicalJsonValue, DeviceId, UserId,
};
use crate::Result;
use crate::{Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
pub(crate) trait Data: Send + Sync {
fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()>;
@ -15,3 +18,65 @@ pub trait Data: Send + Sync {
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo>;
}
impl Data for KeyValueDatabase {
fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()> {
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(())
}
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
self.userdevicesessionid_uiaarequest
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(ToOwned::to_owned)
}
fn update_uiaa_session(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?;
} else {
self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
}
Ok(())
}
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice(
&self
.userdevicesessionid_uiaainfo
.get(&userdevicesessionid)?
.ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
}
}

View file

@ -4,7 +4,7 @@ use std::sync::Arc;
use argon2::{PasswordHash, PasswordVerifier};
use conduit::{utils, Error, Result};
pub use data::Data;
use data::Data;
use ruma::{
api::client::{
error::ErrorKind,
@ -19,7 +19,7 @@ use crate::services;
pub const SESSION_ID_LENGTH: usize = 32;
pub struct Service {
pub db: Arc<dyn Data>,
pub(super) db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,14 +1,17 @@
use std::collections::BTreeMap;
use std::{collections::BTreeMap, mem::size_of};
use argon2::{password_hash::SaltString, PasswordHasher};
use ruma::{
api::client::{device::Device, filter::FilterDefinition},
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::AnyToDeviceEvent,
events::{AnyToDeviceEvent, StateEventType},
serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId,
uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
OwnedMxcUri, OwnedUserId, UInt, UserId,
};
use tracing::warn;
use crate::Result;
use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, Result};
pub trait Data: Send + Sync {
/// Check if a user has an account on this homeserver.
@ -144,3 +147,887 @@ pub trait Data: Send + Sync {
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>>;
}
impl Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) }
/// Check if account is deactivated
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
Ok(self
.userid_password
.get(user_id.as_bytes())?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))?
.is_empty())
}
/// Returns the number of users registered on this server.
fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) }
/// Find out which user an access token belongs to.
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
self.token_userdeviceid
.get(token.as_bytes())?
.map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xFF);
let user_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?;
let device_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?;
Ok(Some((
UserId::parse(
utils::string_from_bytes(user_bytes)
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?,
utils::string_from_bytes(device_bytes)
.map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?,
)))
})
}
/// Returns an iterator over all users on this homeserver.
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
Box::new(self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
}))
}
/// Returns a list of local users as list of usernames.
///
/// A user account is considered `local` if the length of it's password is
/// greater then zero.
fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self
.userid_password
.iter()
.filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw))
.collect();
Ok(users)
}
/// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.")
})?))
})
}
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = calculate_password_hash(password) {
self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(())
} else {
Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Password does not meet the requirements.",
))
}
} else {
self.userid_password.insert(user_id.as_bytes(), b"")?;
Ok(())
}
}
/// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Displayname in db is invalid."))?,
))
})
}
/// Sets a new displayname or removes it if displayname is None. You still
/// need to nofify all rooms of this change.
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname {
self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?;
} else {
self.userid_displayname.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Get the `avatar_url` of a user.
fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> {
self.userid_avatarurl
.get(user_id.as_bytes())?
.map(|bytes| {
let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| {
warn!("Avatar URL in db is invalid: {}", e);
Error::bad_database("Avatar URL in db is invalid.")
})?;
let mxc_uri: OwnedMxcUri = s_bytes.into();
Ok(mxc_uri)
})
.transpose()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
if let Some(avatar_url) = avatar_url {
self.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
} else {
self.userid_avatarurl.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Get the blurhash of a user.
fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_blurhash
.get(user_id.as_bytes())?
.map(|bytes| {
let s = utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?;
Ok(s)
})
.transpose()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
if let Some(blurhash) = blurhash {
self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
} else {
self.userid_blurhash.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Adds a new device to a user.
fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
) -> Result<()> {
// This method should never be called for nonexistent users. We shouldn't assert
// though...
if !self.exists(user_id)? {
warn!("Called create_device for non-existent user {} in database", user_id);
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
}
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(&Device {
device_id: device_id.into(),
display_name: initial_device_display_name,
last_seen_ip: None, // TODO
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
})
.expect("Device::to_string never fails."),
)?;
self.set_token(user_id, device_id, token)?;
Ok(())
}
/// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Remove tokens
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.userdeviceid_token.remove(&userdeviceid)?;
self.token_userdeviceid.remove(&old_token)?;
}
// Remove todevice events
let mut prefix = userdeviceid.clone();
prefix.push(0xFF);
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
self.todeviceid_events.remove(&key)?;
}
// TODO: Remove onetimekeys
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.remove(&userdeviceid)?;
Ok(())
}
/// Returns an iterator over all device ids of this user.
fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
// All devices have metadata
Box::new(
self.userdeviceid_metadata
.scan_prefix(prefix)
.map(|(bytes, _)| {
Ok(utils::string_from_bytes(
bytes
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
.into())
}),
)
}
/// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// should not be None, but we shouldn't assert either lol...
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
warn!(
"Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
// Remove old token
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.token_userdeviceid.remove(&old_token)?;
// It will be removed from userdeviceid_token by the insert later
}
// Assign token to user device combination
self.userdeviceid_token
.insert(&userdeviceid, token.as_bytes())?;
self.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid)?;
Ok(())
}
fn add_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
// All devices have metadata
// Only existing devices should be able to call this, but we shouldn't assert
// either...
if self.userdeviceid_metadata.get(&key)?.is_none() {
warn!(
"Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \
database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
key.push(0xFF);
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore)
key.extend_from_slice(
serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works")
.as_bytes(),
);
self.onetimekeyid_onetimekeys.insert(
&key,
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
)?;
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
Ok(())
}
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
})
}
fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.push(b'"'); // Annoying quotation mark
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
prefix.push(b':');
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys
.scan_prefix(prefix)
.next()
.map(|(key, value)| {
self.onetimekeyid_onetimekeys.remove(&key)?;
Ok((
serde_json::from_slice(
key.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
)
.map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?,
))
})
.transpose()
}
fn count_one_time_keys(
&self, user_id: &UserId, device_id: &DeviceId,
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
let mut counts = BTreeMap::new();
for algorithm in self
.onetimekeyid_onetimekeys
.scan_prefix(userdeviceid)
.map(|(bytes, _)| {
Ok::<_, Error>(
serde_json::from_slice::<OwnedDeviceKeyId>(
bytes
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
.algorithm(),
)
}) {
*counts.entry(algorithm?).or_default() += uint!(1);
}
Ok(counts)
}
fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.keyid_key.insert(
&userdeviceid,
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
)?;
self.mark_device_key_update(user_id)?;
Ok(())
}
fn add_cross_signing_keys(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
) -> Result<()> {
// TODO: Check signatures
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
self.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes())?;
self.userid_masterkeyid
.insert(user_id.as_bytes(), &master_key_key)?;
// Self-signing key
if let Some(self_signing_key) = self_signing_key {
let mut self_signing_key_ids = self_signing_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))?
.keys
.into_values();
let self_signing_key_id = self_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained more than one key.",
));
}
let mut self_signing_key_key = prefix.clone();
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
self.keyid_key
.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?;
self.userid_selfsigningkeyid
.insert(user_id.as_bytes(), &self_signing_key_key)?;
}
// User-signing key
if let Some(user_signing_key) = user_signing_key {
let mut user_signing_key_ids = user_signing_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))?
.keys
.into_values();
let user_signing_key_id = user_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?;
if user_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"User signing key contained more than one key.",
));
}
let mut user_signing_key_key = prefix;
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
self.keyid_key
.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?;
self.userid_usersigningkeyid
.insert(user_id.as_bytes(), &user_signing_key_key)?;
}
if notify {
self.mark_device_key_update(user_id)?;
}
Ok(())
}
fn sign_key(
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
) -> Result<()> {
let mut key = target_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(key_id.as_bytes());
let mut cross_signing_key: serde_json::Value = serde_json::from_slice(
&self
.keyid_key
.get(&key)?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?,
)
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
let signatures = cross_signing_key
.get_mut("signatures")
.ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))?
.as_object_mut()
.ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))?
.entry(sender_id.to_string())
.or_insert_with(|| serde_json::Map::new().into());
signatures
.as_object_mut()
.ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))?
.insert(signature.0, signature.1.into());
self.keyid_key.insert(
&key,
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
)?;
self.mark_device_key_update(target_id)?;
Ok(())
}
fn keys_changed<'a>(
&'a self, user_or_room_id: &str, from: u64, to: Option<u64>,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = user_or_room_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut start = prefix.clone();
start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes());
let to = to.unwrap_or(u64::MAX);
Box::new(
self.keychangeid_userid
.iter_from(&start, false)
.take_while(move |(k, _)| {
k.starts_with(&prefix)
&& if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) {
if let Ok(c) = utils::u64_from_bytes(current) {
c <= to
} else {
warn!("BadDatabase: Could not parse keychangeid_userid bytes");
false
}
} else {
warn!("BadDatabase: Could not parse keychangeid_userid");
false
}
})
.map(|(_, bytes)| {
UserId::parse(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid."))
}),
)
}
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes();
for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
// Don't send key updates to unencrypted rooms
if services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none()
{
continue;
}
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
Ok(())
}
fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(
serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?,
))
})
}
fn parse_master_key(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let master_key = master_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
let mut master_key_ids = master_key.keys.values();
let master_key_id = master_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
if master_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Master key contained more than one key.",
));
}
let mut master_key_key = prefix.clone();
master_key_key.extend_from_slice(master_key_id.as_bytes());
Ok((master_key_key, master_key))
}
fn get_key(
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?;
Ok(Some(Raw::from_json(
serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"),
)))
})
}
fn get_master_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_masterkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
}
fn get_self_signing_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_selfsigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
}
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| {
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?,
))
})
})
}
fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value,
) -> Result<()> {
let mut key = target_user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into());
json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content);
let value = serde_json::to_vec(&json).expect("Map::to_vec always works");
self.todeviceid_events.insert(&key, &value)?;
Ok(())
}
fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
let mut events = Vec::new();
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
events.push(
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
);
}
Ok(events)
}
fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
let mut last = prefix.clone();
last.extend_from_slice(&until.to_be_bytes());
for (key, _) in self
.todeviceid_events
.iter_from(&last, true) // this includes last
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(key, _)| {
Ok::<_, Error>((
key.clone(),
utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()])
.map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?,
))
})
.filter_map(Result::ok)
.take_while(|&(_, count)| count <= until)
{
self.todeviceid_events.remove(&key)?;
}
Ok(())
}
fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Only existing devices should be able to call this, but we shouldn't assert
// either...
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
warn!(
"Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \
metadata in database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(device).expect("Device::to_string always works"),
)?;
Ok(())
}
/// Get device metadata.
fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userdeviceid_metadata
.get(&userdeviceid)?
.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
})?))
})
}
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid devicelistversion in db."))
.map(Some)
})
}
fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
Box::new(
self.userdeviceid_metadata
.scan_prefix(key)
.map(|(_, bytes)| {
serde_json::from_slice::<Device>(&bytes)
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
}),
)
}
/// Creates a new sync filter. Returns the filter id.
fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> {
let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
self.userfilterid_filter
.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?;
Ok(filter_id)
}
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
let raw = self.userfilterid_filter.get(&key)?;
if let Some(raw) = raw {
serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db."))
} else {
Ok(None)
}
}
}
/// Will only return with Some(username) if the password was not empty and the
/// username could be successfully parsed.
/// If `utils::string_from_bytes`(...) returns an error that username will be
/// skipped and the error will be logged.
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
// A valid password is not empty
if password.is_empty() {
None
} else {
match utils::string_from_bytes(username) {
Ok(u) => Some(u),
Err(e) => {
warn!("Failed to parse username while calling get_local_users(): {}", e.to_string());
None
},
}
}
}
/// Calculate a new hash for the given password
fn calculate_password_hash(password: &str) -> Result<String, argon2::password_hash::Error> {
let salt = SaltString::generate(rand::thread_rng());
services()
.globals
.argon
.hash_password(password.as_bytes(), &salt)
.map(|it| it.to_string())
}

View file

@ -5,7 +5,7 @@ use std::{
sync::{Arc, Mutex},
};
pub use data::Data;
use data::Data;
use ruma::{
api::client::{
device::Device,