fix counter increment race

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-02 07:56:45 +00:00
parent 2e2cf08bb2
commit 46423cab4f
2 changed files with 34 additions and 10 deletions

View file

@ -1,6 +1,6 @@
use std::{
collections::{BTreeMap, HashMap},
sync::Arc,
sync::{Arc, RwLock},
};
use conduit::{trace, utils, Error, Result};
@ -33,6 +33,7 @@ pub struct Data {
readreceiptid_readreceipt: Arc<Map>,
userid_lastonetimekeyupdate: Arc<Map>,
pub(super) db: Arc<Database>,
counter: RwLock<u64>,
}
impl Data {
@ -52,18 +53,41 @@ impl Data {
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
db: db.clone(),
counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")),
}
}
pub fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))
let mut lock = self.counter.write().expect("locked");
let counter: &mut u64 = &mut lock;
debug_assert!(
*counter == Self::stored_count(&self.global).expect("database failure"),
"counter mismatch"
);
*counter = counter.wrapping_add(1);
self.global.insert(COUNTER, &counter.to_be_bytes())?;
Ok(*counter)
}
pub fn current_count(&self) -> Result<u64> {
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
})
#[inline]
pub fn current_count(&self) -> u64 {
let lock = self.counter.read().expect("locked");
let counter: &u64 = &lock;
debug_assert!(
*counter == Self::stored_count(&self.global).expect("database failure"),
"counter mismatch"
);
*counter
}
fn stored_count(global: &Arc<Map>) -> Result<u64> {
global
.get(COUNTER)?
.as_deref()
.map_or(Ok(0_u64), utils::u64_from_bytes)
}
pub fn last_check_for_updates_id(&self) -> Result<u64> {

View file

@ -140,11 +140,11 @@ impl Service {
/// Returns this server's keypair.
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
#[tracing::instrument(skip(self))]
#[inline]
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
#[tracing::instrument(skip(self))]
pub fn current_count(&self) -> Result<u64> { self.db.current_count() }
#[inline]
pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) }
#[tracing::instrument(skip(self))]
pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() }