diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 7ce3b5bd..21a9ef2b 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -12,7 +12,7 @@ use std::{ time::Duration, }; use tokio::sync::watch::Sender; -use tracing::{error, warn}; +use tracing::error; #[cfg(feature = "conduit_bin")] use rocket::{get, tokio}; diff --git a/src/database.rs b/src/database.rs index 31c76290..8a177163 100644 --- a/src/database.rs +++ b/src/database.rs @@ -279,6 +279,8 @@ impl Database { statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, + shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, + roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, @@ -292,7 +294,7 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, referencedevents: builder.open_tree("referencedevents")?, pdu_cache: Mutex::new(LruCache::new(100_000)), - auth_chain_cache: Mutex::new(LruCache::new(100_000)), + auth_chain_cache: Mutex::new(LruCache::new(1_000_000)), shorteventid_cache: Mutex::new(LruCache::new(1_000_000)), eventidshort_cache: Mutex::new(LruCache::new(1_000_000)), shortstatekey_cache: Mutex::new(LruCache::new(1_000_000)), diff --git a/src/database/rooms.rs b/src/database/rooms.rs index ad8ffc84..161be60e 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -86,6 +86,8 @@ pub struct Rooms { pub(super) statehash_shortstatehash: Arc, pub(super) shortstatehash_statediff: Arc, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) + pub(super) shorteventid_authchain: Arc, + /// RoomId + EventId -> outlier PDU. /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. pub(super) eventid_outlierpdu: Arc, @@ -94,7 +96,7 @@ pub struct Rooms { pub(super) referencedevents: Arc, pub(super) pdu_cache: Mutex>>, - pub(super) auth_chain_cache: Mutex, HashSet>>, + pub(super) auth_chain_cache: Mutex, Arc>>>, pub(super) shorteventid_cache: Mutex>, pub(super) eventidshort_cache: Mutex>, pub(super) statekeyshort_cache: Mutex>, @@ -3204,7 +3206,64 @@ impl Rooms { } #[tracing::instrument(skip(self))] - pub fn auth_chain_cache(&self) -> std::sync::MutexGuard<'_, LruCache, HashSet>> { - self.auth_chain_cache.lock().unwrap() + pub fn get_auth_chain_from_cache<'a>( + &'a self, + key: &[u64], + ) -> Result>>> { + // Check RAM cache + if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { + return Ok(Some(Arc::clone(result))); + } + + // Check DB cache + if key.len() == 1 { + if let Some(chain) = + self.shorteventid_authchain + .get(&key[0].to_be_bytes())? + .map(|chain| { + chain + .chunks_exact(size_of::()) + .map(|chunk| { + utils::u64_from_bytes(chunk).expect("byte length is correct") + }) + .collect() + }) + { + let chain = Arc::new(chain); + + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(vec![key[0]], Arc::clone(&chain)); + + return Ok(Some(chain)); + } + } + + Ok(None) + } + + #[tracing::instrument(skip(self))] + pub fn cache_auth_chain(&self, key: Vec, chain: Arc>) -> Result<()> { + // Persist in db + if key.len() == 1 { + self.shorteventid_authchain.insert( + &key[0].to_be_bytes(), + &chain + .iter() + .map(|s| s.to_be_bytes().to_vec()) + .flatten() + .collect::>(), + )?; + } + + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(key.clone(), chain); + + Ok(()) } } diff --git a/src/server_server.rs b/src/server_server.rs index 526ed518..65fd4a85 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -51,7 +51,7 @@ use ruma::{ ServerSigningKeyId, UserId, }; use std::{ - collections::{btree_map, hash_map, BTreeMap, HashMap, HashSet}, + collections::{btree_map, hash_map, BTreeMap, BTreeSet, HashMap, HashSet}, convert::{TryFrom, TryInto}, fmt::Debug, future::Future, @@ -1975,44 +1975,85 @@ fn get_auth_chain( starting_events: Vec, db: &Database, ) -> Result + '_> { - let mut full_auth_chain = HashSet::new(); + const NUM_BUCKETS: usize = 50; - const NUM_BUCKETS: usize = 100; - - let mut buckets = vec![HashSet::new(); NUM_BUCKETS]; + let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; for id in starting_events { - let short = db.rooms.get_or_create_shorteventid(&id, &db.globals)?; - let bucket_id = (short % NUM_BUCKETS as u64) as usize; - buckets[bucket_id].insert((short, id)); + if let Some(pdu) = db.rooms.get_pdu(&id)? { + for auth_event in &pdu.auth_events { + let short = db + .rooms + .get_or_create_shorteventid(&auth_event, &db.globals)?; + let bucket_id = (short % NUM_BUCKETS as u64) as usize; + buckets[bucket_id].insert((short, auth_event.clone())); + } + } } - let mut cache = db.rooms.auth_chain_cache(); + let mut full_auth_chain = HashSet::new(); + let mut hits = 0; + let mut misses = 0; for chunk in buckets { - let chunk_key = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = cache.get_mut(&chunk_key) { - full_auth_chain.extend(cached.iter().cloned()); + if chunk.is_empty() { continue; } + // The code below will only get the auth chains, not the events in the chunk. So let's add + // them first + full_auth_chain.extend(chunk.iter().map(|(id, _)| id)); + + let chunk_key = chunk + .iter() + .map(|(short, _)| short) + .copied() + .collect::>(); + if let Some(cached) = db.rooms.get_auth_chain_from_cache(&chunk_key)? { + hits += 1; + full_auth_chain.extend(cached.iter().cloned()); + continue; + } + misses += 1; + let mut chunk_cache = HashSet::new(); + let mut hits2 = 0; + let mut misses2 = 0; for (sevent_id, event_id) in chunk { - if let Some(cached) = cache.get_mut(&[sevent_id][..]) { + if let Some(cached) = db.rooms.get_auth_chain_from_cache(&[sevent_id])? { + hits2 += 1; chunk_cache.extend(cached.iter().cloned()); } else { - drop(cache); - let auth_chain = get_auth_chain_inner(&event_id, db)?; - cache = db.rooms.auth_chain_cache(); - cache.insert(vec![sevent_id], auth_chain.clone()); - chunk_cache.extend(auth_chain); + misses2 += 1; + let auth_chain = Arc::new(get_auth_chain_inner(&event_id, db)?); + db.rooms + .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; + println!( + "cache missed event {} with auth chain len {}", + event_id, + auth_chain.len() + ); + chunk_cache.extend(auth_chain.iter()); }; } - cache.insert(chunk_key, chunk_cache.clone()); - full_auth_chain.extend(chunk_cache); + println!( + "chunk missed with len {}, event hits2: {}, misses2: {}", + chunk_cache.len(), + hits2, + misses2 + ); + let chunk_cache = Arc::new(chunk_cache); + db.rooms + .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; + full_auth_chain.extend(chunk_cache.iter()); } - drop(cache); + println!( + "total: {}, chunk hits: {}, misses: {}", + full_auth_chain.len(), + hits, + misses + ); Ok(full_auth_chain .into_iter()