diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 59df5e7f..823adb46 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,6 +1,6 @@ use std::{ collections::{BTreeMap, HashMap}, - sync::Arc, + sync::{Arc, RwLock}, }; use conduit::{trace, utils, Error, Result}; @@ -33,6 +33,7 @@ pub struct Data { readreceiptid_readreceipt: Arc, userid_lastonetimekeyupdate: Arc, pub(super) db: Arc, + counter: RwLock, } impl Data { @@ -52,18 +53,41 @@ impl Data { readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), db: db.clone(), + counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")), } } pub fn next_count(&self) -> Result { - utils::u64_from_bytes(&self.global.increment(COUNTER)?) - .map_err(|_| Error::bad_database("Count has invalid bytes.")) + let mut lock = self.counter.write().expect("locked"); + let counter: &mut u64 = &mut lock; + debug_assert!( + *counter == Self::stored_count(&self.global).expect("database failure"), + "counter mismatch" + ); + + *counter = counter.wrapping_add(1); + self.global.insert(COUNTER, &counter.to_be_bytes())?; + + Ok(*counter) } - pub fn current_count(&self) -> Result { - self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes.")) - }) + #[inline] + pub fn current_count(&self) -> u64 { + let lock = self.counter.read().expect("locked"); + let counter: &u64 = &lock; + debug_assert!( + *counter == Self::stored_count(&self.global).expect("database failure"), + "counter mismatch" + ); + + *counter + } + + fn stored_count(global: &Arc) -> Result { + global + .get(COUNTER)? + .as_deref() + .map_or(Ok(0_u64), utils::u64_from_bytes) } pub fn last_check_for_updates_id(&self) -> Result { diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 0cf0fccf..979f671f 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -140,11 +140,11 @@ impl Service { /// Returns this server's keypair. pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair } - #[tracing::instrument(skip(self))] + #[inline] pub fn next_count(&self) -> Result { self.db.next_count() } - #[tracing::instrument(skip(self))] - pub fn current_count(&self) -> Result { self.db.current_count() } + #[inline] + pub fn current_count(&self) -> Result { Ok(self.db.current_count()) } #[tracing::instrument(skip(self))] pub fn last_check_for_updates_id(&self) -> Result { self.db.last_check_for_updates_id() }