fix counter increment race
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
2e2cf08bb2
commit
46423cab4f
2 changed files with 34 additions and 10 deletions
|
@ -1,6 +1,6 @@
|
|||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
sync::Arc,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use conduit::{trace, utils, Error, Result};
|
||||
|
@ -33,6 +33,7 @@ pub struct Data {
|
|||
readreceiptid_readreceipt: Arc<Map>,
|
||||
userid_lastonetimekeyupdate: Arc<Map>,
|
||||
pub(super) db: Arc<Database>,
|
||||
counter: RwLock<u64>,
|
||||
}
|
||||
|
||||
impl Data {
|
||||
|
@ -52,18 +53,41 @@ impl Data {
|
|||
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
|
||||
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
|
||||
db: db.clone(),
|
||||
counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
let mut lock = self.counter.write().expect("locked");
|
||||
let counter: &mut u64 = &mut lock;
|
||||
debug_assert!(
|
||||
*counter == Self::stored_count(&self.global).expect("database failure"),
|
||||
"counter mismatch"
|
||||
);
|
||||
|
||||
*counter = counter.wrapping_add(1);
|
||||
self.global.insert(COUNTER, &counter.to_be_bytes())?;
|
||||
|
||||
Ok(*counter)
|
||||
}
|
||||
|
||||
pub fn current_count(&self) -> Result<u64> {
|
||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
#[inline]
|
||||
pub fn current_count(&self) -> u64 {
|
||||
let lock = self.counter.read().expect("locked");
|
||||
let counter: &u64 = &lock;
|
||||
debug_assert!(
|
||||
*counter == Self::stored_count(&self.global).expect("database failure"),
|
||||
"counter mismatch"
|
||||
);
|
||||
|
||||
*counter
|
||||
}
|
||||
|
||||
fn stored_count(global: &Arc<Map>) -> Result<u64> {
|
||||
global
|
||||
.get(COUNTER)?
|
||||
.as_deref()
|
||||
.map_or(Ok(0_u64), utils::u64_from_bytes)
|
||||
}
|
||||
|
||||
pub fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
|
|
|
@ -140,11 +140,11 @@ impl Service {
|
|||
/// Returns this server's keypair.
|
||||
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
#[inline]
|
||||
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn current_count(&self) -> Result<u64> { self.db.current_count() }
|
||||
#[inline]
|
||||
pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() }
|
||||
|
|
Loading…
Add table
Reference in a new issue