continuwuity/src/database/mod.rs
strawberry 39aef8d1b9 dont build sha2, opentelemetry, or zstd code if unused
reduces unnecessary crates being compiled. splits them
into features.

i have yet to see anyone use conduit's opentelemetry
stuff, and realistically those people who do
performance benchmarking and measurements will be
building stuff anyways so they can just enable this
feature.

Signed-off-by: strawberry <strawberry@puppygock.gay>
2024-03-10 13:03:28 -04:00

1162 lines
41 KiB
Rust

pub(crate) mod abstraction;
pub(crate) mod key_value;
use std::{
collections::{BTreeMap, HashMap, HashSet},
fs::{self},
io::Write,
mem::size_of,
path::Path,
sync::{Arc, Mutex, RwLock},
time::Duration,
};
use abstraction::{KeyValueDatabaseEngine, KvTree};
use argon2::{password_hash::SaltString, PasswordHasher, PasswordVerifier};
use itertools::Itertools;
use lru_cache::LruCache;
use rand::thread_rng;
use ruma::{
api::appservice::Registration,
events::{
push_rules::{PushRulesEvent, PushRulesEventContent},
room::message::RoomMessageEventContent,
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
},
push::Ruleset,
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
};
use serde::Deserialize;
use tokio::{
sync::mpsc,
time::{interval, Instant},
};
use tracing::{debug, error, info, warn};
use crate::{
service::rooms::{edus::presence::presence_handler, timeline::PduCount},
services, utils, Config, Error, PduEvent, Result, Services, SERVICES,
};
pub struct KeyValueDatabase {
db: Arc<dyn KeyValueDatabaseEngine>,
//pub globals: globals::Globals,
pub(super) global: Arc<dyn KvTree>,
pub(super) server_signingkeys: Arc<dyn KvTree>,
//pub users: users::Users,
pub(super) userid_password: Arc<dyn KvTree>,
pub(super) userid_displayname: Arc<dyn KvTree>,
pub(super) userid_avatarurl: Arc<dyn KvTree>,
pub(super) userid_blurhash: Arc<dyn KvTree>,
pub(super) userdeviceid_token: Arc<dyn KvTree>,
pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
pub(super) token_userdeviceid: Arc<dyn KvTree>,
pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
pub(super) keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
pub(super) keyid_key: Arc<dyn KvTree>, // KeyId = UserId + KeyId (depends on key type)
pub(super) userid_masterkeyid: Arc<dyn KvTree>,
pub(super) userid_selfsigningkeyid: Arc<dyn KvTree>,
pub(super) userid_usersigningkeyid: Arc<dyn KvTree>,
pub(super) userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
pub(super) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
//pub uiaa: uiaa::Uiaa,
pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
pub(super) userdevicesessionid_uiaarequest:
RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>,
//pub edus: RoomEdus,
pub(super) readreceiptid_readreceipt: Arc<dyn KvTree>, // ReadReceiptId = RoomId + Count + UserId
pub(super) roomuserid_privateread: Arc<dyn KvTree>, // RoomUserId = Room + User, PrivateRead = Count
pub(super) roomuserid_lastprivatereadupdate: Arc<dyn KvTree>, // LastPrivateReadUpdate = Count
pub(super) typingid_userid: Arc<dyn KvTree>, // TypingId = RoomId + TimeoutTime + Count
pub(super) roomid_lasttypingupdate: Arc<dyn KvTree>, // LastRoomTypingUpdate = Count
pub(super) roomuserid_presence: Arc<dyn KvTree>,
//pub rooms: rooms::Rooms,
pub(super) pduid_pdu: Arc<dyn KvTree>, // PduId = ShortRoomId + Count
pub(super) eventid_pduid: Arc<dyn KvTree>,
pub(super) roomid_pduleaves: Arc<dyn KvTree>,
pub(super) alias_roomid: Arc<dyn KvTree>,
pub(super) aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count
pub(super) publicroomids: Arc<dyn KvTree>,
pub(super) threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count
pub(super) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
/// Participating servers in a room.
pub(super) roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName
pub(super) serverroomids: Arc<dyn KvTree>, // ServerRoomId = ServerName + RoomId
pub(super) userroomid_joined: Arc<dyn KvTree>,
pub(super) roomuserid_joined: Arc<dyn KvTree>,
pub(super) roomid_joinedcount: Arc<dyn KvTree>,
pub(super) roomid_invitedcount: Arc<dyn KvTree>,
pub(super) roomuseroncejoinedids: Arc<dyn KvTree>,
pub(super) userroomid_invitestate: Arc<dyn KvTree>, // InviteState = Vec<Raw<Pdu>>
pub(super) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
pub(super) userroomid_leftstate: Arc<dyn KvTree>,
pub(super) roomuserid_leftcount: Arc<dyn KvTree>,
pub(super) disabledroomids: Arc<dyn KvTree>, // Rooms where incoming federation handling is disabled
pub(super) bannedroomids: Arc<dyn KvTree>, // Rooms where local users are not allowed to join
pub(super) lazyloadedids: Arc<dyn KvTree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId
pub(super) userroomid_notificationcount: Arc<dyn KvTree>, // NotifyCount = u64
pub(super) userroomid_highlightcount: Arc<dyn KvTree>, // HightlightCount = u64
pub(super) roomuserid_lastnotificationread: Arc<dyn KvTree>, // LastNotificationRead = u64
/// Remember the current state hash of a room.
pub(super) roomid_shortstatehash: Arc<dyn KvTree>,
pub(super) roomsynctoken_shortstatehash: Arc<dyn KvTree>,
/// Remember the state hash at events in the past.
pub(super) shorteventid_shortstatehash: Arc<dyn KvTree>,
/// StateKey = EventType + StateKey, ShortStateKey = Count
pub(super) statekey_shortstatekey: Arc<dyn KvTree>,
pub(super) shortstatekey_statekey: Arc<dyn KvTree>,
pub(super) roomid_shortroomid: Arc<dyn KvTree>,
pub(super) shorteventid_eventid: Arc<dyn KvTree>,
pub(super) eventid_shorteventid: Arc<dyn KvTree>,
pub(super) statehash_shortstatehash: Arc<dyn KvTree>,
pub(super) shortstatehash_statediff: Arc<dyn KvTree>, /* StateDiff = parent (or 0) +
* (shortstatekey+shorteventid++) + 0_u64 +
* (shortstatekey+shorteventid--) */
pub(super) shorteventid_authchain: Arc<dyn KvTree>,
/// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event
/// /federation/send/txn.
pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
pub(super) softfailedeventids: Arc<dyn KvTree>,
/// ShortEventId + ShortEventId -> ().
pub(super) tofrom_relation: Arc<dyn KvTree>,
/// RoomId + EventId -> Parent PDU EventId.
pub(super) referencedevents: Arc<dyn KvTree>,
//pub account_data: account_data::AccountData,
pub(super) roomuserdataid_accountdata: Arc<dyn KvTree>, // RoomUserDataId = Room + User + Count + Type
pub(super) roomusertype_roomuserdataid: Arc<dyn KvTree>, // RoomUserType = Room + User + Type
//pub media: media::Media,
pub(super) mediaid_file: Arc<dyn KvTree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
pub(super) url_previews: Arc<dyn KvTree>,
//pub key_backups: key_backups::KeyBackups,
pub(super) backupid_algorithm: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub(super) backupid_etag: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub(super) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
//pub transaction_ids: transaction_ids::TransactionIds,
pub(super) userdevicetxnid_response: Arc<dyn KvTree>, /* Response can be empty (/sendToDevice) or the event id
* (/send) */
//pub sending: sending::Sending,
pub(super) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
pub(super) servernameevent_data: Arc<dyn KvTree>, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId +
* PduId / Id (for edus), Data = EDU content */
pub(super) servercurrentevent_data: Arc<dyn KvTree>, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId
* / Id (for edus), Data = EDU content */
//pub appservice: appservice::Appservice,
pub(super) id_appserviceregistrations: Arc<dyn KvTree>,
//pub pusher: pusher::PushData,
pub(super) senderkey_pusher: Arc<dyn KvTree>,
pub(super) cached_registrations: Arc<RwLock<HashMap<String, Registration>>>,
pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>,
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
pub(super) eventidshort_cache: Mutex<LruCache<OwnedEventId, u64>>,
pub(super) statekeyshort_cache: Mutex<LruCache<(StateEventType, String), u64>>,
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
pub(super) presence_timer_sender: Arc<mpsc::UnboundedSender<(OwnedUserId, Duration)>>,
}
#[derive(Deserialize)]
struct CheckForUpdatesResponseEntry {
id: u64,
date: String,
message: String,
}
#[derive(Deserialize)]
struct CheckForUpdatesResponse {
updates: Vec<CheckForUpdatesResponseEntry>,
}
impl KeyValueDatabase {
fn check_db_setup(config: &Config) -> Result<()> {
let path = Path::new(&config.database_path);
let sqlite_exists = path.join("conduit.db").exists();
let rocksdb_exists = path.join("IDENTITY").exists();
let mut count = 0;
if sqlite_exists {
count += 1;
}
if rocksdb_exists {
count += 1;
}
if count > 1 {
warn!("Multiple databases at database_path detected");
return Ok(());
}
if sqlite_exists && config.database_backend != "sqlite" {
return Err(Error::bad_config(
"Found sqlite at database_path, but is not specified in config.",
));
}
if rocksdb_exists && config.database_backend != "rocksdb" {
return Err(Error::bad_config(
"Found rocksdb at database_path, but is not specified in config.",
));
}
Ok(())
}
/// Load an existing database or create a new one.
pub async fn load_or_create(config: Config) -> Result<()> {
Self::check_db_setup(&config)?;
if !Path::new(&config.database_path).exists() {
debug!("Database path does not exist, assuming this is a new setup and creating it");
fs::create_dir_all(&config.database_path).map_err(|e| {
error!("Failed to create database path: {e}");
Error::BadConfig(
"Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please \
create the database folder yourself or allow conduwuit the permissions to create directories and \
files.",
)
})?;
}
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
"sqlite" => {
debug!("Got sqlite database backend");
#[cfg(not(feature = "sqlite"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "sqlite")]
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
},
"rocksdb" => {
debug!("Got rocksdb database backend");
#[cfg(not(feature = "rocksdb"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "rocksdb")]
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
},
_ => {
return Err(Error::BadConfig(
"Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.",
));
},
};
let (presence_sender, presence_receiver) = mpsc::unbounded_channel();
let db_raw = Box::new(Self {
db: builder.clone(),
userid_password: builder.open_tree("userid_password")?,
userid_displayname: builder.open_tree("userid_displayname")?,
userid_avatarurl: builder.open_tree("userid_avatarurl")?,
userid_blurhash: builder.open_tree("userid_blurhash")?,
userdeviceid_token: builder.open_tree("userdeviceid_token")?,
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
keyid_key: builder.open_tree("keyid_key")?,
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?,
typingid_userid: builder.open_tree("typingid_userid")?,
roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?,
roomuserid_presence: builder.open_tree("roomuserid_presence")?,
pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
alias_roomid: builder.open_tree("alias_roomid")?,
aliasid_alias: builder.open_tree("aliasid_alias")?,
publicroomids: builder.open_tree("publicroomids")?,
threadid_userids: builder.open_tree("threadid_userids")?,
tokenids: builder.open_tree("tokenids")?,
roomserverids: builder.open_tree("roomserverids")?,
serverroomids: builder.open_tree("serverroomids")?,
userroomid_joined: builder.open_tree("userroomid_joined")?,
roomuserid_joined: builder.open_tree("roomuserid_joined")?,
roomid_joinedcount: builder.open_tree("roomid_joinedcount")?,
roomid_invitedcount: builder.open_tree("roomid_invitedcount")?,
roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
disabledroomids: builder.open_tree("disabledroomids")?,
bannedroomids: builder.open_tree("bannedroomids")?,
lazyloadedids: builder.open_tree("lazyloadedids")?,
userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?,
shorteventid_authchain: builder.open_tree("shorteventid_authchain")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?,
statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
softfailedeventids: builder.open_tree("softfailedeventids")?,
tofrom_relation: builder.open_tree("tofrom_relation")?,
referencedevents: builder.open_tree("referencedevents")?,
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
mediaid_file: builder.open_tree("mediaid_file")?,
url_previews: builder.open_tree("url_previews")?,
backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
servername_educount: builder.open_tree("servername_educount")?,
servernameevent_data: builder.open_tree("servernameevent_data")?,
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?,
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?,
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
pdu_cache: Mutex::new(LruCache::new(
config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"),
)),
auth_chain_cache: Mutex::new(LruCache::new((100_000.0 * config.conduit_cache_capacity_modifier) as usize)),
shorteventid_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
eventidshort_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
shortstatekey_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
statekeyshort_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
our_real_users_cache: RwLock::new(HashMap::new()),
appservice_in_room_cache: RwLock::new(HashMap::new()),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
presence_timer_sender: Arc::new(presence_sender),
});
let db = Box::leak(db_raw);
let services_raw = Box::new(Services::build(db, config)?);
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));
// Matrix resource ownership is based on the server name; changing it
// requires recreating the database from scratch.
if services().users.count()? > 0 {
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is valid");
if !services().users.exists(&conduit_user)? {
error!("The {} server user does not exist, and the database is not new.", conduit_user);
return Err(Error::bad_database(
"Cannot reuse an existing database after changing the server name, please delete the old one \
first.",
));
}
}
// If the database has any data, perform data migrations before starting
// do not increment the db version if the user is not using sha256_media
let latest_database_version = if cfg!(feature = "sha256_media") {
14
} else {
13
};
if services().users.count()? > 0 {
// MIGRATIONS
if services().globals.database_version()? < 1 {
for (roomserverid, _) in db.roomserverids.iter() {
let mut parts = roomserverid.split(|&b| b == 0xFF);
let room_id = parts.next().expect("split always returns one element");
let servername = match parts.next() {
Some(s) => s,
None => {
error!("Migration: Invalid roomserverid in db.");
continue;
},
};
let mut serverroomid = servername.to_vec();
serverroomid.push(0xFF);
serverroomid.extend_from_slice(room_id);
db.serverroomids.insert(&serverroomid, &[])?;
}
services().globals.bump_database_version(1)?;
warn!("Migration: 0 -> 1 finished");
}
if services().globals.database_version()? < 2 {
// We accidentally inserted hashed versions of "" into the db instead of just ""
for (userid, password) in db.userid_password.iter() {
let salt = SaltString::generate(thread_rng());
let empty_pass = services()
.globals
.argon
.hash_password(b"", &salt)
.expect("our own password to be properly hashed");
let empty_hashed_password =
services().globals.argon.verify_password(&password, &empty_pass).is_ok();
if empty_hashed_password {
db.userid_password.insert(&userid, b"")?;
}
}
services().globals.bump_database_version(2)?;
warn!("Migration: 1 -> 2 finished");
}
if services().globals.database_version()? < 3 {
// Move media to filesystem
for (key, content) in db.mediaid_file.iter() {
if content.is_empty() {
continue;
}
#[allow(deprecated)]
let path = services().globals.get_media_file(&key);
let mut file = fs::File::create(path)?;
file.write_all(&content)?;
db.mediaid_file.insert(&key, &[])?;
}
services().globals.bump_database_version(3)?;
warn!("Migration: 2 -> 3 finished");
}
if services().globals.database_version()? < 4 {
// Add federated users to services() as deactivated
for our_user in services().users.iter() {
let our_user = our_user?;
if services().users.is_deactivated(&our_user)? {
continue;
}
for room in services().rooms.state_cache.rooms_joined(&our_user) {
for user in services().rooms.state_cache.room_members(&room?) {
let user = user?;
if user.server_name() != services().globals.server_name() {
info!(?user, "Migration: creating user");
services().users.create(&user, None)?;
}
}
}
}
services().globals.bump_database_version(4)?;
warn!("Migration: 3 -> 4 finished");
}
if services().globals.database_version()? < 5 {
// Upgrade user data store
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
let mut parts = roomuserdataid.split(|&b| b == 0xFF);
let room_id = parts.next().unwrap();
let user_id = parts.next().unwrap();
let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap();
let mut key = room_id.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id);
key.push(0xFF);
key.extend_from_slice(event_type);
db.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
}
services().globals.bump_database_version(5)?;
warn!("Migration: 4 -> 5 finished");
}
if services().globals.database_version()? < 6 {
// Set room member count
for (roomid, _) in db.roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services().rooms.state_cache.update_joined_count(room_id)?;
}
services().globals.bump_database_version(6)?;
warn!("Migration: 5 -> 6 finished");
}
if services().globals.database_version()? < 7 {
// Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None;
let mut current_room = None;
let mut current_state = HashSet::new();
let mut counter = 0;
let mut handle_state = |current_sstatehash: u64,
current_room: &RoomId,
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
counter += 1;
let last_roomsstatehash = last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
services().rooms.state_compressor.load_shortstatehash_info(last_roomsstatehash)
},
)?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew =
current_state.difference(&parent_stateinfo.1).copied().collect::<HashSet<_>>();
let statediffremoved =
parent_stateinfo.1.difference(&current_state).copied().collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
services().rooms.state_compressor.save_state_from_diff(
current_sstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
2, // every state change is 2 event changes on average
states_parents,
)?;
/*
let mut tmp = services().rooms.load_shortstatehash_info(&current_sstatehash)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
Ok::<_, Error>(())
};
for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() {
let sstatehash =
utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec();
if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_deref().unwrap(),
current_state,
&mut last_roomstates,
)?;
last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash);
}
current_state = HashSet::new();
current_sstatehash = Some(sstatehash);
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let pdu = services().rooms.timeline.get_pdu(event_id).unwrap().unwrap();
if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone());
}
}
let mut val = sstatekey;
val.extend_from_slice(&seventid);
current_state.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_deref().unwrap(),
current_state,
&mut last_roomstates,
)?;
}
services().globals.bump_database_version(7)?;
warn!("Migration: 6 -> 7 finished");
}
if services().globals.database_version()? < 8 {
// Generate short room ids for all rooms
for (room_id, _) in db.roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes();
db.roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8");
}
// Update pduids db layout
let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(count);
Some((new_key, v))
});
db.pduid_pdu.insert_batch(&mut batch)?;
let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
let mut parts = value.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
let mut new_value = short_room_id;
new_value.extend_from_slice(count);
Some((k, new_value))
});
db.eventid_pduid.insert_batch(&mut batch2)?;
services().globals.bump_database_version(8)?;
warn!("Migration: 7 -> 8 finished");
}
if services().globals.database_version()? < 9 {
// Update tokenids db layout
let mut iter = db
.tokenids
.iter()
.filter_map(|(key, _)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(4, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let word = parts.next().unwrap();
let _pdu_id_room = parts.next().unwrap();
let pdu_id_count = parts.next().unwrap();
let short_room_id =
db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(word);
new_key.push(0xFF);
new_key.extend_from_slice(pdu_id_count);
Some((new_key, Vec::new()))
})
.peekable();
while iter.peek().is_some() {
db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?;
debug!("Inserted smaller batch");
}
info!("Deleting starts");
let batch2: Vec<_> = db
.tokenids
.iter()
.filter_map(|(key, _)| {
if key.starts_with(b"!") {
Some(key)
} else {
None
}
})
.collect();
for key in batch2 {
db.tokenids.remove(&key)?;
}
services().globals.bump_database_version(9)?;
warn!("Migration: 8 -> 9 finished");
}
if services().globals.database_version()? < 10 {
// Add other direction for shortstatekeys
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
db.shortstatekey_statekey.insert(&shortstatekey, &statekey)?;
}
// Force E2EE device list updates so we can send them over federation
for user_id in services().users.iter().filter_map(std::result::Result::ok) {
services().users.mark_device_key_update(&user_id)?;
}
services().globals.bump_database_version(10)?;
warn!("Migration: 9 -> 10 finished");
}
if services().globals.database_version()? < 11 {
db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?;
services().globals.bump_database_version(11)?;
warn!("Migration: 10 -> 11 finished");
}
if services().globals.database_version()? < 12 {
for username in services().users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), services().globals.server_name())
{
Ok(u) => u,
Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
},
};
let raw_rules_list = services()
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
.expect("Username is invalid");
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
let rules_list = &mut account_data.content.global;
//content rule
{
let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"];
let rule = rules_list.content.get(content_rule_transformation[0]);
if rule.is_some() {
let mut rule = rule.unwrap().clone();
content_rule_transformation[1].clone_into(&mut rule.rule_id);
rules_list.content.shift_remove(content_rule_transformation[0]);
rules_list.content.insert(rule);
}
}
//underride rules
{
let underride_rule_transformation = [
[".m.rules.call", ".m.rule.call"],
[".m.rules.room_one_to_one", ".m.rule.room_one_to_one"],
[".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"],
[".m.rules.message", ".m.rule.message"],
[".m.rules.encrypted", ".m.rule.encrypted"],
];
for transformation in underride_rule_transformation {
let rule = rules_list.underride.get(transformation[0]);
if let Some(rule) = rule {
let mut rule = rule.clone();
transformation[1].clone_into(&mut rule.rule_id);
rules_list.underride.shift_remove(transformation[0]);
rules_list.underride.insert(rule);
}
}
}
services().account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
}
services().globals.bump_database_version(12)?;
warn!("Migration: 11 -> 12 finished");
}
// This migration can be reused as-is anytime the server-default rules are
// updated.
if services().globals.database_version()? < 13 {
for username in services().users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), services().globals.server_name())
{
Ok(u) => u,
Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
},
};
let raw_rules_list = services()
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
.expect("Username is invalid");
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
let user_default_rules = Ruleset::server_default(&user);
account_data.content.global.update_with_server_default(user_default_rules);
services().account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
}
services().globals.bump_database_version(13)?;
warn!("Migration: 12 -> 13 finished");
}
#[cfg(feature = "sha256_media")]
{
if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") {
warn!(
"sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names"
);
// Move old media files to new names
for (key, _) in db.mediaid_file.iter() {
let old_path = services().globals.get_media_file(&key);
debug!("Old file path: {old_path:?}");
let path = services().globals.get_media_file_new(&key);
debug!("New file path: {path:?}");
// move the file to the new location
if old_path.exists() {
tokio::fs::rename(&old_path, &path).await?;
}
}
services().globals.bump_database_version(14)?;
warn!("Migration: 13 -> 14 finished");
}
}
assert_eq!(
services().globals.database_version().unwrap(),
latest_database_version,
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
services().globals.database_version().unwrap(),
latest_database_version
);
{
let patterns = &services().globals.config.forbidden_usernames;
if !patterns.is_empty() {
for user in services().users.iter() {
let user_id = user?;
let matches = patterns.matches(user_id.localpart());
if matches.matched_any() {
warn!(
"User {} matches the following forbidden username patterns: {}",
user_id.to_string(),
matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ")
);
}
}
}
}
{
let patterns = &services().globals.config.forbidden_room_names;
if !patterns.is_empty() {
for address in services().rooms.metadata.iter_ids() {
let room_id = address?;
let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id);
for room_alias_result in room_aliases {
let room_alias = room_alias_result?;
let matches = patterns.matches(room_alias.alias());
if matches.matched_any() {
warn!(
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
room_alias,
&room_id,
matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ")
);
}
}
}
}
}
info!(
"Loaded {} database with version {}",
services().globals.config.database_backend,
latest_database_version
);
} else {
services().globals.bump_database_version(latest_database_version)?;
// Create the admin room and server user on first run
services().admin.create_admin_room().await?;
warn!(
"Created new {} database with version {}",
services().globals.config.database_backend,
latest_database_version
);
}
// Inserting registrations into cache
for appservice in services().appservice.all()? {
services().appservice.registration_info.write().await.insert(
appservice.0,
appservice.1.try_into().expect("Should be validated on registration"),
);
}
services().admin.start_handler();
// Set emergency access for the conduit user
match set_emergency_access() {
Ok(pwd_set) => {
if pwd_set {
warn!(
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
account recovery!"
);
services().admin.send_message(RoomMessageEventContent::text_plain(
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
account recovery!",
));
}
},
Err(e) => {
error!("Could not set the configured emergency password for the conduit user: {}", e);
},
};
services().sending.start_handler();
Self::start_cleanup_task().await;
if services().globals.allow_check_for_updates() {
Self::start_check_for_updates_task();
}
if services().globals.allow_local_presence() {
Self::start_presence_handler(presence_receiver).await;
}
Ok(())
}
pub fn flush(&self) -> Result<()> {
let start = std::time::Instant::now();
let res = self.db.flush();
debug!("flush: took {:?}", start.elapsed());
res
}
#[tracing::instrument]
fn start_check_for_updates_task() {
tokio::spawn(async move {
let timer_interval = Duration::from_secs(60 * 60);
let mut i = interval(timer_interval);
loop {
i.tick().await;
let _ = Self::try_handle_updates().await;
}
});
}
async fn try_handle_updates() -> Result<()> {
let response =
services().globals.default_client().get("https://pupbrain.dev/check-for-updates/stable").send().await?;
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?).map_err(|e| {
error!("Bad check for updates response: {e}");
Error::BadServerResponse("Bad version check response")
})?;
let mut last_update_id = services().globals.last_check_for_updates_id()?;
for update in response.updates {
last_update_id = last_update_id.max(update.id);
if update.id > services().globals.last_check_for_updates_id()? {
error!("{}", update.message);
services().admin.send_message(RoomMessageEventContent::text_plain(format!(
"@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}",
update.date, update.message
)));
}
}
services().globals.update_check_for_updates_id(last_update_id)?;
Ok(())
}
fn perform_cleanup() {
let start = Instant::now();
if let Err(e) = services().globals.cleanup() {
error!(target: "database-cleanup", "Ran into an error during cleanup: {}", e);
} else {
debug!(target: "database-cleanup", "Finished cleanup in {:#?}.", start.elapsed());
}
}
#[tracing::instrument]
async fn start_cleanup_task() {
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
let timer_interval = Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
tokio::spawn(async move {
let mut i = interval(timer_interval);
#[cfg(unix)]
let mut hangup = signal(SignalKind::hangup()).unwrap();
let mut ctrl_c = signal(SignalKind::interrupt()).unwrap();
let mut terminate = signal(SignalKind::terminate()).unwrap();
loop {
#[cfg(unix)]
tokio::select! {
_ = i.tick() => {
debug!(target: "database-cleanup", "Timer ticked");
}
_ = hangup.recv() => {
debug!(target: "database-cleanup","Received SIGHUP");
}
_ = ctrl_c.recv() => {
debug!(target: "database-cleanup", "Received Ctrl+C");
}
_ = terminate.recv() => {
debug!(target: "database-cleanup","Received SIGTERM");
}
};
#[cfg(not(unix))]
{
i.tick().await;
debug!(target: "database-cleanup", "Timer ticked")
}
Self::perform_cleanup();
}
});
}
async fn start_presence_handler(presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>) {
tokio::spawn(async move {
match presence_handler(presence_timer_receiver).await {
Ok(()) => warn!("Presence maintenance task finished"),
Err(e) => error!("Presence maintenance task finished with error: {e}"),
}
});
}
}
/// Sets the emergency password and push rules for the @conduit account in case
/// emergency password is set
fn set_emergency_access() -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is a valid UserId");
services().users.set_password(&conduit_user, services().globals.emergency_password().as_deref())?;
let (ruleset, res) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
None => (Ruleset::new(), Ok(false)),
};
services().account_data.update(
None,
&conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
res
}