use Arc<[u64]> rather than Arc<HashSet<u64>> for auth_chain_cache value.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-04-08 09:01:28 -07:00 committed by June
parent 2cc72de80e
commit c42209c0b3
4 changed files with 20 additions and 21 deletions

View file

@ -1,9 +1,9 @@
use std::{collections::HashSet, mem::size_of, sync::Arc}; use std::{mem::size_of, sync::Arc};
use crate::{database::KeyValueDatabase, service, utils, Result}; use crate::{database::KeyValueDatabase, service, utils, Result};
impl service::rooms::auth_chain::Data for KeyValueDatabase { impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> { fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
// Check RAM cache // Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result))); return Ok(Some(Arc::clone(result)));
@ -19,12 +19,10 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
chain chain
.chunks_exact(size_of::<u64>()) .chunks_exact(size_of::<u64>())
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
.collect() .collect::<Arc<[u64]>>()
}); });
if let Some(chain) = chain { if let Some(chain) = chain {
let chain = Arc::new(chain);
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.auth_chain_cache
.lock() .lock()
@ -38,7 +36,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
Ok(None) Ok(None)
} }
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> { fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
// Only persist single events in db // Only persist single events in db
if key.len() == 1 { if key.len() == 1 {
self.shorteventid_authchain.insert( self.shorteventid_authchain.insert(

View file

@ -181,7 +181,7 @@ pub struct KeyValueDatabase {
//pub pusher: pusher::PushData, //pub pusher: pusher::PushData,
pub(super) senderkey_pusher: Arc<dyn KvTree>, pub(super) senderkey_pusher: Arc<dyn KvTree>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[u64]>>>,
pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>, pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>, pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>, pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,

View file

@ -1,8 +1,8 @@
use std::{collections::HashSet, sync::Arc}; use std::sync::Arc;
use crate::Result; use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<HashSet<u64>>>>; fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<[u64]>>>;
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()>; fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()>;
} }

View file

@ -15,15 +15,6 @@ pub struct Service {
} }
impl Service { impl Service {
pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
self.db.get_cached_eventid_authchain(key)
}
#[tracing::instrument(skip(self))]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
self.db.cache_auth_chain(key, auth_chain)
}
pub async fn get_auth_chain<'a>( pub async fn get_auth_chain<'a>(
&self, room_id: &RoomId, starting_events: Vec<Arc<EventId>>, &self, room_id: &RoomId, starting_events: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> { ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
@ -81,7 +72,7 @@ impl Service {
services() services()
.rooms .rooms
.auth_chain .auth_chain
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; .cache_auth_chain(vec![sevent_id], &auth_chain)?;
debug!( debug!(
event_id = ?event_id, event_id = ?event_id,
chain_length = ?auth_chain.len(), chain_length = ?auth_chain.len(),
@ -105,7 +96,7 @@ impl Service {
services() services()
.rooms .rooms
.auth_chain .auth_chain
.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; .cache_auth_chain(chunk_key, &chunk_cache)?;
full_auth_chain.extend(chunk_cache.iter()); full_auth_chain.extend(chunk_cache.iter());
} }
@ -154,4 +145,14 @@ impl Service {
Ok(found) Ok(found)
} }
pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
self.db.get_cached_eventid_authchain(key)
}
#[tracing::instrument(skip(self))]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Result<()> {
self.db
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
}
} }