diff --git a/src/database/keyval.rs b/src/database/keyval.rs index c9d25977..a288f184 100644 --- a/src/database/keyval.rs +++ b/src/database/keyval.rs @@ -3,10 +3,6 @@ use serde::Deserialize; use crate::de; -pub(crate) type OwnedKeyVal = (Vec, Vec); -pub(crate) type OwnedKey = Vec; -pub(crate) type OwnedVal = Vec; - pub type KeyVal<'a, K = &'a Slice, V = &'a Slice> = (Key<'a, K>, Val<'a, V>); pub type Key<'a, T = &'a Slice> = T; pub type Val<'a, T = &'a Slice> = T; @@ -72,10 +68,6 @@ where de::from_slice::(val) } -#[inline] -#[must_use] -pub fn to_owned(kv: KeyVal<'_>) -> OwnedKeyVal { (kv.0.to_owned(), kv.1.to_owned()) } - #[inline] pub fn key(kv: KeyVal<'_, K, V>) -> Key<'_, K> { kv.0 } diff --git a/src/database/map/get.rs b/src/database/map/get.rs index 71489402..72382e36 100644 --- a/src/database/map/get.rs +++ b/src/database/map/get.rs @@ -3,14 +3,12 @@ use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; use arrayvec::ArrayVec; use conduit::{err, implement, Result}; use futures::future::ready; +use rocksdb::DBPinnableSlice; use serde::Serialize; -use crate::{ - keyval::{OwnedKey, OwnedVal}, - ser, - util::{map_err, or_else}, - Handle, -}; +use crate::{ser, util, Handle}; + +type RocksdbResult<'a> = Result>, rocksdb::Error>; /// Fetch a value from the database into cache, returning a reference-handle /// asynchronously. The key is serialized into an allocated buffer to perform @@ -68,17 +66,17 @@ pub fn get_blocking(&self, key: &K) -> Result> where K: AsRef<[u8]> + ?Sized + Debug, { - self.db + let res = self .db - .get_pinned_cf_opt(&self.cf(), key, &self.read_options) - .map_err(map_err)? - .map(Handle::from) - .ok_or(err!(Request(NotFound("Not found in database")))) + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options); + + into_result_handle(res) } #[implement(super::Map)] #[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] -pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> Vec> +pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> Vec>> where I: Iterator + ExactSizeIterator + Send + Debug, K: AsRef<[u8]> + Sized + Debug + 'a, @@ -87,19 +85,18 @@ where // comparator**. const SORTED: bool = false; - let mut ret: Vec> = Vec::with_capacity(keys.len()); let read_options = &self.read_options; - for res in self - .db + self.db .db .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) - { - match res { - Ok(Some(res)) => ret.push(Some((*res).to_vec())), - Ok(None) => ret.push(None), - Err(e) => or_else(e).expect("database multiget error"), - } - } - - ret + .into_iter() + .map(into_result_handle) + .collect() +} + +fn into_result_handle(result: RocksdbResult<'_>) -> Result> { + result + .map_err(util::map_err)? + .map(Handle::from) + .ok_or(err!(Request(NotFound("Not found in database")))) } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 66da3948..825ee109 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use conduit::{err, implement, utils, Error, Result}; +use conduit::{err, implement, utils, Result}; use database::{Deserialized, Map}; use ruma::{events::StateEventType, EventId, RoomId}; @@ -69,41 +69,26 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { #[implement(Service)] pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { - let mut ret: Vec = Vec::with_capacity(event_ids.len()); - let keys = event_ids - .iter() - .map(|id| id.as_bytes()) - .collect::>(); - - for (i, short) in self - .db + self.db .eventid_shorteventid - .get_batch_blocking(keys.iter()) - .iter() + .get_batch_blocking(event_ids.iter()) + .into_iter() .enumerate() - { - match short { - Some(short) => ret.push( - utils::u64_from_bytes(short) - .map_err(|_| Error::bad_database("Invalid shorteventid in db.")) - .unwrap(), - ), - None => { + .map(|(i, result)| match result { + Ok(ref short) => utils::u64_from_u8(short), + Err(_) => { let short = self.services.globals.next_count().unwrap(); self.db .eventid_shorteventid - .insert(keys[i], &short.to_be_bytes()); + .insert(event_ids[i], &short.to_be_bytes()); self.db .shorteventid_eventid - .insert(&short.to_be_bytes(), keys[i]); + .insert(&short.to_be_bytes(), event_ids[i]); - debug_assert!(ret.len() == i, "position of result must match input"); - ret.push(short); + short }, - } - } - - ret + }) + .collect() } #[implement(Service)]