Refactor server_keys service/interface and related callsites

Signed-off-by: Jason Volk <jason@zemos.net>
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
Jason Volk 2024-10-11 18:57:59 +00:00 committed by strawberry
parent d82ea331cf
commit c0939c3e9a
30 changed files with 1025 additions and 1378 deletions

View file

@ -1,16 +1,9 @@
use std::{
collections::BTreeMap,
sync::{Arc, RwLock},
};
use std::sync::{Arc, RwLock};
use conduit::{trace, utils, utils::rand, Error, Result, Server};
use database::{Database, Deserialized, Json, Map};
use conduit::{trace, utils, Result, Server};
use database::{Database, Deserialized, Map};
use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt};
use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
};
use ruma::{DeviceId, UserId};
use crate::{rooms, Dep};
@ -25,7 +18,6 @@ pub struct Data {
pduid_pdu: Arc<Map>,
keychangeid_userid: Arc<Map>,
roomusertype_roomuserdataid: Arc<Map>,
server_signingkeys: Arc<Map>,
readreceiptid_readreceipt: Arc<Map>,
userid_lastonetimekeyupdate: Arc<Map>,
counter: RwLock<u64>,
@ -56,7 +48,6 @@ impl Data {
pduid_pdu: db["pduid_pdu"].clone(),
keychangeid_userid: db["keychangeid_userid"].clone(),
roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(),
server_signingkeys: db["server_signingkeys"].clone(),
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")),
@ -205,107 +196,6 @@ impl Data {
Ok(())
}
pub fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let generate = |_| {
let keypair = Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)");
let mut value = rand::string(8).as_bytes().to_vec();
value.push(0xFF);
value.extend_from_slice(&keypair);
self.global.insert(b"keypair", &value);
value
};
let keypair_bytes: Vec<u8> = self
.global
.get_blocking(b"keypair")
.map_or_else(generate, Into::into);
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
utils::string_from_bytes(
// 1. version
parts
.next()
.expect("splitn always returns at least one element"),
)
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
.and_then(|version| {
// 2. key
parts
.next()
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
.map(|key| (version, key))
})
.and_then(|(version, key)| {
Ed25519KeyPair::from_der(key, version)
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
})
}
#[inline]
pub fn remove_keypair(&self) -> Result<()> {
self.global.remove(b"keypair");
Ok(())
}
/// TODO: the key valid until timestamp (`valid_until_ts`) is only honored
/// in room version > 4
///
/// Remove the outdated keys and insert the new ones.
///
/// This doesn't actually check that the keys provided are newer than the
/// old set.
pub async fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> BTreeMap<OwnedServerSigningKeyId, VerifyKey> {
// (timo) Not atomic, but this is not critical
let mut keys: ServerSigningKeys = self
.server_signingkeys
.get(origin)
.await
.deserialized()
.unwrap_or_else(|_| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
keys.verify_keys.extend(new_keys.verify_keys);
keys.old_verify_keys.extend(new_keys.old_verify_keys);
self.server_signingkeys.raw_put(origin, Json(&keys));
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
tree
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
pub async fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
self.signing_keys_for(origin).await.map_or_else(
|_| Ok(BTreeMap::new()),
|keys: ServerSigningKeys| {
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree)
},
)
}
pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSigningKeys> {
self.server_signingkeys.get(origin).await.deserialized()
}
pub async fn database_version(&self) -> u64 {
self.global
.get(b"version")

View file

@ -2,7 +2,7 @@ mod data;
pub(super) mod migrations;
use std::{
collections::{BTreeMap, HashMap},
collections::HashMap,
fmt::Write,
sync::{Arc, RwLock},
time::Instant,
@ -13,13 +13,8 @@ use data::Data;
use ipaddress::IPAddress;
use regex::RegexSet;
use ruma::{
api::{
client::discovery::discover_support::ContactRole,
federation::discovery::{ServerSigningKeys, VerifyKey},
},
serde::Base64,
DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomAliasId,
RoomVersionId, ServerName, UserId,
api::client::discovery::discover_support::ContactRole, DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName,
OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UserId,
};
use tokio::sync::Mutex;
use url::Url;
@ -31,7 +26,6 @@ pub struct Service {
pub config: Config,
pub cidr_range_denylist: Vec<IPAddress>,
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
jwt_decoding_key: Option<jsonwebtoken::DecodingKey>,
pub stable_room_versions: Vec<RoomVersionId>,
pub unstable_room_versions: Vec<RoomVersionId>,
@ -50,16 +44,6 @@ impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let db = Data::new(&args);
let config = &args.server.config;
let keypair = db.load_keypair();
let keypair = match keypair {
Ok(k) => k,
Err(e) => {
error!("Keypair invalid. Deleting...");
db.remove_keypair()?;
return Err(e);
},
};
let jwt_decoding_key = config
.jwt_secret
@ -115,7 +99,6 @@ impl crate::Service for Service {
db,
config: config.clone(),
cidr_range_denylist,
keypair: Arc::new(keypair),
jwt_decoding_key,
stable_room_versions,
unstable_room_versions,
@ -175,9 +158,6 @@ impl crate::Service for Service {
}
impl Service {
/// Returns this server's keypair.
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
#[inline]
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
@ -224,8 +204,6 @@ impl Service {
pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers }
pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first }
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() }
pub fn turn_password(&self) -> &String { &self.config.turn_password }
@ -302,28 +280,6 @@ impl Service {
}
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
pub async fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let mut keys = self.db.verify_keys_for(origin).await?;
if origin == self.server_name() {
keys.insert(
format!("ed25519:{}", self.keypair().version())
.try_into()
.expect("found invalid server signing keys in DB"),
VerifyKey {
key: Base64::new(self.keypair.public_key().to_vec()),
},
);
}
Ok(keys)
}
pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSigningKeys> {
self.db.signing_keys_for(origin).await
}
pub fn well_known_client(&self) -> &Option<Url> { &self.config.well_known.client }
pub fn well_known_server(&self) -> &Option<OwnedServerName> { &self.config.well_known.server }

View file

@ -28,12 +28,10 @@ use ruma::{
StateEventType, TimelineEventType,
},
int,
serde::Base64,
state_res::{self, EventTypeExt, RoomVersion, StateMap},
uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
ServerName, UserId,
uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId,
RoomId, RoomVersionId, ServerName, UserId,
};
use tokio::sync::RwLock;
use super::state_compressor::CompressedStateEvent;
use crate::{globals, rooms, sending, server_keys, Dep};
@ -129,11 +127,10 @@ impl Service {
/// 13. Use state resolution to find new room state
/// 14. Check if the event passes auth based on the "current state" of the
/// room, if not soft fail it
#[tracing::instrument(skip(self, origin, value, is_timeline_event, pub_key_map), name = "pdu")]
#[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")]
pub async fn handle_incoming_pdu<'a>(
&self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId,
value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> {
// 1. Skip the PDU if we already have it as a timeline event
if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await {
@ -177,7 +174,7 @@ impl Service {
let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?;
let (incoming_pdu, val) = self
.handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map)
.handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false)
.boxed()
.await?;
@ -200,7 +197,6 @@ impl Service {
&create_event,
room_id,
&room_version_id,
pub_key_map,
incoming_pdu.prev_events.clone(),
)
.await?;
@ -212,7 +208,6 @@ impl Service {
origin,
event_id,
room_id,
pub_key_map,
&mut eventid_info,
&create_event,
&first_pdu_in_room,
@ -250,7 +245,7 @@ impl Service {
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
let r = self
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map)
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id)
.await;
self.federation_handletime
@ -264,12 +259,11 @@ impl Service {
#[allow(clippy::type_complexity)]
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(
skip(self, origin, event_id, room_id, pub_key_map, eventid_info, create_event, first_pdu_in_room),
skip(self, origin, event_id, room_id, eventid_info, create_event, first_pdu_in_room),
name = "prev"
)]
pub async fn handle_prev_pdu<'a>(
&self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
eventid_info: &mut HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId,
) -> Result<()> {
@ -318,7 +312,7 @@ impl Service {
.expect("locked")
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id, pub_key_map)
self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id)
.await?;
self.federation_handletime
@ -338,8 +332,7 @@ impl Service {
#[allow(clippy::too_many_arguments)]
async fn handle_outlier_pdu<'a>(
&self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId,
mut value: BTreeMap<String, CanonicalJsonValue>, auth_events_known: bool,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
mut value: CanonicalJsonObject, auth_events_known: bool,
) -> Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)> {
// 1. Remove unsigned field
value.remove("unsigned");
@ -349,14 +342,13 @@ impl Service {
// 2. Check signatures, otherwise drop
// 3. check content hash, redact if doesn't match
let room_version_id = Self::get_room_version_id(create_event)?;
let guard = pub_key_map.read().await;
let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) {
Err(e) => {
// Drop
warn!("Dropping bad event {event_id}: {e}");
return Err!(Request(InvalidParam("Signature verification failed")));
},
let mut val = match self
.services
.server_keys
.verify_event(&value, Some(&room_version_id))
.await
{
Ok(ruma::signatures::Verified::All) => value,
Ok(ruma::signatures::Verified::Signatures) => {
// Redact
debug_info!("Calculated hash does not match (redaction): {event_id}");
@ -371,11 +363,13 @@ impl Service {
obj
},
Ok(ruma::signatures::Verified::All) => value,
Err(e) => {
return Err!(Request(InvalidParam(debug_error!(
"Signature verification failed for {event_id}: {e}"
))))
},
};
drop(guard);
// Now that we have checked the signature and hashes we can add the eventID and
// convert to our PduEvent type
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
@ -404,7 +398,6 @@ impl Service {
create_event,
room_id,
&room_version_id,
pub_key_map,
),
)
.await;
@ -487,7 +480,7 @@ impl Service {
pub async fn upgrade_outlier_to_timeline_pdu(
&self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent,
origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
origin: &ServerName, room_id: &RoomId,
) -> Result<Option<Vec<u8>>> {
// Skip the PDU if we already have it as a timeline event
if let Ok(pduid) = self
@ -526,14 +519,7 @@ impl Service {
if state_at_incoming_event.is_none() {
state_at_incoming_event = self
.fetch_state(
origin,
create_event,
room_id,
&room_version_id,
pub_key_map,
&incoming_pdu.event_id,
)
.fetch_state(origin, create_event, room_id, &room_version_id, &incoming_pdu.event_id)
.await?;
}
@ -1021,10 +1007,10 @@ impl Service {
/// Call /state_ids to find out what the state at this pdu is. We trust the
/// server's response to some extend (sic), but we still do a lot of checks
/// on the events
#[tracing::instrument(skip(self, pub_key_map, create_event, room_version_id))]
#[tracing::instrument(skip(self, create_event, room_version_id))]
async fn fetch_state(
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, event_id: &EventId,
event_id: &EventId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Fetching state ids");
let res = self
@ -1048,7 +1034,7 @@ impl Service {
.collect::<Vec<_>>();
let state_vec = self
.fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map)
.fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id)
.boxed()
.await;
@ -1102,7 +1088,7 @@ impl Service {
/// d. TODO: Ask other servers over federation?
pub async fn fetch_and_handle_outliers<'a>(
&self, origin: &'a ServerName, events: &'a [Arc<EventId>], create_event: &'a PduEvent, room_id: &'a RoomId,
room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
room_version_id: &'a RoomVersionId,
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> {
let back_off = |id| match self
.services
@ -1222,22 +1208,6 @@ impl Service {
events_with_auth_events.push((id, None, events_in_reverse_order));
}
// We go through all the signatures we see on the PDUs and their unresolved
// dependencies and fetch the corresponding signing keys
self.services
.server_keys
.fetch_required_signing_keys(
events_with_auth_events
.iter()
.flat_map(|(_id, _local_pdu, events)| events)
.map(|(_event_id, event)| event),
pub_key_map,
)
.await
.unwrap_or_else(|e| {
warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}");
});
let mut pdus = Vec::with_capacity(events_with_auth_events.len());
for (id, local_pdu, events_in_reverse_order) in events_with_auth_events {
// a. Look in the main timeline (pduid_pdu tree)
@ -1266,16 +1236,8 @@ impl Service {
}
}
match Box::pin(self.handle_outlier_pdu(
origin,
create_event,
&next_id,
room_id,
value.clone(),
true,
pub_key_map,
))
.await
match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true))
.await
{
Ok((pdu, json)) => {
if next_id == *id {
@ -1296,7 +1258,7 @@ impl Service {
#[tracing::instrument(skip_all)]
async fn fetch_prev(
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, initial_set: Vec<Arc<EventId>>,
initial_set: Vec<Arc<EventId>>,
) -> Result<(
Vec<Arc<EventId>>,
HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
@ -1311,14 +1273,7 @@ impl Service {
while let Some(prev_event_id) = todo_outlier_stack.pop() {
if let Some((pdu, mut json_opt)) = self
.fetch_and_handle_outliers(
origin,
&[prev_event_id.clone()],
create_event,
room_id,
room_version_id,
pub_key_map,
)
.fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id)
.boxed()
.await
.pop()

View file

@ -16,7 +16,7 @@ use conduit::{
};
use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryStreamExt};
use ruma::{
api::{client::error::ErrorKind, federation},
api::federation,
canonical_json::to_canonical_value,
events::{
push_rules::PushRulesEvent,
@ -30,14 +30,12 @@ use ruma::{
GlobalAccountDataEventType, StateEventType, TimelineEventType,
},
push::{Action, Ruleset, Tweak},
serde::Base64,
state_res::{self, Event, RoomVersion},
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName,
RoomId, RoomVersionId, ServerName, UserId,
};
use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock;
use self::data::Data;
pub use self::data::PdusIterItem;
@ -784,21 +782,15 @@ impl Service {
to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"),
);
match ruma::signatures::hash_and_sign_event(
self.services.globals.server_name().as_str(),
self.services.globals.keypair(),
&mut pdu_json,
&room_version_id,
) {
Ok(()) => {},
Err(e) => {
return match e {
ruma::signatures::Error::PduSize => {
Err(Error::BadRequest(ErrorKind::TooLarge, "Message is too long"))
},
_ => Err(Error::BadRequest(ErrorKind::Unknown, "Signing event failed")),
}
},
if let Err(e) = self
.services
.server_keys
.hash_and_sign_event(&mut pdu_json, &room_version_id)
{
return match e {
Error::Signatures(ruma::signatures::Error::PduSize) => Err!(Request(TooLarge("Message is too long"))),
_ => Err!(Request(Unknown("Signing event failed"))),
};
}
// Generate event id
@ -1106,9 +1098,8 @@ impl Service {
.await;
match response {
Ok(response) => {
let pub_key_map = RwLock::new(BTreeMap::new());
for pdu in response.pdus {
if let Err(e) = self.backfill_pdu(backfill_server, pdu, &pub_key_map).await {
if let Err(e) = self.backfill_pdu(backfill_server, pdu).await {
warn!("Failed to add backfilled pdu in room {room_id}: {e}");
}
}
@ -1124,11 +1115,8 @@ impl Service {
Ok(())
}
#[tracing::instrument(skip(self, pdu, pub_key_map))]
pub async fn backfill_pdu(
&self, origin: &ServerName, pdu: Box<RawJsonValue>,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
#[tracing::instrument(skip(self, pdu))]
pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) -> Result<()> {
let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?;
// Lock so we cannot backfill the same pdu twice at the same time
@ -1146,14 +1134,9 @@ impl Service {
return Ok(());
}
self.services
.server_keys
.fetch_required_signing_keys([&value], pub_key_map)
.await?;
self.services
.event_handler
.handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map)
.handle_incoming_pdu(origin, &room_id, &event_id, value, false)
.await?;
let value = self

View file

@ -17,7 +17,7 @@ use tokio::sync::Mutex;
use self::data::Data;
pub use self::dest::Destination;
use crate::{account_data, client, globals, presence, pusher, resolver, rooms, users, Dep};
use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_keys, users, Dep};
pub struct Service {
server: Arc<Server>,
@ -41,6 +41,7 @@ struct Services {
account_data: Dep<account_data::Service>,
appservice: Dep<crate::appservice::Service>,
pusher: Dep<pusher::Service>,
server_keys: Dep<server_keys::Service>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@ -78,6 +79,7 @@ impl crate::Service for Service {
account_data: args.depend::<account_data::Service>("account_data"),
appservice: args.depend::<crate::appservice::Service>("appservice"),
pusher: args.depend::<pusher::Service>("pusher"),
server_keys: args.depend::<server_keys::Service>("server_keys"),
},
db: Data::new(&args),
sender,

View file

@ -1,8 +1,8 @@
use std::{fmt::Debug, mem};
use conduit::{
debug, debug_error, debug_info, debug_warn, err, error::inspect_debug_log, trace, utils::string::EMPTY, Err, Error,
Result,
debug, debug_error, debug_info, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY,
Err, Error, Result,
};
use http::{header::AUTHORIZATION, HeaderValue};
use ipaddress::IPAddress;
@ -18,7 +18,7 @@ use ruma::{
};
use crate::{
globals, resolver,
resolver,
resolver::{actual::ActualDest, cache::CachedDest},
};
@ -75,7 +75,7 @@ impl super::Service {
.try_into_http_request::<Vec<u8>>(&actual.string, SATIR, &VERSIONS)
.map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?;
sign_request::<T>(&self.services.globals, dest, &mut http_request);
self.sign_request::<T>(dest, &mut http_request);
let request = Request::try_from(http_request)?;
self.validate_url(request.url())?;
@ -178,7 +178,8 @@ where
Err(e.into())
}
fn sign_request<T>(globals: &globals::Service, dest: &ServerName, http_request: &mut http::Request<Vec<u8>>)
#[implement(super::Service)]
fn sign_request<T>(&self, dest: &ServerName, http_request: &mut http::Request<Vec<u8>>)
where
T: OutgoingRequest + Debug + Send,
{
@ -200,11 +201,13 @@ where
.to_string()
.into(),
);
req_map.insert("origin".to_owned(), globals.server_name().as_str().into());
req_map.insert("origin".to_owned(), self.services.globals.server_name().to_string().into());
req_map.insert("destination".to_owned(), dest.as_str().into());
let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap");
ruma::signatures::sign_json(globals.server_name().as_str(), globals.keypair(), &mut req_json)
self.services
.server_keys
.sign_json(&mut req_json)
.expect("our request json is what ruma expects");
let req_json: serde_json::Map<String, serde_json::Value> =
@ -231,7 +234,12 @@ where
http_request.headers_mut().insert(
AUTHORIZATION,
HeaderValue::from(&XMatrix::new(globals.config.server_name.clone(), dest.to_owned(), key, sig)),
HeaderValue::from(&XMatrix::new(
self.services.globals.server_name().to_owned(),
dest.to_owned(),
key,
sig,
)),
);
}
}

View file

@ -0,0 +1,175 @@
use std::{
borrow::Borrow,
collections::{BTreeMap, BTreeSet},
};
use conduit::{debug, debug_warn, error, implement, result::FlatOk, warn};
use futures::{stream::FuturesUnordered, StreamExt};
use ruma::{
api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject, OwnedServerName,
OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
};
use serde_json::value::RawValue as RawJsonValue;
use super::key_exists;
type Batch = BTreeMap<OwnedServerName, Vec<OwnedServerSigningKeyId>>;
#[implement(super::Service)]
pub async fn acquire_events_pubkeys<'a, I>(&self, events: I)
where
I: Iterator<Item = &'a Box<RawJsonValue>> + Send,
{
type Batch = BTreeMap<OwnedServerName, BTreeSet<OwnedServerSigningKeyId>>;
type Signatures = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, String>>;
let mut batch = Batch::new();
events
.cloned()
.map(Raw::<CanonicalJsonObject>::from_json)
.map(|event| event.get_field::<Signatures>("signatures"))
.filter_map(FlatOk::flat_ok)
.flat_map(IntoIterator::into_iter)
.for_each(|(server, sigs)| {
batch.entry(server).or_default().extend(sigs.into_keys());
});
let batch = batch
.iter()
.map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow)));
self.acquire_pubkeys(batch).await;
}
#[implement(super::Service)]
pub async fn acquire_pubkeys<'a, S, K>(&self, batch: S)
where
S: Iterator<Item = (&'a ServerName, K)> + Send + Clone,
K: Iterator<Item = &'a ServerSigningKeyId> + Send + Clone,
{
let requested_servers = batch.clone().count();
let requested_keys = batch.clone().flat_map(|(_, key_ids)| key_ids).count();
debug!("acquire {requested_keys} keys from {requested_servers}");
let missing = self.acquire_locals(batch).await;
let missing_keys = keys_count(&missing);
let missing_servers = missing.len();
if missing_servers == 0 {
return;
}
debug!("missing {missing_keys} keys for {missing_servers} servers locally");
let missing = self.acquire_origins(missing.into_iter()).await;
let missing_keys = keys_count(&missing);
let missing_servers = missing.len();
if missing_servers == 0 {
return;
}
debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable");
let missing = self.acquire_notary(missing.into_iter()).await;
let missing_keys = keys_count(&missing);
let missing_servers = missing.len();
if missing_keys > 0 {
debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries");
warn!("did not obtain {missing_keys} of {requested_keys} keys; some events may not be accepted");
}
}
#[implement(super::Service)]
async fn acquire_locals<'a, S, K>(&self, batch: S) -> Batch
where
S: Iterator<Item = (&'a ServerName, K)> + Send,
K: Iterator<Item = &'a ServerSigningKeyId> + Send,
{
let mut missing = Batch::new();
for (server, key_ids) in batch {
for key_id in key_ids {
if !self.verify_key_exists(server, key_id).await {
missing
.entry(server.into())
.or_default()
.push(key_id.into());
}
}
}
missing
}
#[implement(super::Service)]
async fn acquire_origins<I>(&self, batch: I) -> Batch
where
I: Iterator<Item = (OwnedServerName, Vec<OwnedServerSigningKeyId>)> + Send,
{
let mut requests: FuturesUnordered<_> = batch
.map(|(origin, key_ids)| self.acquire_origin(origin, key_ids))
.collect();
let mut missing = Batch::new();
while let Some((origin, key_ids)) = requests.next().await {
if !key_ids.is_empty() {
missing.insert(origin, key_ids);
}
}
missing
}
#[implement(super::Service)]
async fn acquire_origin(
&self, origin: OwnedServerName, mut key_ids: Vec<OwnedServerSigningKeyId>,
) -> (OwnedServerName, Vec<OwnedServerSigningKeyId>) {
if let Ok(server_keys) = self.server_request(&origin).await {
self.add_signing_keys(server_keys.clone()).await;
key_ids.retain(|key_id| !key_exists(&server_keys, key_id));
}
(origin, key_ids)
}
#[implement(super::Service)]
async fn acquire_notary<I>(&self, batch: I) -> Batch
where
I: Iterator<Item = (OwnedServerName, Vec<OwnedServerSigningKeyId>)> + Send,
{
let mut missing: Batch = batch.collect();
for notary in self.services.globals.trusted_servers() {
let missing_keys = keys_count(&missing);
let missing_servers = missing.len();
debug!("Asking notary {notary} for {missing_keys} missing keys from {missing_servers} servers");
let batch = missing
.iter()
.map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow)));
match self.batch_notary_request(notary, batch).await {
Err(e) => error!("Failed to contact notary {notary:?}: {e}"),
Ok(results) => {
for server_keys in results {
self.acquire_notary_result(&mut missing, server_keys).await;
}
},
}
}
missing
}
#[implement(super::Service)]
async fn acquire_notary_result(&self, missing: &mut Batch, server_keys: ServerSigningKeys) {
let server = &server_keys.server_name;
self.add_signing_keys(server_keys.clone()).await;
if let Some(key_ids) = missing.get_mut(server) {
key_ids.retain(|key_id| key_exists(&server_keys, key_id));
if key_ids.is_empty() {
missing.remove(server);
}
}
}
fn keys_count(batch: &Batch) -> usize { batch.iter().flat_map(|(_, key_ids)| key_ids.iter()).count() }

View file

@ -0,0 +1,86 @@
use std::borrow::Borrow;
use conduit::{implement, Err, Result};
use ruma::{api::federation::discovery::VerifyKey, CanonicalJsonObject, RoomVersionId, ServerName, ServerSigningKeyId};
use super::{extract_key, PubKeyMap, PubKeys};
#[implement(super::Service)]
pub async fn get_event_keys(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> Result<PubKeyMap> {
use ruma::signatures::required_keys;
let required = match required_keys(object, version) {
Ok(required) => required,
Err(e) => return Err!(BadServerResponse("Failed to determine keys required to verify: {e}")),
};
let batch = required
.iter()
.map(|(s, ids)| (s.borrow(), ids.iter().map(Borrow::borrow)));
Ok(self.get_pubkeys(batch).await)
}
#[implement(super::Service)]
pub async fn get_pubkeys<'a, S, K>(&self, batch: S) -> PubKeyMap
where
S: Iterator<Item = (&'a ServerName, K)> + Send,
K: Iterator<Item = &'a ServerSigningKeyId> + Send,
{
let mut keys = PubKeyMap::new();
for (server, key_ids) in batch {
let pubkeys = self.get_pubkeys_for(server, key_ids).await;
keys.insert(server.into(), pubkeys);
}
keys
}
#[implement(super::Service)]
pub async fn get_pubkeys_for<'a, I>(&self, origin: &ServerName, key_ids: I) -> PubKeys
where
I: Iterator<Item = &'a ServerSigningKeyId> + Send,
{
let mut keys = PubKeys::new();
for key_id in key_ids {
if let Ok(verify_key) = self.get_verify_key(origin, key_id).await {
keys.insert(key_id.into(), verify_key.key);
}
}
keys
}
#[implement(super::Service)]
pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
if let Some(result) = self.verify_keys_for(origin).await.remove(key_id) {
return Ok(result);
}
if let Ok(server_key) = self.server_request(origin).await {
self.add_signing_keys(server_key.clone()).await;
if let Some(result) = extract_key(server_key, key_id) {
return Ok(result);
}
}
for notary in self.services.globals.trusted_servers() {
if let Ok(server_keys) = self.notary_request(notary, origin).await {
for server_key in &server_keys {
self.add_signing_keys(server_key.clone()).await;
}
for server_key in server_keys {
if let Some(result) = extract_key(server_key, key_id) {
return Ok(result);
}
}
}
}
Err!(BadServerResponse(debug_error!(
?key_id,
?origin,
"Failed to fetch federation signing-key"
)))
}

View file

@ -0,0 +1,64 @@
use std::sync::Arc;
use conduit::{debug, debug_info, err, error, utils, utils::string_from_bytes, Result};
use database::Database;
use ruma::{api::federation::discovery::VerifyKey, serde::Base64, signatures::Ed25519KeyPair};
use super::VerifyKeys;
pub(super) fn init(db: &Arc<Database>) -> Result<(Box<Ed25519KeyPair>, VerifyKeys)> {
let keypair = load(db).inspect_err(|_e| {
error!("Keypair invalid. Deleting...");
remove(db);
})?;
let verify_key = VerifyKey {
key: Base64::new(keypair.public_key().to_vec()),
};
let id = format!("ed25519:{}", keypair.version());
let verify_keys: VerifyKeys = [(id.try_into()?, verify_key)].into();
Ok((keypair, verify_keys))
}
fn load(db: &Arc<Database>) -> Result<Box<Ed25519KeyPair>> {
let (version, key) = db["global"]
.get_blocking(b"keypair")
.map(|ref val| {
// database deserializer is having trouble with this so it's manual for now
let mut elems = val.split(|&b| b == b'\xFF');
let vlen = elems.next().expect("invalid keypair entry").len();
let ver = string_from_bytes(&val[..vlen]).expect("invalid keypair version");
let der = val[vlen.saturating_add(1)..].to_vec();
debug!("Found existing Ed25519 keypair: {ver:?}");
(ver, der)
})
.or_else(|e| {
assert!(e.is_not_found(), "unexpected error fetching keypair");
create(db)
})?;
let key =
Ed25519KeyPair::from_der(&key, version).map_err(|e| err!("Failed to load ed25519 keypair from der: {e:?}"))?;
Ok(Box::new(key))
}
fn create(db: &Arc<Database>) -> Result<(String, Vec<u8>)> {
let keypair = Ed25519KeyPair::generate().map_err(|e| err!("Failed to generate new ed25519 keypair: {e:?}"))?;
let id = utils::rand::string(8);
debug_info!("Generated new Ed25519 keypair: {id:?}");
let value: (String, Vec<u8>) = (id, keypair.to_vec());
db["global"].raw_put(b"keypair", &value);
Ok(value)
}
#[inline]
fn remove(db: &Arc<Database>) {
let global = &db["global"];
global.remove(b"keypair");
}

View file

@ -1,31 +1,30 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
sync::Arc,
time::{Duration, SystemTime},
};
mod acquire;
mod get;
mod keypair;
mod request;
mod sign;
mod verify;
use conduit::{debug, debug_error, debug_warn, err, error, info, trace, warn, Err, Result};
use futures::{stream::FuturesUnordered, StreamExt};
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use conduit::{implement, utils::time::timepoint_from_now, Result};
use database::{Deserialized, Json, Map};
use ruma::{
api::federation::{
discovery::{
get_remote_server_keys,
get_remote_server_keys_batch::{self, v2::QueryCriteria},
get_server_keys,
},
membership::create_join_event,
},
serde::Base64,
CanonicalJsonObject, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedServerSigningKeyId,
RoomVersionId, ServerName,
api::federation::discovery::{ServerSigningKeys, VerifyKey},
serde::Raw,
signatures::{Ed25519KeyPair, PublicKeyMap, PublicKeySet},
MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
};
use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::{RwLock, RwLockWriteGuard};
use crate::{globals, sending, Dep};
pub struct Service {
keypair: Box<Ed25519KeyPair>,
verify_keys: VerifyKeys,
minimum_valid: Duration,
services: Services,
db: Data,
}
struct Services {
@ -33,546 +32,135 @@ struct Services {
sending: Dep<sending::Service>,
}
struct Data {
server_signingkeys: Arc<Map>,
}
pub type VerifyKeys = BTreeMap<OwnedServerSigningKeyId, VerifyKey>;
pub type PubKeyMap = PublicKeyMap;
pub type PubKeys = PublicKeySet;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let minimum_valid = Duration::from_secs(3600);
let (keypair, verify_keys) = keypair::init(args.db)?;
Ok(Arc::new(Self {
keypair,
verify_keys,
minimum_valid,
services: Services {
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
},
db: Data {
server_signingkeys: args.db["server_signingkeys"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
pub async fn fetch_required_signing_keys<'a, E>(
&'a self, events: E, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()>
where
E: IntoIterator<Item = &'a BTreeMap<String, CanonicalJsonValue>> + Send,
{
let mut server_key_ids = HashMap::new();
for event in events {
for (signature_server, signature) in event
.get("signatures")
.ok_or(err!(BadServerResponse("No signatures in server response pdu.")))?
.as_object()
.ok_or(err!(BadServerResponse("Invalid signatures object in server response pdu.")))?
{
let signature_object = signature.as_object().ok_or(err!(BadServerResponse(
"Invalid signatures content object in server response pdu.",
)))?;
#[implement(Service)]
#[inline]
pub fn keypair(&self) -> &Ed25519KeyPair { &self.keypair }
for signature_id in signature_object.keys() {
server_key_ids
.entry(signature_server.clone())
.or_insert_with(HashSet::new)
.insert(signature_id.clone());
}
}
}
#[implement(Service)]
async fn add_signing_keys(&self, new_keys: ServerSigningKeys) {
let origin = &new_keys.server_name;
if server_key_ids.is_empty() {
// Nothing to do, can exit early
trace!("server_key_ids is empty, not fetching any keys");
return Ok(());
}
// (timo) Not atomic, but this is not critical
let mut keys: ServerSigningKeys = self
.db
.server_signingkeys
.get(origin)
.await
.deserialized()
.unwrap_or_else(|_| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
trace!(
"Fetch keys for {}",
server_key_ids
.keys()
.cloned()
.collect::<Vec<_>>()
.join(", ")
);
let mut server_keys: FuturesUnordered<_> = server_key_ids
.into_iter()
.map(|(signature_server, signature_ids)| async {
let fetch_res = self
.fetch_signing_keys_for_server(
signature_server.as_str().try_into().map_err(|e| {
(
signature_server.clone(),
err!(BadServerResponse(
"Invalid servername in signatures of server response pdu: {e:?}"
)),
)
})?,
signature_ids.into_iter().collect(), // HashSet to Vec
)
.await;
match fetch_res {
Ok(keys) => Ok((signature_server, keys)),
Err(e) => {
debug_error!(
"Signature verification failed: Could not fetch signing key for {signature_server}: {e}",
);
Err((signature_server, e))
},
}
})
.collect();
while let Some(fetch_res) = server_keys.next().await {
match fetch_res {
Ok((signature_server, keys)) => {
pub_key_map
.write()
.await
.insert(signature_server.clone(), keys);
},
Err((signature_server, e)) => {
debug_warn!("Failed to fetch keys for {signature_server}: {e:?}");
},
}
}
Ok(())
}
// Gets a list of servers for which we don't have the signing key yet. We go
// over the PDUs and either cache the key or add it to the list that needs to be
// retrieved.
async fn get_server_keys_from_cache(
&self, pdu: &RawJsonValue,
servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
_room_version: &RoomVersionId,
pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
debug_error!("Invalid PDU in server response: {pdu:#?}");
err!(BadServerResponse(error!("Invalid PDU in server response: {e:?}")))
})?;
let signatures = value
.get("signatures")
.ok_or(err!(BadServerResponse("No signatures in server response pdu.")))?
.as_object()
.ok_or(err!(BadServerResponse("Invalid signatures object in server response pdu.")))?;
for (signature_server, signature) in signatures {
let signature_object = signature.as_object().ok_or(err!(BadServerResponse(
"Invalid signatures content object in server response pdu.",
)))?;
let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>();
let contains_all_ids =
|keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|e| {
err!(BadServerResponse(
"Invalid servername in signatures of server response pdu: {e:?}"
))
})?;
if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) {
continue;
}
debug!("Loading signing keys for {origin}");
let result: BTreeMap<_, _> = self
.services
.globals
.verify_keys_for(origin)
.await?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
if !contains_all_ids(&result) {
debug_warn!("Signing key not loaded for {origin}");
servers.insert(origin.to_owned(), BTreeMap::new());
}
pub_key_map.insert(origin.to_string(), result);
}
Ok(())
}
/// Batch requests homeserver signing keys from trusted notary key servers
/// (`trusted_servers` config option)
async fn batch_request_signing_keys(
&self, mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
for server in self.services.globals.trusted_servers() {
debug!("Asking batch signing keys from trusted server {server}");
match self
.services
.sending
.send_federation_request(
server,
get_remote_server_keys_batch::v2::Request {
server_keys: servers.clone(),
},
)
.await
{
Ok(keys) => {
debug!("Got signing keys: {keys:?}");
let mut pkm = pub_key_map.write().await;
for k in keys.server_keys {
let k = match k.deserialize() {
Ok(key) => key,
Err(e) => {
warn!(
"Received error {e} while fetching keys from trusted server {server}: {:#?}",
k.into_json()
);
continue;
},
};
// TODO: Check signature from trusted server?
servers.remove(&k.server_name);
let result = self
.services
.globals
.db
.add_signing_key(&k.server_name, k.clone())
.await
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect::<BTreeMap<_, _>>();
pkm.insert(k.server_name.to_string(), result);
}
},
Err(e) => error!(
"Failed sending batched key request to trusted key server {server} for the remote servers \
{servers:?}: {e}"
),
}
}
Ok(())
}
/// Requests multiple homeserver signing keys from individual servers (not
/// trused notary servers)
async fn request_signing_keys(
&self, servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
debug!("Asking individual servers for signing keys: {servers:?}");
let mut futures: FuturesUnordered<_> = servers
.into_keys()
.map(|server| async move {
(
self.services
.sending
.send_federation_request(&server, get_server_keys::v2::Request::new())
.await,
server,
)
})
.collect();
while let Some(result) = futures.next().await {
debug!("Received new Future result");
if let (Ok(get_keys_response), origin) = result {
debug!("Result is from {origin}");
if let Ok(key) = get_keys_response.server_key.deserialize() {
let result: BTreeMap<_, _> = self
.services
.globals
.db
.add_signing_key(&origin, key)
.await
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
pub_key_map.write().await.insert(origin.to_string(), result);
}
}
debug!("Done handling Future result");
}
Ok(())
}
pub async fn fetch_join_signing_keys(
&self, event: &create_join_event::v2::Response, room_version: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
let mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>> = BTreeMap::new();
{
let mut pkm = pub_key_map.write().await;
// Try to fetch keys, failure is okay. Servers we couldn't find in the cache
// will be added to `servers`
for pdu in event
.room_state
.state
.iter()
.chain(&event.room_state.auth_chain)
{
if let Err(error) = self
.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm)
.await
{
debug!(%error, "failed to get server keys from cache");
};
}
drop(pkm);
};
if servers.is_empty() {
trace!("We had all keys cached locally, not fetching any keys from remote servers");
return Ok(());
}
if self.services.globals.query_trusted_key_servers_first() {
info!(
"query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \
homeserver signing keys."
);
self.batch_request_signing_keys(servers.clone(), pub_key_map)
.await?;
if servers.is_empty() {
debug!("Trusted server supplied all signing keys, no more keys to fetch");
return Ok(());
}
debug!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}");
self.request_signing_keys(servers.clone(), pub_key_map)
.await?;
} else {
debug!("query_trusted_key_servers_first is set to false, querying individual homeservers first");
self.request_signing_keys(servers.clone(), pub_key_map)
.await?;
if servers.is_empty() {
debug!("Individual homeservers supplied all signing keys, no more keys to fetch");
return Ok(());
}
debug!("Remaining servers left the individual homeservers did not provide: {servers:?}");
self.batch_request_signing_keys(servers.clone(), pub_key_map)
.await?;
}
debug!("Search for signing keys done");
/*if servers.is_empty() {
warn!("Failed to find homeserver signing keys for the remaining servers: {servers:?}");
}*/
Ok(())
}
/// Search the DB for the signing keys of the given server, if we don't have
/// them fetch them from the server and save to our DB.
#[tracing::instrument(skip_all)]
pub async fn fetch_signing_keys_for_server(
&self, origin: &ServerName, signature_ids: Vec<String>,
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let mut result: BTreeMap<_, _> = self
.services
.globals
.verify_keys_for(origin)
.await?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
if contains_all_ids(&result) {
trace!("We have all homeserver signing keys locally for {origin}, not fetching any remotely");
return Ok(result);
}
// i didnt split this out into their own functions because it's relatively small
if self.services.globals.query_trusted_key_servers_first() {
info!(
"query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \
keys"
);
for server in self.services.globals.trusted_servers() {
debug!("Asking notary server {server} for {origin}'s signing key");
if let Some(server_keys) = self
.services
.sending
.send_federation_request(
server,
get_remote_server_keys::v2::Request::new(
origin.to_owned(),
MilliSecondsSinceUnixEpoch::from_system_time(
SystemTime::now()
.checked_add(Duration::from_secs(3600))
.expect("SystemTime too large"),
)
.expect("time is valid"),
),
)
.await
.ok()
.map(|resp| {
resp.server_keys
.into_iter()
.filter_map(|e| e.deserialize().ok())
.collect::<Vec<_>>()
}) {
debug!("Got signing keys: {:?}", server_keys);
for k in server_keys {
self.services
.globals
.db
.add_signing_key(origin, k.clone())
.await;
result.extend(
k.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
k.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
}
if contains_all_ids(&result) {
return Ok(result);
}
}
}
debug!("Asking {origin} for their signing keys over federation");
if let Some(server_key) = self
.services
.sending
.send_federation_request(origin, get_server_keys::v2::Request::new())
.await
.ok()
.and_then(|resp| resp.server_key.deserialize().ok())
{
self.services
.globals
.db
.add_signing_key(origin, server_key.clone())
.await;
result.extend(
server_key
.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
server_key
.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
if contains_all_ids(&result) {
return Ok(result);
}
}
} else {
info!("query_trusted_key_servers_first is set to false, querying {origin} first");
debug!("Asking {origin} for their signing keys over federation");
if let Some(server_key) = self
.services
.sending
.send_federation_request(origin, get_server_keys::v2::Request::new())
.await
.ok()
.and_then(|resp| resp.server_key.deserialize().ok())
{
self.services
.globals
.db
.add_signing_key(origin, server_key.clone())
.await;
result.extend(
server_key
.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
server_key
.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
if contains_all_ids(&result) {
return Ok(result);
}
}
for server in self.services.globals.trusted_servers() {
debug!("Asking notary server {server} for {origin}'s signing key");
if let Some(server_keys) = self
.services
.sending
.send_federation_request(
server,
get_remote_server_keys::v2::Request::new(
origin.to_owned(),
MilliSecondsSinceUnixEpoch::from_system_time(
SystemTime::now()
.checked_add(Duration::from_secs(3600))
.expect("SystemTime too large"),
)
.expect("time is valid"),
),
)
.await
.ok()
.map(|resp| {
resp.server_keys
.into_iter()
.filter_map(|e| e.deserialize().ok())
.collect::<Vec<_>>()
}) {
debug!("Got signing keys: {server_keys:?}");
for k in server_keys {
self.services
.globals
.db
.add_signing_key(origin, k.clone())
.await;
result.extend(
k.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
k.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
}
if contains_all_ids(&result) {
return Ok(result);
}
}
}
}
Err!(BadServerResponse(warn!("Failed to find public key for server {origin:?}")))
}
keys.verify_keys.extend(new_keys.verify_keys);
keys.old_verify_keys.extend(new_keys.old_verify_keys);
self.db.server_signingkeys.raw_put(origin, Json(&keys));
}
#[implement(Service)]
async fn verify_key_exists(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> bool {
type KeysMap<'a> = BTreeMap<&'a ServerSigningKeyId, &'a RawJsonValue>;
let Ok(keys) = self
.db
.server_signingkeys
.get(origin)
.await
.deserialized::<Raw<ServerSigningKeys>>()
else {
return false;
};
if let Ok(Some(verify_keys)) = keys.get_field::<KeysMap<'_>>("verify_keys") {
if verify_keys.contains_key(key_id) {
return true;
}
}
if let Ok(Some(old_verify_keys)) = keys.get_field::<KeysMap<'_>>("old_verify_keys") {
if old_verify_keys.contains_key(key_id) {
return true;
}
}
false
}
#[implement(Service)]
pub async fn verify_keys_for(&self, origin: &ServerName) -> VerifyKeys {
let mut keys = self
.signing_keys_for(origin)
.await
.map(|keys| merge_old_keys(keys).verify_keys)
.unwrap_or(BTreeMap::new());
if self.services.globals.server_is_ours(origin) {
keys.extend(self.verify_keys.clone().into_iter());
}
keys
}
#[implement(Service)]
pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSigningKeys> {
self.db.server_signingkeys.get(origin).await.deserialized()
}
#[implement(Service)]
fn minimum_valid_ts(&self) -> MilliSecondsSinceUnixEpoch {
let timepoint = timepoint_from_now(self.minimum_valid).expect("SystemTime should not overflow");
MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow")
}
fn merge_old_keys(mut keys: ServerSigningKeys) -> ServerSigningKeys {
keys.verify_keys.extend(
keys.old_verify_keys
.clone()
.into_iter()
.map(|(key_id, old)| (key_id, VerifyKey::new(old.key))),
);
keys
}
fn extract_key(mut keys: ServerSigningKeys, key_id: &ServerSigningKeyId) -> Option<VerifyKey> {
keys.verify_keys.remove(key_id).or_else(|| {
keys.old_verify_keys
.remove(key_id)
.map(|old| VerifyKey::new(old.key))
})
}
fn key_exists(keys: &ServerSigningKeys, key_id: &ServerSigningKeyId) -> bool {
keys.verify_keys.contains_key(key_id) || keys.old_verify_keys.contains_key(key_id)
}

View file

@ -0,0 +1,97 @@
use std::collections::BTreeMap;
use conduit::{implement, Err, Result};
use ruma::{
api::federation::discovery::{
get_remote_server_keys,
get_remote_server_keys_batch::{self, v2::QueryCriteria},
get_server_keys, ServerSigningKeys,
},
OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
};
#[implement(super::Service)]
pub(super) async fn batch_notary_request<'a, S, K>(
&self, notary: &ServerName, batch: S,
) -> Result<Vec<ServerSigningKeys>>
where
S: Iterator<Item = (&'a ServerName, K)> + Send,
K: Iterator<Item = &'a ServerSigningKeyId> + Send,
{
use get_remote_server_keys_batch::v2::Request;
type RumaBatch = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>;
let criteria = QueryCriteria {
minimum_valid_until_ts: Some(self.minimum_valid_ts()),
};
let mut server_keys = RumaBatch::new();
for (server, key_ids) in batch {
let entry = server_keys.entry(server.into()).or_default();
for key_id in key_ids {
entry.insert(key_id.into(), criteria.clone());
}
}
debug_assert!(!server_keys.is_empty(), "empty batch request to notary");
let request = Request {
server_keys,
};
self.services
.sending
.send_federation_request(notary, request)
.await
.map(|response| response.server_keys)
.map(|keys| {
keys.into_iter()
.map(|key| key.deserialize())
.filter_map(Result::ok)
.collect()
})
}
#[implement(super::Service)]
pub async fn notary_request(&self, notary: &ServerName, target: &ServerName) -> Result<Vec<ServerSigningKeys>> {
use get_remote_server_keys::v2::Request;
let request = Request {
server_name: target.into(),
minimum_valid_until_ts: self.minimum_valid_ts(),
};
self.services
.sending
.send_federation_request(notary, request)
.await
.map(|response| response.server_keys)
.map(|keys| {
keys.into_iter()
.map(|key| key.deserialize())
.filter_map(Result::ok)
.collect()
})
}
#[implement(super::Service)]
pub async fn server_request(&self, target: &ServerName) -> Result<ServerSigningKeys> {
use get_server_keys::v2::Request;
let server_signing_key = self
.services
.sending
.send_federation_request(target, Request::new())
.await
.map(|response| response.server_key)
.and_then(|key| key.deserialize().map_err(Into::into))?;
if server_signing_key.server_name != target {
return Err!(BadServerResponse(debug_warn!(
requested = ?target,
response = ?server_signing_key.server_name,
"Server responded with bogus server_name"
)));
}
Ok(server_signing_key)
}

View file

@ -0,0 +1,18 @@
use conduit::{implement, Result};
use ruma::{CanonicalJsonObject, RoomVersionId};
#[implement(super::Service)]
pub fn sign_json(&self, object: &mut CanonicalJsonObject) -> Result {
use ruma::signatures::sign_json;
let server_name = self.services.globals.server_name().as_str();
sign_json(server_name, self.keypair(), object).map_err(Into::into)
}
#[implement(super::Service)]
pub fn hash_and_sign_event(&self, object: &mut CanonicalJsonObject, room_version: &RoomVersionId) -> Result {
use ruma::signatures::hash_and_sign_event;
let server_name = self.services.globals.server_name().as_str();
hash_and_sign_event(server_name, self.keypair(), object, room_version).map_err(Into::into)
}

View file

@ -0,0 +1,33 @@
use conduit::{implement, pdu::gen_event_id_canonical_json, Err, Result};
use ruma::{signatures::Verified, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomVersionId};
use serde_json::value::RawValue as RawJsonValue;
#[implement(super::Service)]
pub async fn validate_and_add_event_id(
&self, pdu: &RawJsonValue, room_version: &RoomVersionId,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?;
if let Err(e) = self.verify_event(&value, Some(room_version)).await {
return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}")));
}
value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into()));
Ok((event_id, value))
}
#[implement(super::Service)]
pub async fn verify_event(
&self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>,
) -> Result<Verified> {
let room_version = room_version.unwrap_or(&RoomVersionId::V11);
let keys = self.get_event_keys(event, room_version).await?;
ruma::signatures::verify_event(&keys, event, room_version).map_err(Into::into)
}
#[implement(super::Service)]
pub async fn verify_json(&self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>) -> Result {
let room_version = room_version.unwrap_or(&RoomVersionId::V11);
let keys = self.get_event_keys(event, room_version).await?;
ruma::signatures::verify_json(&keys, event.clone()).map_err(Into::into)
}