fix counter increment race
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
2e2cf08bb2
commit
46423cab4f
2 changed files with 34 additions and 10 deletions
|
@ -1,6 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeMap, HashMap},
|
collections::{BTreeMap, HashMap},
|
||||||
sync::Arc,
|
sync::{Arc, RwLock},
|
||||||
};
|
};
|
||||||
|
|
||||||
use conduit::{trace, utils, Error, Result};
|
use conduit::{trace, utils, Error, Result};
|
||||||
|
@ -33,6 +33,7 @@ pub struct Data {
|
||||||
readreceiptid_readreceipt: Arc<Map>,
|
readreceiptid_readreceipt: Arc<Map>,
|
||||||
userid_lastonetimekeyupdate: Arc<Map>,
|
userid_lastonetimekeyupdate: Arc<Map>,
|
||||||
pub(super) db: Arc<Database>,
|
pub(super) db: Arc<Database>,
|
||||||
|
counter: RwLock<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Data {
|
impl Data {
|
||||||
|
@ -52,18 +53,41 @@ impl Data {
|
||||||
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
|
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
|
||||||
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
|
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
|
||||||
db: db.clone(),
|
db: db.clone(),
|
||||||
|
counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn next_count(&self) -> Result<u64> {
|
pub fn next_count(&self) -> Result<u64> {
|
||||||
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
|
let mut lock = self.counter.write().expect("locked");
|
||||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
let counter: &mut u64 = &mut lock;
|
||||||
|
debug_assert!(
|
||||||
|
*counter == Self::stored_count(&self.global).expect("database failure"),
|
||||||
|
"counter mismatch"
|
||||||
|
);
|
||||||
|
|
||||||
|
*counter = counter.wrapping_add(1);
|
||||||
|
self.global.insert(COUNTER, &counter.to_be_bytes())?;
|
||||||
|
|
||||||
|
Ok(*counter)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn current_count(&self) -> Result<u64> {
|
#[inline]
|
||||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
pub fn current_count(&self) -> u64 {
|
||||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
|
let lock = self.counter.read().expect("locked");
|
||||||
})
|
let counter: &u64 = &lock;
|
||||||
|
debug_assert!(
|
||||||
|
*counter == Self::stored_count(&self.global).expect("database failure"),
|
||||||
|
"counter mismatch"
|
||||||
|
);
|
||||||
|
|
||||||
|
*counter
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stored_count(global: &Arc<Map>) -> Result<u64> {
|
||||||
|
global
|
||||||
|
.get(COUNTER)?
|
||||||
|
.as_deref()
|
||||||
|
.map_or(Ok(0_u64), utils::u64_from_bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn last_check_for_updates_id(&self) -> Result<u64> {
|
pub fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||||
|
|
|
@ -140,11 +140,11 @@ impl Service {
|
||||||
/// Returns this server's keypair.
|
/// Returns this server's keypair.
|
||||||
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
|
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[inline]
|
||||||
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
|
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[inline]
|
||||||
pub fn current_count(&self) -> Result<u64> { self.db.current_count() }
|
pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) }
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() }
|
pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() }
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue