diff --git a/src/client_server/context.rs b/src/client_server/context.rs index 97fc4fd8..d7aaf0ce 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -1,5 +1,9 @@ use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; -use ruma::api::client::{error::ErrorKind, r0::context::get_context}; +use ruma::{ + api::client::{error::ErrorKind, r0::context::get_context}, + events::EventType, +}; +use std::collections::HashSet; use std::convert::TryFrom; #[cfg(feature = "conduit_bin")] @@ -21,6 +25,7 @@ pub async fn get_context_route( body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); if !db.rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( @@ -29,6 +34,8 @@ pub async fn get_context_route( )); } + let mut lazy_loaded = HashSet::new(); + let base_pdu_id = db .rooms .get_pdu_id(&body.event_id)? @@ -45,8 +52,18 @@ pub async fn get_context_route( .ok_or(Error::BadRequest( ErrorKind::NotFound, "Base event not found.", - ))? - .to_room_event(); + ))?; + + if !db.rooms.lazy_load_was_sent_before( + &sender_user, + &sender_device, + &body.room_id, + &base_event.sender, + )? { + lazy_loaded.insert(base_event.sender.clone()); + } + + let base_event = base_event.to_room_event(); let events_before: Vec<_> = db .rooms @@ -60,6 +77,17 @@ pub async fn get_context_route( .filter_map(|r| r.ok()) // Remove buggy events .collect(); + for (_, event) in &events_before { + if !db.rooms.lazy_load_was_sent_before( + &sender_user, + &sender_device, + &body.room_id, + &event.sender, + )? { + lazy_loaded.insert(event.sender.clone()); + } + } + let start_token = events_before .last() .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) @@ -82,6 +110,17 @@ pub async fn get_context_route( .filter_map(|r| r.ok()) // Remove buggy events .collect(); + for (_, event) in &events_after { + if !db.rooms.lazy_load_was_sent_before( + &sender_user, + &sender_device, + &body.room_id, + &event.sender, + )? { + lazy_loaded.insert(event.sender.clone()); + } + } + let end_token = events_after .last() .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) @@ -98,12 +137,15 @@ pub async fn get_context_route( resp.events_before = events_before; resp.event = Some(base_event); resp.events_after = events_after; - resp.state = db // TODO: State at event - .rooms - .room_state_full(&body.room_id)? - .values() - .map(|pdu| pdu.to_state_event()) - .collect(); + resp.state = Vec::new(); + for ll_id in &lazy_loaded { + if let Some(member_event) = + db.rooms + .room_state_get(&body.room_id, &EventType::RoomMember, ll_id.as_str())? + { + resp.state.push(member_event.to_state_event()); + } + } Ok(resp.into()) } diff --git a/src/client_server/filter.rs b/src/client_server/filter.rs index dfb53770..f8845f1e 100644 --- a/src/client_server/filter.rs +++ b/src/client_server/filter.rs @@ -1,32 +1,47 @@ -use crate::{utils, ConduitResult}; -use ruma::api::client::r0::filter::{self, create_filter, get_filter}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; +use ruma::api::client::{ + error::ErrorKind, + r0::filter::{create_filter, get_filter}, +}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// -/// TODO: Loads a filter that was previously created. -#[cfg_attr(feature = "conduit_bin", get("/_matrix/client/r0/user/<_>/filter/<_>"))] -#[tracing::instrument] -pub async fn get_filter_route() -> ConduitResult { - // TODO - Ok(get_filter::Response::new(filter::IncomingFilterDefinition { - event_fields: None, - event_format: filter::EventFormat::default(), - account_data: filter::IncomingFilter::default(), - room: filter::IncomingRoomFilter::default(), - presence: filter::IncomingFilter::default(), - }) - .into()) +/// Loads a filter that was previously created. +/// +/// - A user can only access their own filters +#[cfg_attr( + feature = "conduit_bin", + get("/_matrix/client/r0/user/<_>/filter/<_>", data = "") +)] +#[tracing::instrument(skip(db, body))] +pub async fn get_filter_route( + db: DatabaseGuard, + body: Ruma>, +) -> ConduitResult { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let filter = match db.users.get_filter(sender_user, &body.filter_id)? { + Some(filter) => filter, + None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")), + }; + + Ok(get_filter::Response::new(filter).into()) } /// # `PUT /_matrix/client/r0/user/{userId}/filter` /// -/// TODO: Creates a new filter to be used by other endpoints. -#[cfg_attr(feature = "conduit_bin", post("/_matrix/client/r0/user/<_>/filter"))] -#[tracing::instrument] -pub async fn create_filter_route() -> ConduitResult { - // TODO - Ok(create_filter::Response::new(utils::random_string(10)).into()) +/// Creates a new filter to be used by other endpoints. +#[cfg_attr( + feature = "conduit_bin", + post("/_matrix/client/r0/user/<_>/filter", data = "") +)] +#[tracing::instrument(skip(db, body))] +pub async fn create_filter_route( + db: DatabaseGuard, + body: Ruma>, +) -> ConduitResult { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + Ok(create_filter::Response::new(db.users.create_filter(sender_user, &body.filter)?).into()) } diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 60c756a3..7e1e6a76 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -6,7 +6,11 @@ use ruma::{ }, events::EventType, }; -use std::{collections::BTreeMap, convert::TryInto, sync::Arc}; +use std::{ + collections::{BTreeMap, HashSet}, + convert::TryInto, + sync::Arc, +}; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; @@ -117,6 +121,7 @@ pub async fn get_message_events_route( body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); if !db.rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( @@ -133,9 +138,18 @@ pub async fn get_message_events_route( let to = body.to.as_ref().map(|t| t.parse()); + db.rooms + .lazy_load_confirm_delivery(&sender_user, &sender_device, &body.room_id, from)?; + // Use limit or else 10 let limit = body.limit.try_into().map_or(10_usize, |l: u32| l as usize); + let next_token; + + let mut resp = get_message_events::Response::new(); + + let mut lazy_loaded = HashSet::new(); + match body.dir { get_message_events::Direction::Forward => { let events_after: Vec<_> = db @@ -152,7 +166,18 @@ pub async fn get_message_events_route( .take_while(|&(k, _)| Some(Ok(k)) != to) // Stop at `to` .collect(); - let end_token = events_after.last().map(|(count, _)| count.to_string()); + for (_, event) in &events_after { + if !db.rooms.lazy_load_was_sent_before( + &sender_user, + &sender_device, + &body.room_id, + &event.sender, + )? { + lazy_loaded.insert(event.sender.clone()); + } + } + + next_token = events_after.last().map(|(count, _)| count).copied(); let events_after: Vec<_> = events_after .into_iter() @@ -161,11 +186,8 @@ pub async fn get_message_events_route( let mut resp = get_message_events::Response::new(); resp.start = Some(body.from.to_owned()); - resp.end = end_token; + resp.end = next_token.map(|count| count.to_string()); resp.chunk = events_after; - resp.state = Vec::new(); - - Ok(resp.into()) } get_message_events::Direction::Backward => { let events_before: Vec<_> = db @@ -182,20 +204,49 @@ pub async fn get_message_events_route( .take_while(|&(k, _)| Some(Ok(k)) != to) // Stop at `to` .collect(); - let start_token = events_before.last().map(|(count, _)| count.to_string()); + for (_, event) in &events_before { + if !db.rooms.lazy_load_was_sent_before( + &sender_user, + &sender_device, + &body.room_id, + &event.sender, + )? { + lazy_loaded.insert(event.sender.clone()); + } + } + + next_token = events_before.last().map(|(count, _)| count).copied(); let events_before: Vec<_> = events_before .into_iter() .map(|(_, pdu)| pdu.to_room_event()) .collect(); - let mut resp = get_message_events::Response::new(); resp.start = Some(body.from.to_owned()); - resp.end = start_token; + resp.end = next_token.map(|count| count.to_string()); resp.chunk = events_before; - resp.state = Vec::new(); - - Ok(resp.into()) } } + + resp.state = Vec::new(); + for ll_id in &lazy_loaded { + if let Some(member_event) = + db.rooms + .room_state_get(&body.room_id, &EventType::RoomMember, ll_id.as_str())? + { + resp.state.push(member_event.to_state_event()); + } + } + + if let Some(next_token) = next_token { + db.rooms.lazy_load_mark_sent( + &sender_user, + &sender_device, + &body.room_id, + lazy_loaded, + next_token, + ); + } + + Ok(resp.into()) } diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 9ba3b7fb..7e55b4a3 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,6 +1,10 @@ use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use ruma::{ - api::client::r0::{sync::sync_events, uiaa::UiaaResponse}, + api::client::r0::{ + filter::{IncomingFilterDefinition, LazyLoadOptions}, + sync::sync_events, + uiaa::UiaaResponse, + }, events::{ room::member::{MembershipState, RoomMemberEventContent}, AnySyncEphemeralRoomEvent, EventType, @@ -36,13 +40,15 @@ use rocket::{get, tokio}; /// Calling this endpoint with a `since` parameter from a previous `next_batch` returns: /// For joined rooms: /// - Some of the most recent events of each timeline that happened after since -/// - If user joined the room after since: All state events and device list updates in that room +/// - If user joined the room after since: All state events (unless lazy loading is activated) and +/// all device list updates in that room /// - If the user was already in the room: A list of all events that are in the state now, but were /// not in the state at `since` /// - If the state we send contains a member event: Joined and invited member counts, heroes /// - Device list updates that happened after `since` /// - If there are events in the timeline we send or the user send updated his read mark: Notification counts /// - EDUs that are active now (read receipts, typing updates, presence) +/// - TODO: Allow multiple sync streams to support Pantalaimon /// /// For invited rooms: /// - If the user was invited after `since`: A subset of the state of the room at the point of the invite @@ -77,34 +83,32 @@ pub async fn sync_events_route( Entry::Vacant(v) => { let (tx, rx) = tokio::sync::watch::channel(None); + v.insert((body.since.clone(), rx.clone())); + tokio::spawn(sync_helper_wrapper( Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), - body.since.clone(), - body.full_state, - body.timeout, + body, tx, )); - v.insert((body.since.clone(), rx)).1.clone() + rx } Entry::Occupied(mut o) => { if o.get().0 != body.since { let (tx, rx) = tokio::sync::watch::channel(None); + o.insert((body.since.clone(), rx.clone())); + tokio::spawn(sync_helper_wrapper( Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), - body.since.clone(), - body.full_state, - body.timeout, + body, tx, )); - o.insert((body.since.clone(), rx.clone())); - rx } else { o.get().1.clone() @@ -135,18 +139,16 @@ async fn sync_helper_wrapper( db: Arc, sender_user: Box, sender_device: Box, - since: Option, - full_state: bool, - timeout: Option, + body: sync_events::IncomingRequest, tx: Sender>>, ) { + let since = body.since.clone(); + let r = sync_helper( Arc::clone(&db), sender_user.clone(), sender_device.clone(), - since.clone(), - full_state, - timeout, + body, ) .await; @@ -179,9 +181,7 @@ async fn sync_helper( db: Arc, sender_user: Box, sender_device: Box, - since: Option, - full_state: bool, - timeout: Option, + body: sync_events::IncomingRequest, // bool = caching allowed ) -> Result<(sync_events::Response, bool), Error> { // TODO: match body.set_presence { @@ -193,8 +193,26 @@ async fn sync_helper( let next_batch = db.globals.current_count()?; let next_batch_string = next_batch.to_string(); + // Load filter + let filter = match body.filter { + None => IncomingFilterDefinition::default(), + Some(sync_events::IncomingFilter::FilterDefinition(filter)) => filter, + Some(sync_events::IncomingFilter::FilterId(filter_id)) => db + .users + .get_filter(&sender_user, &filter_id)? + .unwrap_or_default(), + }; + + let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { + LazyLoadOptions::Enabled { + include_redundant_members: redundant, + } => (true, redundant), + _ => (false, false), + }; + let mut joined_rooms = BTreeMap::new(); - let since = since + let since = body + .since .clone() .and_then(|string| string.parse().ok()) .unwrap_or(0); @@ -264,6 +282,14 @@ async fn sync_helper( // limited unless there are events in non_timeline_pdus let limited = non_timeline_pdus.next().is_some(); + let mut timeline_users = HashSet::new(); + for (_, event) in &timeline_pdus { + timeline_users.insert(event.sender.as_str().to_owned()); + } + + db.rooms + .lazy_load_confirm_delivery(&sender_user, &sender_device, &room_id, since)?; + // Database queries: let current_shortstatehash = db @@ -344,14 +370,58 @@ async fn sync_helper( state_events, ) = if since_shortstatehash.is_none() { // Probably since = 0, we will do an initial sync + let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; - let state_events: Vec<_> = current_state_ids - .iter() - .map(|(_, id)| db.rooms.get_pdu(id)) - .filter_map(|r| r.ok().flatten()) - .collect(); + + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); + + for (shortstatekey, id) in current_state_ids { + let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; + + if event_type != EventType::RoomMember { + let pdu = match db.rooms.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + } + }; + state_events.push(pdu); + } else if !lazy_load_enabled + || body.full_state + || timeline_users.contains(&state_key) + { + let pdu = match db.rooms.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + } + }; + lazy_loaded.insert( + UserId::parse(state_key.as_ref()) + .expect("they are in timeline_users, so they should be correct"), + ); + state_events.push(pdu); + } + } + + // Reset lazy loading because this is an initial sync + db.rooms + .lazy_load_reset(&sender_user, &sender_device, &room_id)?; + + // The state_events above should contain all timeline_users, let's mark them as lazy + // loaded. + db.rooms.lazy_load_mark_sent( + &sender_user, + &sender_device, + &room_id, + lazy_loaded, + next_batch, + ); ( heroes, @@ -387,20 +457,67 @@ async fn sync_helper( let since_state_ids = db.rooms.state_full_ids(since_shortstatehash)?; - let state_events = if joined_since_last_sync { - current_state_ids - .iter() - .map(|(_, id)| db.rooms.get_pdu(id)) - .filter_map(|r| r.ok().flatten()) - .collect::>() - } else { - current_state_ids - .iter() - .filter(|(key, id)| since_state_ids.get(key) != Some(id)) - .map(|(_, id)| db.rooms.get_pdu(id)) - .filter_map(|r| r.ok().flatten()) - .collect() - }; + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); + + for (key, id) in current_state_ids { + if body.full_state || since_state_ids.get(&key) != Some(&id) { + let pdu = match db.rooms.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + } + }; + + if pdu.kind == EventType::RoomMember { + match UserId::parse( + pdu.state_key + .as_ref() + .expect("State event has state key") + .clone(), + ) { + Ok(state_key_userid) => { + lazy_loaded.insert(state_key_userid); + } + Err(e) => error!("Invalid state key for member event: {}", e), + } + } + + state_events.push(pdu); + } + } + + for (_, event) in &timeline_pdus { + if lazy_loaded.contains(&event.sender) { + continue; + } + + if !db.rooms.lazy_load_was_sent_before( + &sender_user, + &sender_device, + &room_id, + &event.sender, + )? || lazy_load_send_redundant + { + if let Some(member_event) = db.rooms.room_state_get( + &room_id, + &EventType::RoomMember, + event.sender.as_str(), + )? { + lazy_loaded.insert(event.sender.clone()); + state_events.push(member_event); + } + } + } + + db.rooms.lazy_load_mark_sent( + &sender_user, + &sender_device, + &room_id, + lazy_loaded, + next_batch, + ); let encrypted_room = db .rooms @@ -765,7 +882,7 @@ async fn sync_helper( }; // TODO: Retry the endpoint instead of returning (waiting for #118) - if !full_state + if !body.full_state && response.rooms.is_empty() && response.presence.is_empty() && response.account_data.is_empty() @@ -774,7 +891,7 @@ async fn sync_helper( { // Hang a few seconds so requests are not spammed // Stop hanging if new info arrives - let mut duration = timeout.unwrap_or_default(); + let mut duration = body.timeout.unwrap_or_default(); if duration.as_secs() > 30 { duration = Duration::from_secs(30); } diff --git a/src/database.rs b/src/database.rs index bd05798b..edb9d8af 100644 --- a/src/database.rs +++ b/src/database.rs @@ -249,6 +249,7 @@ impl Database { userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, + userfilterid_filter: builder.open_tree("userfilterid_filter")?, todeviceid_events: builder.open_tree("todeviceid_events")?, }, uiaa: uiaa::Uiaa { @@ -289,6 +290,8 @@ impl Database { userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + lazyloadedids: builder.open_tree("lazyloadedids")?, + userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, @@ -324,6 +327,7 @@ impl Database { statekeyshort_cache: Mutex::new(LruCache::new(1_000_000)), our_real_users_cache: RwLock::new(HashMap::new()), appservice_in_room_cache: RwLock::new(HashMap::new()), + lazy_load_waiting: Mutex::new(HashMap::new()), stateinfo_cache: Mutex::new(LruCache::new(1000)), }, account_data: account_data::AccountData { diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index c0fda301..2f15f727 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -1,10 +1,11 @@ use super::super::Config; +use super::{DatabaseEngine, Tree}; use crate::{utils, Result}; -use std::{future::Future, pin::Pin, sync::Arc}; use std::{ collections::{hash_map, HashMap}, - sync::RwLock}; -use super::{DatabaseEngine, Tree}; + sync::RwLock, +}; +use std::{future::Future, pin::Pin, sync::Arc}; use tokio::sync::watch; pub struct Engine { diff --git a/src/database/rooms.rs b/src/database/rooms.rs index fb9ecbf0..ff18cd47 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -28,7 +28,7 @@ use ruma::{ push::{Action, Ruleset, Tweak}, serde::{CanonicalJsonObject, CanonicalJsonValue, Raw}, state_res::{self, RoomVersion, StateMap}, - uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, + uint, DeviceId, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, }; use serde::Deserialize; use serde_json::value::to_raw_value; @@ -79,6 +79,8 @@ pub struct Rooms { pub(super) userroomid_leftstate: Arc, pub(super) roomuserid_leftcount: Arc, + pub(super) lazyloadedids: Arc, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId + pub(super) userroomid_notificationcount: Arc, // NotifyCount = u64 pub(super) userroomid_highlightcount: Arc, // HightlightCount = u64 @@ -117,6 +119,8 @@ pub struct Rooms { pub(super) shortstatekey_cache: Mutex>, pub(super) our_real_users_cache: RwLock, Arc>>>>, pub(super) appservice_in_room_cache: RwLock, HashMap>>, + pub(super) lazy_load_waiting: + Mutex, Box, Box, u64), HashSet>>>, pub(super) stateinfo_cache: Mutex< LruCache< u64, @@ -3453,4 +3457,94 @@ impl Rooms { Ok(()) } + + #[tracing::instrument(skip(self))] + pub fn lazy_load_was_sent_before( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ll_user: &UserId, + ) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&device_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(&room_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(&ll_user.as_bytes()); + Ok(self.lazyloadedids.get(&key)?.is_some()) + } + + #[tracing::instrument(skip(self))] + pub fn lazy_load_mark_sent( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + lazy_load: HashSet>, + count: u64, + ) { + self.lazy_load_waiting.lock().unwrap().insert( + ( + user_id.to_owned(), + device_id.to_owned(), + room_id.to_owned(), + count, + ), + lazy_load, + ); + } + + #[tracing::instrument(skip(self))] + pub fn lazy_load_confirm_delivery( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + since: u64, + ) -> Result<()> { + if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( + user_id.to_owned(), + device_id.to_owned(), + room_id.to_owned(), + since, + )) { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(&device_id.as_bytes()); + prefix.push(0xff); + prefix.extend_from_slice(&room_id.as_bytes()); + prefix.push(0xff); + + for ll_id in user_ids { + let mut key = prefix.clone(); + key.extend_from_slice(&ll_id.as_bytes()); + self.lazyloadedids.insert(&key, &[])?; + } + } + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub fn lazy_load_reset( + &self, + user_id: &Box, + device_id: &Box, + room_id: &Box, + ) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(&device_id.as_bytes()); + prefix.push(0xff); + prefix.extend_from_slice(&room_id.as_bytes()); + prefix.push(0xff); + + for (key, _) in self.lazyloadedids.scan_prefix(prefix) { + self.lazyloadedids.remove(&key)?; + } + + Ok(()) + } } diff --git a/src/database/users.rs b/src/database/users.rs index d4bf4890..becfc736 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -1,6 +1,9 @@ use crate::{utils, Error, Result}; use ruma::{ - api::client::{error::ErrorKind, r0::device::Device}, + api::client::{ + error::ErrorKind, + r0::{device::Device, filter::IncomingFilterDefinition}, + }, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, events::{AnyToDeviceEvent, EventType}, identifiers::MxcUri, @@ -31,6 +34,8 @@ pub struct Users { pub(super) userid_selfsigningkeyid: Arc, pub(super) userid_usersigningkeyid: Arc, + pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId + pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count } @@ -990,4 +995,45 @@ impl Users { // TODO: Unhook 3PID Ok(()) } + + /// Creates a new sync filter. Returns the filter id. + #[tracing::instrument(skip(self))] + pub fn create_filter( + &self, + user_id: &UserId, + filter: &IncomingFilterDefinition, + ) -> Result { + let filter_id = utils::random_string(4); + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(filter_id.as_bytes()); + + self.userfilterid_filter.insert( + &key, + &serde_json::to_vec(&filter).expect("filter is valid json"), + )?; + + Ok(filter_id) + } + + #[tracing::instrument(skip(self))] + pub fn get_filter( + &self, + user_id: &UserId, + filter_id: &str, + ) -> Result> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(filter_id.as_bytes()); + + let raw = self.userfilterid_filter.get(&key)?; + + if let Some(raw) = raw { + serde_json::from_slice(&raw) + .map_err(|_| Error::bad_database("Invalid filter event in db.")) + } else { + Ok(None) + } + } }