diff --git a/Cargo.toml b/Cargo.toml index c93c6736..19ccbb00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -769,7 +769,7 @@ perf = "warn" ################### #restriction = "warn" -#arithmetic_side_effects = "warn" # TODO +arithmetic_side_effects = "warn" #as_conversions = "warn" # TODO assertions_on_result_states = "warn" dbg_macro = "warn" diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 3efe283b..53009566 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -58,7 +58,7 @@ pub(super) async fn parse_pdu(body: Vec<&str>) -> Result match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { Ok(hash) => { diff --git a/src/admin/room/room_moderation_commands.rs b/src/admin/room/room_moderation_commands.rs index 39eb7c47..92e3de47 100644 --- a/src/admin/room/room_moderation_commands.rs +++ b/src/admin/room/room_moderation_commands.rs @@ -191,7 +191,10 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo )); } - let rooms_s = body.clone().drain(1..body.len() - 1).collect::>(); + let rooms_s = body + .clone() + .drain(1..body.len().saturating_sub(1)) + .collect::>(); let admin_room_alias = &services().globals.admin_alias; diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 6dc60713..884e1d29 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -23,7 +23,7 @@ pub(super) async fn list(_body: Vec<&str>) -> Result { match services().users.list_local_users() { Ok(users) => { let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); - plain_msg += &users.join("\n"); + plain_msg += users.join("\n").as_str(); plain_msg += "\n```"; Ok(RoomMessageEventContent::notice_markdown(plain_msg)) @@ -195,7 +195,10 @@ pub(super) async fn deactivate_all( )); } - let usernames = body.clone().drain(1..body.len() - 1).collect::>(); + let usernames = body + .clone() + .drain(1..body.len().saturating_sub(1)) + .collect::>(); let mut user_ids: Vec = Vec::with_capacity(usernames.len()); let mut admins = Vec::new(); diff --git a/src/core/utils/html.rs b/src/core/utils/html.rs index 3b44a31b..938e50ec 100644 --- a/src/core/utils/html.rs +++ b/src/core/utils/html.rs @@ -26,7 +26,7 @@ impl fmt::Display for Escape<'_> { fmt.write_str(s)?; // NOTE: we only expect single byte characters here - which is fine as long as // we only match single byte characters - last = i + 1; + last = i.saturating_add(1); } if last < s.len() { diff --git a/src/core/utils/math.rs b/src/core/utils/math.rs index a77f8e26..f00b9055 100644 --- a/src/core/utils/math.rs +++ b/src/core/utils/math.rs @@ -16,7 +16,15 @@ macro_rules! checked { #[cfg(not(debug_assertions))] #[macro_export] macro_rules! validated { - ($($input:tt)*) => { Ok($($input)*) } + ($($input:tt)*) => { + //#[allow(clippy::arithmetic_side_effects)] { + //Some($($input)*) + // .ok_or_else(|| $crate::Error::Arithmetic("this error should never been seen")) + //} + + //NOTE: remove me when stmt_expr_attributes is stable + $crate::checked!($($input)*) + } } #[cfg(debug_assertions)] diff --git a/src/core/utils/rand.rs b/src/core/utils/rand.rs index 1ded8a6d..b80671eb 100644 --- a/src/core/utils/rand.rs +++ b/src/core/utils/rand.rs @@ -15,7 +15,11 @@ pub fn string(length: usize) -> String { #[inline] #[must_use] -pub fn timepoint_secs(range: Range) -> SystemTime { SystemTime::now() + secs(range) } +pub fn timepoint_secs(range: Range) -> SystemTime { + SystemTime::now() + .checked_add(secs(range)) + .expect("range does not overflow SystemTime") +} #[must_use] pub fn secs(range: Range) -> Duration { diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index 239e27e9..add15861 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -65,7 +65,7 @@ fn common_prefix_none() { #[test] fn checked_add() { - use utils::math::checked; + use crate::checked; let a = 1234; let res = checked!(a + 1).unwrap(); @@ -75,9 +75,9 @@ fn checked_add() { #[test] #[should_panic(expected = "overflow")] fn checked_add_overflow() { - use utils::math::checked; + use crate::checked; - let a: u64 = u64::MAX; + let a = u64::MAX; let res = checked!(a + 1).expect("overflow"); assert_eq!(res, 0); } diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index ecb827fe..083aec3c 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -65,7 +65,10 @@ impl Data { "counter mismatch" ); - *counter = counter.wrapping_add(1); + *counter = counter + .checked_add(1) + .expect("counter must not overflow u64"); + self.global.insert(COUNTER, &counter.to_be_bytes())?; Ok(*counter) diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index e171cb9d..3948d1f5 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -836,11 +836,14 @@ async fn fix_bad_double_separator_in_state_cache(db: &Arc, _config: &C for (mut key, value) in roomuserid_joined.iter() { iter_count = iter_count.saturating_add(1); debug_info!(%iter_count); - let first_sep_index = key.iter().position(|&i| i == 0xFF).unwrap(); + let first_sep_index = key + .iter() + .position(|&i| i == 0xFF) + .expect("found 0xFF delim"); if key .iter() - .get(first_sep_index..=first_sep_index + 1) + .get(first_sep_index..=first_sep_index.saturating_add(1)) .copied() .collect_vec() == vec![0xFF, 0xFF] diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index caa75d4e..3cb8fda8 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -1,10 +1,10 @@ mod data; mod tests; -use std::{collections::HashMap, io::Cursor, path::PathBuf, sync::Arc, time::SystemTime}; +use std::{collections::HashMap, io::Cursor, num::Saturating as Sat, path::PathBuf, sync::Arc, time::SystemTime}; use base64::{engine::general_purpose, Engine as _}; -use conduit::{debug, debug_error, error, utils, Error, Result, Server}; +use conduit::{checked, debug, debug_error, error, utils, Error, Result, Server}; use data::Data; use image::imageops::FilterType; use ruma::{OwnedMxcUri, OwnedUserId}; @@ -305,36 +305,20 @@ impl Service { image.resize_to_fill(width, height, FilterType::CatmullRom) } else { let (exact_width, exact_height) = { - // Copied from image::dynimage::resize_dimensions - // - // https://github.com/image-rs/image/blob/6edf8ae492c4bb1dacb41da88681ea74dab1bab3/src/math/utils.rs#L5-L11 - // Calculates the width and height an image should be - // resized to. This preserves aspect ratio, and based - // on the `fill` parameter will either fill the - // dimensions to fit inside the smaller constraint - // (will overflow the specified bounds on one axis to - // preserve aspect ratio), or will shrink so that both - // dimensions are completely contained within the given - // `width` and `height`, with empty space on one axis. - let ratio = u64::from(original_width) * u64::from(height); - let nratio = u64::from(width) * u64::from(original_height); + let ratio = Sat(original_width) * Sat(height); + let nratio = Sat(width) * Sat(original_height); let use_width = nratio <= ratio; let intermediate = if use_width { - u64::from(original_height) * u64::from(width) / u64::from(original_width) + Sat(original_height) * Sat(checked!(width / original_width)?) } else { - u64::from(original_width) * u64::from(height) / u64::from(original_height) + Sat(original_width) * Sat(checked!(height / original_height)?) }; + if use_width { - if u32::try_from(intermediate).is_ok() { - (width, intermediate as u32) - } else { - ((u64::from(width) * u64::from(u32::MAX) / intermediate) as u32, u32::MAX) - } - } else if u32::try_from(intermediate).is_ok() { - (intermediate as u32, height) + (width, intermediate.0) } else { - (u32::MAX, (u64::from(height) * u64::from(u32::MAX) / intermediate) as u32) + (intermediate.0, height) } }; diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 584f1a6d..f5400379 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -3,7 +3,7 @@ mod data; use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{debug, error, utils, Error, Result}; +use conduit::{checked, debug, error, utils, Error, Result}; use data::Data; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -79,12 +79,16 @@ pub struct Service { timer_receiver: Mutex>, handler_join: Mutex>>, timeout_remote_users: bool, + idle_timeout: u64, + offline_timeout: u64, } #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; + let idle_timeout_s = config.presence_idle_timeout_s; + let offline_timeout_s = config.presence_offline_timeout_s; let (timer_sender, timer_receiver) = loole::unbounded(); Ok(Arc::new(Self { db: Data::new(args.db), @@ -92,6 +96,8 @@ impl crate::Service for Service { timer_receiver: Mutex::new(timer_receiver), handler_join: Mutex::new(None), timeout_remote_users: config.presence_timeout_remote_users, + idle_timeout: checked!(idle_timeout_s * 1_000)?, + offline_timeout: checked!(offline_timeout_s * 1_000)?, })) } @@ -219,7 +225,7 @@ impl Service { loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { - Some(user_id) = presence_timers.next() => process_presence_timer(&user_id)?, + Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, event = receiver.recv_async() => match event { Err(_e) => return Ok(()), Ok((user_id, timeout)) => { @@ -230,6 +236,36 @@ impl Service { } } } + + fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { + let mut presence_state = PresenceState::Offline; + let mut last_active_ago = None; + let mut status_msg = None; + + let presence_event = self.get_presence(user_id)?; + + if let Some(presence_event) = presence_event { + presence_state = presence_event.content.presence; + last_active_ago = presence_event.content.last_active_ago; + status_msg = presence_event.content.status_msg; + } + + let new_state = match (&presence_state, last_active_ago.map(u64::from)) { + (PresenceState::Online, Some(ago)) if ago >= self.idle_timeout => Some(PresenceState::Unavailable), + (PresenceState::Unavailable, Some(ago)) if ago >= self.offline_timeout => Some(PresenceState::Offline), + _ => None, + }; + + debug!( + "Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}" + ); + + if let Some(new_state) = new_state { + self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; + } + + Ok(()) + } } async fn presence_timer(user_id: OwnedUserId, timeout: Duration) -> OwnedUserId { @@ -237,36 +273,3 @@ async fn presence_timer(user_id: OwnedUserId, timeout: Duration) -> OwnedUserId user_id } - -fn process_presence_timer(user_id: &OwnedUserId) -> Result<()> { - let idle_timeout = services().globals.config.presence_idle_timeout_s * 1_000; - let offline_timeout = services().globals.config.presence_offline_timeout_s * 1_000; - - let mut presence_state = PresenceState::Offline; - let mut last_active_ago = None; - let mut status_msg = None; - - let presence_event = services().presence.get_presence(user_id)?; - - if let Some(presence_event) = presence_event { - presence_state = presence_event.content.presence; - last_active_ago = presence_event.content.last_active_ago; - status_msg = presence_event.content.status_msg; - } - - let new_state = match (&presence_state, last_active_ago.map(u64::from)) { - (PresenceState::Online, Some(ago)) if ago >= idle_timeout => Some(PresenceState::Unavailable), - (PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => Some(PresenceState::Offline), - _ => None, - }; - - debug!("Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}"); - - if let Some(new_state) = new_state { - services() - .presence - .set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; - } - - Ok(()) -} diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index ca04b1e5..18cdb70f 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,7 +5,7 @@ use std::{ sync::Arc, }; -use conduit::{debug, error, trace, warn, Error, Result}; +use conduit::{debug, error, trace, validated, warn, Error, Result}; use data::Data; use ruma::{api::client::error::ErrorKind, EventId, RoomId}; @@ -43,20 +43,20 @@ impl Service { #[tracing::instrument(skip_all, name = "auth_chain")] pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result> { - const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? + const NUM_BUCKETS: u64 = 50; //TODO: change possible w/o disrupting db? const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); let started = std::time::Instant::now(); - let mut buckets = [BUCKET; NUM_BUCKETS]; - for (i, short) in services() + let mut buckets = [BUCKET; NUM_BUCKETS as usize]; + for (i, &short) in services() .rooms .short .multi_get_or_create_shorteventid(starting_events)? .iter() .enumerate() { - let bucket = short % NUM_BUCKETS as u64; - buckets[bucket as usize].insert((*short, starting_events[i])); + let bucket = validated!(short % NUM_BUCKETS)?; + buckets[bucket as usize].insert((short, starting_events[i])); } debug!( diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 395b70f2..0f7919dd 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -191,7 +191,7 @@ impl Service { e.insert((Instant::now(), 1)); }, hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1 + 1); + *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)); }, }; }, @@ -1072,7 +1072,7 @@ impl Service { let mut todo_auth_events = vec![Arc::clone(id)]; let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - let mut i = 0; + let mut i: u64 = 0; while let Some(next_id) = todo_auth_events.pop() { if let Some((time, tries)) = services() .globals @@ -1094,7 +1094,7 @@ impl Service { continue; } - i += 1; + i = i.saturating_add(1); if i % 100 == 0 { tokio::task::yield_now().await; } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index e35c969d..cdb2fc29 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -205,7 +205,7 @@ impl Service { if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) { for relation in relations.flatten() { if stack_pdu.1 < max_depth { - stack.push((relation.clone(), stack_pdu.1 + 1)); + stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); } pdus.push(relation); diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 17acb0b3..06eaf655 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -76,10 +76,12 @@ impl Data { .iter_from(&first_possible_edu, false) .take_while(move |(k, _)| k.starts_with(&prefix2)) .map(move |(k, v)| { - let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + size_of::()]) + let count_offset = prefix.len().saturating_add(size_of::()); + let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + let user_id_offset = count_offset.saturating_add(1); let user_id = UserId::parse( - utils::string_from_bytes(&k[prefix.len() + size_of::() + 1..]) + utils::string_from_bytes(&k[user_id_offset..]) .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, ) .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index bf6cc873..6924c50e 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -7,7 +7,7 @@ use std::{ sync::Arc, }; -use conduit::debug_info; +use conduit::{checked, debug_info}; use lru_cache::LruCache; use ruma::{ api::{ @@ -508,7 +508,8 @@ impl Service { } // We have reached the room after where we last left off - if parents.len() + 1 == short_room_ids.len() { + let parents_len = parents.len(); + if checked!(parents_len + 1)? == short_room_ids.len() { populate_results = true; } } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 61c7d6e6..33773001 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; +use conduit::{checked, utils, Error, Result}; use database::{Database, Map}; use super::CompressedStateEvent; @@ -38,11 +38,12 @@ impl Data { let mut added = HashSet::new(); let mut removed = HashSet::new(); - let mut i = size_of::(); - while let Some(v) = value.get(i..i + 2 * size_of::()) { + let stride = size_of::(); + let mut i = stride; + while let Some(v) = value.get(i..checked!(i + 2 * stride)?) { if add_mode && v.starts_with(&0_u64.to_be_bytes()) { add_mode = false; - i += size_of::(); + i = checked!(i + stride)?; continue; } if add_mode { @@ -50,7 +51,7 @@ impl Data { } else { removed.insert(v.try_into().expect("we checked the size above")); } - i += 2 * size_of::(); + i = checked!(i + 2 * stride)?; } Ok(StateDiff { diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 97f3fb80..ca076fc9 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -7,7 +7,7 @@ use std::{ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{utils, Result}; +use conduit::{checked, utils, Result}; use data::Data; use lru_cache::LruCache; use ruma::{EventId, RoomId}; @@ -169,12 +169,14 @@ impl Service { statediffremoved: Arc>, diff_to_sibling: usize, mut parent_states: ParentStatesVec, ) -> Result<()> { - let diffsum = statediffnew.len() + statediffremoved.len(); + let statediffnew_len = statediffnew.len(); + let statediffremoved_len = statediffremoved.len(); + let diffsum = checked!(statediffnew_len + statediffremoved_len)?; if parent_states.len() > 3 { // Number of layers // To many layers, we have to go deeper - let parent = parent_states.pop().unwrap(); + let parent = parent_states.pop().expect("parent must have a state"); let mut parent_new = (*parent.2).clone(); let mut parent_removed = (*parent.3).clone(); @@ -226,10 +228,12 @@ impl Service { // 1. We add the current diff on top of the parent layer. // 2. We replace a layer above - let parent = parent_states.pop().unwrap(); - let parent_diff = parent.2.len() + parent.3.len(); + let parent = parent_states.pop().expect("parent must have a state"); + let parent_2_len = parent.2.len(); + let parent_3_len = parent.3.len(); + let parent_diff = checked!(parent_2_len + parent_3_len)?; - if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { // Diff too big, we replace above layer(s) let mut parent_new = (*parent.2).clone(); let mut parent_removed = (*parent.3).clone(); diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index 29539847..c4a1a294 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,6 +1,6 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; +use conduit::{checked, utils, Error, Result}; use database::{Database, Map}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; @@ -31,7 +31,7 @@ impl Data { .to_vec(); let mut current = prefix.clone(); - current.extend_from_slice(&(until - 1).to_be_bytes()); + current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); Ok(Box::new( self.threadid_userids diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index f3cefe21..dd2686b0 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -64,7 +64,7 @@ impl Service { .and_then(|relations| serde_json::from_value::(relations.clone().into()).ok()) { // Thread already existed - relations.count += uint!(1); + relations.count = relations.count.saturating_add(uint!(1)); relations.latest_event = pdu.to_message_like_event(); let content = serde_json::to_value(relations).expect("to_value always works"); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 0d4d945e..ec975b99 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -4,7 +4,7 @@ use std::{ sync::{Arc, Mutex}, }; -use conduit::{error, utils, Error, Result}; +use conduit::{checked, error, utils, Error, Result}; use database::{Database, Map}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; @@ -281,10 +281,12 @@ impl Data { /// Returns the `count` of this pdu's id. pub(super) fn pdu_count(pdu_id: &[u8]) -> Result { - let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + let stride = size_of::(); + let pdu_id_len = pdu_id.len(); + let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; let second_last_u64 = - utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::()..pdu_id.len() - size_of::()]); + utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]); if matches!(second_last_u64, Ok(0)) { Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index ba987dbd..70f4423c 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -6,7 +6,7 @@ use std::{ sync::Arc, }; -use conduit::{debug, error, info, utils, utils::mutex_map, warn, Error, Result}; +use conduit::{debug, error, info, utils, utils::mutex_map, validated, warn, Error, Result}; use data::Data; use itertools::Itertools; use rand::prelude::SliceRandom; @@ -670,7 +670,7 @@ impl Service { .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) .max() .unwrap_or_else(|| uint!(0)) - + uint!(1); + .saturating_add(uint!(1)); let mut unsigned = unsigned.unwrap_or_default(); @@ -1240,10 +1240,11 @@ impl Service { let insert_lock = services().globals.roomid_mutex_insert.lock(&room_id).await; + let max = u64::MAX; let count = services().globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); + pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); // Insert pdu self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 0f4fa17a..0fb0d9dc 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -93,7 +93,7 @@ impl Service { statuses.entry(dest).and_modify(|e| { *e = match e { TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), - TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n + 1, Instant::now()), + TransactionStatus::Retrying(ref n) => TransactionStatus::Failed(n.saturating_add(1), Instant::now()), TransactionStatus::Failed(..) => panic!("Request that was not even running failed?!"), } }); diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 3ba54a93..302eae9d 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -463,7 +463,8 @@ impl Data { .algorithm(), ) }) { - *counts.entry(algorithm?).or_default() += uint!(1); + let count: &mut UInt = counts.entry(algorithm?).or_default(); + *count = count.saturating_add(uint!(1)); } Ok(counts) @@ -814,7 +815,7 @@ impl Data { .map(|(key, _)| { Ok::<_, Error>(( key.clone(), - utils::u64_from_bytes(&key[key.len() - size_of::()..key.len()]) + utils::u64_from_bytes(&key[key.len().saturating_sub(size_of::())..key.len()]) .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, )) }) @@ -928,10 +929,12 @@ impl Data { /// Creates an OpenID token, which can be used to prove that a user has /// access to an account (primarily for integrations) pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { - let expires_in = services().globals.config.openid_token_ttl; - let expires_at = utils::millis_since_unix_epoch().saturating_add(expires_in * 1000); + use std::num::Saturating as Sat; - let mut value = expires_at.to_be_bytes().to_vec(); + let expires_in = services().globals.config.openid_token_ttl; + let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); + + let mut value = expires_at.0.to_be_bytes().to_vec(); value.extend_from_slice(user_id.as_bytes()); self.openidtoken_expiresatuserid