From 3c09313f7949d5be4b67e33604490411266de2b5 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 30 Mar 2024 00:30:26 -0700 Subject: [PATCH] move and reorganize sending codepaths; no functional changes Signed-off-by: Jason Volk --- src/api/mod.rs | 1 - src/api/server_server.rs | 608 +---------------- src/service/globals/resolver.rs | 2 +- .../sending/appservice.rs} | 0 src/service/sending/mod.rs | 345 +++++----- src/service/sending/send.rs | 617 ++++++++++++++++++ 6 files changed, 791 insertions(+), 782 deletions(-) rename src/{api/appservice_server.rs => service/sending/appservice.rs} (100%) create mode 100644 src/service/sending/send.rs diff --git a/src/api/mod.rs b/src/api/mod.rs index 0d2cd664..5c284757 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,3 @@ -pub mod appservice_server; pub mod client_server; pub mod ruma_wrapper; pub mod server_server; diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 3a51a15d..ca45c40e 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -3,22 +3,15 @@ use std::{ collections::BTreeMap, - fmt::Debug, - mem, - net::{IpAddr, SocketAddr}, sync::Arc, time::{Duration, Instant, SystemTime}, }; use axum::{response::IntoResponse, Json}; -use futures_util::future::TryFutureExt; use get_profile_information::v1::ProfileField; -use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; -use http::header::{HeaderValue, AUTHORIZATION}; -use ipaddress::IPAddress; use ruma::{ api::{ - client::error::{Error as RumaError, ErrorKind}, + client::error::ErrorKind, federation::{ authorization::get_event_authorization, backfill::get_backfill, @@ -35,7 +28,7 @@ use ruma::{ send_transaction_message, }, }, - EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, OutgoingResponse, SendAccessToken, + OutgoingResponse, }, directory::{Filter, RoomNetwork}, events::{ @@ -61,558 +54,6 @@ use crate::{ services, utils, Error, PduEvent, Result, Ruma, }; -/// Wraps either an literal IP address plus port, or a hostname plus complement -/// (colon-plus-port if it was specified). -/// -/// Note: A `FedDest::Named` might contain an IP address in string form if there -/// was no port specified to construct a `SocketAddr` with. -/// -/// # Examples: -/// ```rust -/// # use conduit::api::server_server::FedDest; -/// # fn main() -> Result<(), std::net::AddrParseError> { -/// FedDest::Literal("198.51.100.3:8448".parse()?); -/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); -/// FedDest::Named("matrix.example.org".to_owned(), String::new()); -/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); -/// FedDest::Named("198.51.100.5".to_owned(), String::new()); -/// # Ok(()) -/// # } -/// ``` -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum FedDest { - Literal(SocketAddr), - Named(String, String), -} - -impl FedDest { - fn into_https_string(self) -> String { - match self { - Self::Literal(addr) => format!("https://{addr}"), - Self::Named(host, port) => format!("https://{host}{port}"), - } - } - - fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, port) => host + &port, - } - } - - fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } - - fn port(&self) -> Option { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port[1..].parse().ok(), - } - } -} - -pub(crate) async fn send_request(destination: &ServerName, request: T) -> Result -where - T: OutgoingRequest + Debug, -{ - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - if destination == services().globals.server_name() { - return Err(Error::bad_config("Won't send federation request to ourselves")); - } - - if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { - info!( - "Destination {} is an IP literal, checking against IP range denylist.", - destination - ); - let ip = IPAddress::parse(destination.host()).map_err(|e| { - warn!("Failed to parse IP literal from string: {}", e); - Error::BadServerResponse("Invalid IP address") - })?; - - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - debug!("List of pushed CIDR ranges: {:?}", cidr_ranges); - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } - } - - info!("IP literal {} is allowed.", destination); - } - - debug!("Preparing to send request to {destination}"); - - let mut write_destination_to_cache = false; - - let cached_result = services() - .globals - .actual_destinations() - .read() - .await - .get(destination) - .cloned(); - - let (actual_destination, host) = if let Some(result) = cached_result { - result - } else { - write_destination_to_cache = true; - - let result = find_actual_destination(destination).await; - - (result.0, result.1.into_uri_string()) - }; - - let actual_destination_str = actual_destination.clone().into_https_string(); - - let mut http_request = request - .try_into_http_request::>( - &actual_destination_str, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_5], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", actual_destination_str, e); - Error::BadServerResponse("Invalid destination") - })?; - - let mut request_map = serde_json::Map::new(); - - if !http_request.body().is_empty() { - request_map.insert( - "content".to_owned(), - serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"), - ); - }; - - request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); - request_map.insert( - "uri".to_owned(), - http_request - .uri() - .path_and_query() - .expect("all requests have a path") - .to_string() - .into(), - ); - request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); - request_map.insert("destination".to_owned(), destination.as_str().into()); - - let mut request_json = serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); - - ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut request_json, - ) - .expect("our request json is what ruma expects"); - - let request_json: serde_json::Map = - serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap(); - - let signatures = request_json["signatures"] - .as_object() - .unwrap() - .values() - .map(|v| { - v.as_object() - .unwrap() - .iter() - .map(|(k, v)| (k, v.as_str().unwrap())) - }); - - for signature_server in signatures { - for s in signature_server { - http_request.headers_mut().insert( - AUTHORIZATION, - HeaderValue::from_str(&format!( - "X-Matrix origin={},key=\"{}\",sig=\"{}\"", - services().globals.server_name(), - s.0, - s.1 - )) - .unwrap(), - ); - } - } - - let reqwest_request = reqwest::Request::try_from(http_request)?; - - let url = reqwest_request.url().clone(); - - if let Some(url_host) = url.host_str() { - debug!("Checking request URL for IP"); - if let Ok(ip) = IPAddress::parse(url_host) { - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } - } - } - } - - debug!("Sending request to {destination} at {url}"); - let response = services() - .globals - .client - .federation - .execute(reqwest_request) - .await; - debug!("Received response from {destination} at {url}"); - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - - debug!("Checking response destination's IP"); - if let Some(remote_addr) = response.remote_addr() { - if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } - } - } - } - - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - debug!("Getting response bytes from {destination}"); - let body = response.bytes().await.unwrap_or_else(|e| { - info!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - debug!("Got response bytes from {destination}"); - - if !status.is_success() { - debug!( - "Response not successful\n{} {}: {}", - url, - status, - String::from_utf8_lossy(&body) - .lines() - .collect::>() - .join(" ") - ); - } - - let http_response = http_response_builder - .body(body) - .expect("reqwest body is valid http body"); - - if status.is_success() { - debug!("Parsing response bytes from {destination}"); - let response = T::IncomingResponse::try_from_http_response(http_response); - if response.is_ok() && write_destination_to_cache { - services() - .globals - .actual_destinations() - .write() - .await - .insert(OwnedServerName::from(destination), (actual_destination, host)); - } - - response.map_err(|e| { - warn!("Invalid 200 response from {} on: {} {}", &destination, url, e); - Error::BadServerResponse("Server returned bad 200 response.") - }) - } else { - debug!("Returning error from {destination}"); - Err(Error::FederationError( - destination.to_owned(), - RumaError::from_http_response(http_response), - )) - } - }, - Err(e) => { - // we do not need to log that servers in a room are dead, this is normal in - // public rooms and just spams the logs. - if e.is_timeout() { - debug!( - "Timed out sending request to {} at {}: {}", - destination, actual_destination_str, e - ); - } else if e.is_connect() { - debug!("Failed to connect to {} at {}: {}", destination, actual_destination_str, e); - } else if e.is_redirect() { - debug!( - "Redirect loop sending request to {} at {}: {}\nFinal URL: {:?}", - destination, - actual_destination_str, - e, - e.url() - ); - } else { - info!("Could not send request to {} at {}: {}", destination, actual_destination_str, e); - } - - Err(e.into()) - }, - } -} - -fn get_ip_with_port(destination_str: &str) -> Option { - if let Ok(destination) = destination_str.parse::() { - Some(FedDest::Literal(destination)) - } else if let Ok(ip_addr) = destination_str.parse::() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -fn add_port_to_hostname(destination_str: &str) -> FedDest { - let (host, port) = match destination_str.find(':') { - None => (destination_str, ":8448"), - Some(pos) => destination_str.split_at(pos), - }; - FedDest::Named(host.to_owned(), port.to_owned()) -} - -/// Returns: `actual_destination`, host header -/// Implemented according to the specification at -/// Numbers in comments below refer to bullet points in linked section of -/// specification -async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) { - debug!("Finding actual destination for {destination}"); - let destination_str = destination.as_str().to_owned(); - let mut hostname = destination_str.clone(); - let actual_destination = match get_ip_with_port(&destination_str) { - Some(host_port) => { - debug!("1: IP literal with provided or default port"); - host_port - }, - None => { - if let Some(pos) = destination_str.find(':') { - debug!("2: Hostname with included port"); - - let (host, port) = destination_str.split_at(pos); - query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await; - - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - debug!("Requesting well known for {destination}"); - if let Some(delegated_hostname) = request_well_known(destination.as_str()).await { - debug!("3: A .well-known file is available"); - hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); - match get_ip_with_port(&delegated_hostname) { - Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file - None => { - if let Some(pos) = delegated_hostname.find(':') { - debug!("3.2: Hostname with port in .well-known file"); - - let (host, port) = delegated_hostname.split_at(pos); - query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await; - - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - debug!("Delegated hostname has no port in this branch"); - if let Some(hostname_override) = query_srv_record(&delegated_hostname).await { - debug!("3.3: SRV lookup successful"); - - let force_port = hostname_override.port(); - query_and_cache_override( - &delegated_hostname, - &hostname_override.hostname(), - force_port.unwrap_or(8448), - ) - .await; - - if let Some(port) = force_port { - FedDest::Named(delegated_hostname, format!(":{port}")) - } else { - add_port_to_hostname(&delegated_hostname) - } - } else { - debug!("3.4: No SRV records, just use the hostname from .well-known"); - query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448).await; - add_port_to_hostname(&delegated_hostname) - } - } - }, - } - } else { - debug!("4: No .well-known or an error occured"); - if let Some(hostname_override) = query_srv_record(&destination_str).await { - debug!("4: SRV record found"); - - let force_port = hostname_override.port(); - query_and_cache_override(&hostname, &hostname_override.hostname(), force_port.unwrap_or(8448)) - .await; - - if let Some(port) = force_port { - FedDest::Named(hostname.clone(), format!(":{port}")) - } else { - add_port_to_hostname(&hostname) - } - } else { - debug!("5: No SRV record found"); - query_and_cache_override(&destination_str, &destination_str, 8448).await; - add_port_to_hostname(&destination_str) - } - } - } - }, - }; - - // Can't use get_ip_with_port here because we don't want to add a port - // to an IP address if it wasn't specified - let hostname = if let Ok(addr) = hostname.parse::() { - FedDest::Literal(addr) - } else if let Ok(addr) = hostname.parse::() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) - } else if let Some(pos) = hostname.find(':') { - let (host, port) = hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - FedDest::Named(hostname, ":8448".to_owned()) - }; - - debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}"); - (actual_destination, hostname) -} - -async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) { - match services() - .globals - .dns_resolver() - .lookup_ip(hostname.to_owned()) - .await - { - Ok(override_ip) => { - debug!("Caching result of {:?} overriding {:?}", hostname, overname); - - services() - .globals - .resolver - .overrides - .write() - .unwrap() - .insert(overname.to_owned(), (override_ip.iter().collect(), port)); - }, - Err(e) => { - debug!("Got {:?} for {:?} to override {:?}", e.kind(), hostname, overname); - }, - } -} - -async fn query_srv_record(hostname: &'_ str) -> Option { - fn handle_successful_srv(srv: &SrvLookup) -> Option { - srv.iter().next().map(|result| { - FedDest::Named( - result.target().to_string().trim_end_matches('.').to_owned(), - format!(":{}", result.port()), - ) - }) - } - - async fn lookup_srv(hostname: &str) -> Result { - debug!("querying SRV for {:?}", hostname); - let hostname = hostname.trim_end_matches('.'); - services() - .globals - .dns_resolver() - .srv_lookup(hostname.to_owned()) - .await - } - - let first_hostname = format!("_matrix-fed._tcp.{hostname}."); - let second_hostname = format!("_matrix._tcp.{hostname}."); - - lookup_srv(&first_hostname) - .or_else(|_| { - debug!("Querying deprecated _matrix SRV record for host {:?}", hostname); - lookup_srv(&second_hostname) - }) - .and_then(|srv_lookup| async move { Ok(handle_successful_srv(&srv_lookup)) }) - .await - .ok() - .flatten() -} - -async fn request_well_known(destination: &str) -> Option { - if !services() - .globals - .resolver - .overrides - .read() - .unwrap() - .contains_key(destination) - { - query_and_cache_override(destination, destination, 8448).await; - } - - let response = services() - .globals - .client - .well_known - .get(&format!("https://{destination}/.well-known/matrix/server")) - .send() - .await; - debug!("Got well known response"); - debug!("Well known response: {:?}", response); - - if let Err(e) = &response { - debug!("Well known error: {e:?}"); - return None; - } - - let text = response.ok()?.text().await; - - debug!("Got well known response text"); - debug!("Well known response text: {:?}", text); - - if text.as_ref().ok()?.len() > 10000 { - debug!( - "Well known response for destination '{destination}' exceeded past 10000 characters, assuming no \ - well-known." - ); - return None; - } - - let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; - debug!("serde_json body of well known text: {}", body); - - Some(body.get("m.server")?.as_str()?.to_owned()) -} - /// # `GET /_matrix/federation/v1/version` /// /// Get version information on this server. @@ -2074,48 +1515,3 @@ pub async fn get_hierarchy_route(body: Ruma) -> Resu Err(Error::BadRequest(ErrorKind::NotFound, "Room does not exist.")) } } - -#[cfg(test)] -mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } - - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } - - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ); - } - - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ); - } -} diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs index 3e9fb00a..55ad35d0 100644 --- a/src/service/globals/resolver.rs +++ b/src/service/globals/resolver.rs @@ -13,7 +13,7 @@ use ruma::OwnedServerName; use tokio::sync::RwLock; use tracing::error; -use crate::{api::server_server::FedDest, Config, Error}; +use crate::{service::sending::FedDest, Config, Error}; pub type WellKnownMap = HashMap; pub type TlsNameMap = HashMap, u16)>; diff --git a/src/api/appservice_server.rs b/src/service/sending/appservice.rs similarity index 100% rename from src/api/appservice_server.rs rename to src/service/sending/appservice.rs diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 20884950..e4482f5b 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,5 +1,3 @@ -mod data; - use std::{ collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, @@ -11,10 +9,9 @@ use base64::{engine::general_purpose, Engine as _}; pub use data::Data; use federation::transactions::send_transaction_message; use futures_util::{stream::FuturesUnordered, StreamExt}; -use ipaddress::IPAddress; use ruma::{ api::{ - appservice::{self, Registration}, + appservice::Registration, federation::{ self, transactions::edu::{ @@ -31,14 +28,25 @@ use tokio::{ select, sync::{mpsc, Mutex, Semaphore}, }; -use tracing::{debug, error, info, warn}; +use tracing::{error, warn}; -use crate::{ - api::{appservice_server, server_server}, - services, - utils::calculate_hash, - Config, Error, PduEvent, Result, -}; +use crate::{services, utils::calculate_hash, Config, Error, PduEvent, Result}; + +pub mod appservice; +pub mod data; +pub mod send; +pub use send::FedDest; + +pub struct Service { + db: &'static dyn Data, + + /// The state for a given state hash. + pub(super) maximum_requests: Arc, + pub sender: mpsc::UnboundedSender<(OutgoingKind, SendingEventType, Vec)>, + receiver: Mutex)>>, + startup_netburst: bool, + timeout: u64, +} #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum OutgoingKind { @@ -83,16 +91,6 @@ pub enum SendingEventType { Flush, // none } -pub struct Service { - db: &'static dyn Data, - - /// The state for a given state hash. - pub(super) maximum_requests: Arc, - pub sender: mpsc::UnboundedSender<(OutgoingKind, SendingEventType, Vec)>, - receiver: Mutex)>>, - startup_netburst: bool, -} - enum TransactionStatus { Running, Failed(u32, Instant), // number of times failed, time of last failure @@ -111,6 +109,148 @@ impl Service { }) } + #[tracing::instrument(skip(self, pdu_id, user, pushkey))] + pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); + let event = SendingEventType::Pdu(pdu_id.to_owned()); + let _cork = services().globals.db.cork()?; + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + let outgoing_kind = OutgoingKind::Appservice(appservice_id); + let event = SendingEventType::Pdu(pdu_id); + let _cork = services().globals.db.cork()?; + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id, pdu_id))] + pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(Result::ok) + .filter(|server| &**server != services().globals.server_name()); + + self.send_pdu_servers(servers, pdu_id) + } + + #[tracing::instrument(skip(self, servers, pdu_id))] + pub fn send_pdu_servers>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { + let requests = servers + .into_iter() + .map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned()))) + .collect::>(); + let _cork = services().globals.db.cork()?; + let keys = self.db.queue_requests( + &requests + .iter() + .map(|(o, e)| (o, e.clone())) + .collect::>(), + )?; + for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { + self.sender + .send((outgoing_kind.clone(), event, key)) + .unwrap(); + } + + Ok(()) + } + + #[tracing::instrument(skip(self, server, serialized))] + pub fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { + let outgoing_kind = OutgoingKind::Normal(server.to_owned()); + let event = SendingEventType::Edu(serialized); + let _cork = services().globals.db.cork()?; + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id, serialized))] + pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(Result::ok) + .filter(|server| &**server != services().globals.server_name()); + + self.send_edu_servers(servers, serialized) + } + + #[tracing::instrument(skip(self, servers, serialized))] + pub fn send_edu_servers>(&self, servers: I, serialized: Vec) -> Result<()> { + let requests = servers + .into_iter() + .map(|server| (OutgoingKind::Normal(server), SendingEventType::Edu(serialized.clone()))) + .collect::>(); + let _cork = services().globals.db.cork()?; + let keys = self.db.queue_requests( + &requests + .iter() + .map(|(o, e)| (o, e.clone())) + .collect::>(), + )?; + for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { + self.sender + .send((outgoing_kind.clone(), event, key)) + .unwrap(); + } + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id))] + pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(Result::ok) + .filter(|server| &**server != services().globals.server_name()); + + self.flush_servers(servers) + } + + #[tracing::instrument(skip(self, servers))] + pub fn flush_servers>(&self, servers: I) -> Result<()> { + let requests = servers.into_iter().map(OutgoingKind::Normal); + + for outgoing_kind in requests { + self.sender + .send((outgoing_kind, SendingEventType::Flush, Vec::::new())) + .unwrap(); + } + + Ok(()) + } + + /// Cleanup event data + /// Used for instance after we remove an appservice registration + #[tracing::instrument(skip(self))] + pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { + self.db + .delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; + + Ok(()) + } + pub fn start_handler(self: &Arc) { let self2 = Arc::clone(self); tokio::spawn(async move { @@ -407,148 +547,6 @@ impl Service { Ok((events, max_edu_count)) } - #[tracing::instrument(skip(self, pdu_id, user, pushkey))] - pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { - let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); - let event = SendingEventType::Pdu(pdu_id.to_owned()); - let _cork = services().globals.db.cork()?; - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender - .send((outgoing_kind, event, keys.into_iter().next().unwrap())) - .unwrap(); - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { - let outgoing_kind = OutgoingKind::Appservice(appservice_id); - let event = SendingEventType::Pdu(pdu_id); - let _cork = services().globals.db.cork()?; - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender - .send((outgoing_kind, event, keys.into_iter().next().unwrap())) - .unwrap(); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id, pdu_id))] - pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { - let servers = services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); - - self.send_pdu_servers(servers, pdu_id) - } - - #[tracing::instrument(skip(self, servers, pdu_id))] - pub fn send_pdu_servers>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned()))) - .collect::>(); - let _cork = services().globals.db.cork()?; - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; - for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { - self.sender - .send((outgoing_kind.clone(), event, key)) - .unwrap(); - } - - Ok(()) - } - - #[tracing::instrument(skip(self, server, serialized))] - pub fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { - let outgoing_kind = OutgoingKind::Normal(server.to_owned()); - let event = SendingEventType::Edu(serialized); - let _cork = services().globals.db.cork()?; - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender - .send((outgoing_kind, event, keys.into_iter().next().unwrap())) - .unwrap(); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id, serialized))] - pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { - let servers = services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); - - self.send_edu_servers(servers, serialized) - } - - #[tracing::instrument(skip(self, servers, serialized))] - pub fn send_edu_servers>(&self, servers: I, serialized: Vec) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (OutgoingKind::Normal(server), SendingEventType::Edu(serialized.clone()))) - .collect::>(); - let _cork = services().globals.db.cork()?; - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; - for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { - self.sender - .send((outgoing_kind.clone(), event, key)) - .unwrap(); - } - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id))] - pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { - let servers = services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); - - self.flush_servers(servers) - } - - #[tracing::instrument(skip(self, servers))] - pub fn flush_servers>(&self, servers: I) -> Result<()> { - let requests = servers.into_iter().map(OutgoingKind::Normal); - - for outgoing_kind in requests { - self.sender - .send((outgoing_kind, SendingEventType::Flush, Vec::::new())) - .unwrap(); - } - - Ok(()) - } - - /// Cleanup event data - /// Used for instance after we remove an appservice registration - #[tracing::instrument(skip(self))] - pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { - self.db - .delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; - - Ok(()) - } - #[tracing::instrument(skip(events, kind))] async fn handle_events( kind: OutgoingKind, events: Vec, @@ -586,7 +584,7 @@ impl Service { let permit = services().sending.maximum_requests.acquire().await; - let response = match appservice_server::send_request( + let response = match appservice::send_request( services() .appservice .get_registration(id) @@ -597,7 +595,7 @@ impl Service { Error::bad_database("[Appservice] Could not load registration from db."), ) })?, - appservice::event::push_events::v1::Request { + ruma::api::appservice::event::push_events::v1::Request { events: pdu_jsons, txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( &events @@ -737,7 +735,7 @@ impl Service { let permit = services().sending.maximum_requests.acquire().await; - let response = server_server::send_request( + let response = send::send_request( server, send_transaction_message::v1::Request { origin: services().globals.server_name().to_owned(), @@ -814,13 +812,12 @@ impl Service { debug!("Waiting for permit"); let permit = self.maximum_requests.acquire().await; debug!("Got permit"); - let response = - tokio::time::timeout(Duration::from_secs(5 * 60), server_server::send_request(destination, request)) - .await - .map_err(|_| { - warn!("Timeout after 300 seconds waiting for server response of {destination}"); - Error::BadServerResponse("Timeout after 300 seconds waiting for server response") - })?; + let response = tokio::time::timeout(Duration::from_secs(5 * 60), send::send_request(destination, request)) + .await + .map_err(|_| { + warn!("Timeout after 300 seconds waiting for server response of {destination}"); + Error::BadServerResponse("Timeout after 300 seconds waiting for server response") + })?; drop(permit); response @@ -837,7 +834,7 @@ impl Service { T: OutgoingRequest + Debug, { let permit = self.maximum_requests.acquire().await; - let response = appservice_server::send_request(registration, request).await; + let response = appservice::send_request(registration, request).await; drop(permit); response diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs new file mode 100644 index 00000000..68b6b2b2 --- /dev/null +++ b/src/service/sending/send.rs @@ -0,0 +1,617 @@ +use std::{ + fmt::Debug, + mem, + net::{IpAddr, SocketAddr}, +}; + +use futures_util::TryFutureExt; +use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; +use http::{header::AUTHORIZATION, HeaderValue}; +use ipaddress::IPAddress; +use ruma::{ + api::{ + client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, + SendAccessToken, + }, + OwnedServerName, ServerName, +}; +use tracing::{debug, info, warn}; + +use crate::{services, Error, Result}; + +/// Wraps either an literal IP address plus port, or a hostname plus complement +/// (colon-plus-port if it was specified). +/// +/// Note: A `FedDest::Named` might contain an IP address in string form if there +/// was no port specified to construct a `SocketAddr` with. +/// +/// # Examples: +/// ```rust +/// # use conduit::api::server_server::FedDest; +/// # fn main() -> Result<(), std::net::AddrParseError> { +/// FedDest::Literal("198.51.100.3:8448".parse()?); +/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); +/// FedDest::Named("matrix.example.org".to_owned(), String::new()); +/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); +/// FedDest::Named("198.51.100.5".to_owned(), String::new()); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FedDest { + Literal(SocketAddr), + Named(String, String), +} + +pub(crate) async fn send_request(destination: &ServerName, request: T) -> Result +where + T: OutgoingRequest + Debug, +{ + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + if destination == services().globals.server_name() { + return Err(Error::bad_config("Won't send federation request to ourselves")); + } + + if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { + info!( + "Destination {} is an IP literal, checking against IP range denylist.", + destination + ); + let ip = IPAddress::parse(destination.host()).map_err(|e| { + warn!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; + + let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); + let mut cidr_ranges: Vec = Vec::new(); + + for cidr in cidr_ranges_s { + cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); + } + + debug!("List of pushed CIDR ranges: {:?}", cidr_ranges); + + for cidr in cidr_ranges { + if cidr.includes(&ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + } + + info!("IP literal {} is allowed.", destination); + } + + debug!("Preparing to send request to {destination}"); + + let mut write_destination_to_cache = false; + + let cached_result = services() + .globals + .actual_destinations() + .read() + .await + .get(destination) + .cloned(); + + let (actual_destination, host) = if let Some(result) = cached_result { + result + } else { + write_destination_to_cache = true; + + let result = find_actual_destination(destination).await; + + (result.0, result.1.into_uri_string()) + }; + + let actual_destination_str = actual_destination.clone().into_https_string(); + + let mut http_request = request + .try_into_http_request::>( + &actual_destination_str, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_5], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", actual_destination_str, e); + Error::BadServerResponse("Invalid destination") + })?; + + let mut request_map = serde_json::Map::new(); + + if !http_request.body().is_empty() { + request_map.insert( + "content".to_owned(), + serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"), + ); + }; + + request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); + request_map.insert( + "uri".to_owned(), + http_request + .uri() + .path_and_query() + .expect("all requests have a path") + .to_string() + .into(), + ); + request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); + request_map.insert("destination".to_owned(), destination.as_str().into()); + + let mut request_json = serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); + + ruma::signatures::sign_json( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut request_json, + ) + .expect("our request json is what ruma expects"); + + let request_json: serde_json::Map = + serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap(); + + let signatures = request_json["signatures"] + .as_object() + .unwrap() + .values() + .map(|v| { + v.as_object() + .unwrap() + .iter() + .map(|(k, v)| (k, v.as_str().unwrap())) + }); + + for signature_server in signatures { + for s in signature_server { + http_request.headers_mut().insert( + AUTHORIZATION, + HeaderValue::from_str(&format!( + "X-Matrix origin={},key=\"{}\",sig=\"{}\"", + services().globals.server_name(), + s.0, + s.1 + )) + .unwrap(), + ); + } + } + + let reqwest_request = reqwest::Request::try_from(http_request)?; + + let url = reqwest_request.url().clone(); + + if let Some(url_host) = url.host_str() { + debug!("Checking request URL for IP"); + if let Ok(ip) = IPAddress::parse(url_host) { + let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); + let mut cidr_ranges: Vec = Vec::new(); + + for cidr in cidr_ranges_s { + cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); + } + + for cidr in cidr_ranges { + if cidr.includes(&ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + } + } + } + + debug!("Sending request to {destination} at {url}"); + let response = services() + .globals + .client + .federation + .execute(reqwest_request) + .await; + debug!("Received response from {destination} at {url}"); + + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + + debug!("Checking response destination's IP"); + if let Some(remote_addr) = response.remote_addr() { + if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { + let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); + let mut cidr_ranges: Vec = Vec::new(); + + for cidr in cidr_ranges_s { + cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); + } + + for cidr in cidr_ranges { + if cidr.includes(&ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + } + } + } + + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + debug!("Getting response bytes from {destination}"); + let body = response.bytes().await.unwrap_or_else(|e| { + info!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + debug!("Got response bytes from {destination}"); + + if !status.is_success() { + debug!( + "Response not successful\n{} {}: {}", + url, + status, + String::from_utf8_lossy(&body) + .lines() + .collect::>() + .join(" ") + ); + } + + let http_response = http_response_builder + .body(body) + .expect("reqwest body is valid http body"); + + if status.is_success() { + debug!("Parsing response bytes from {destination}"); + let response = T::IncomingResponse::try_from_http_response(http_response); + if response.is_ok() && write_destination_to_cache { + services() + .globals + .actual_destinations() + .write() + .await + .insert(OwnedServerName::from(destination), (actual_destination, host)); + } + + response.map_err(|e| { + warn!("Invalid 200 response from {} on: {} {}", &destination, url, e); + Error::BadServerResponse("Server returned bad 200 response.") + }) + } else { + debug!("Returning error from {destination}"); + Err(Error::FederationError( + destination.to_owned(), + RumaError::from_http_response(http_response), + )) + } + }, + Err(e) => { + // we do not need to log that servers in a room are dead, this is normal in + // public rooms and just spams the logs. + if e.is_timeout() { + debug!( + "Timed out sending request to {} at {}: {}", + destination, actual_destination_str, e + ); + } else if e.is_connect() { + debug!("Failed to connect to {} at {}: {}", destination, actual_destination_str, e); + } else if e.is_redirect() { + debug!( + "Redirect loop sending request to {} at {}: {}\nFinal URL: {:?}", + destination, + actual_destination_str, + e, + e.url() + ); + } else { + info!("Could not send request to {} at {}: {}", destination, actual_destination_str, e); + } + + Err(e.into()) + }, + } +} + +fn get_ip_with_port(destination_str: &str) -> Option { + if let Ok(destination) = destination_str.parse::() { + Some(FedDest::Literal(destination)) + } else if let Ok(ip_addr) = destination_str.parse::() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } +} + +fn add_port_to_hostname(destination_str: &str) -> FedDest { + let (host, port) = match destination_str.find(':') { + None => (destination_str, ":8448"), + Some(pos) => destination_str.split_at(pos), + }; + FedDest::Named(host.to_owned(), port.to_owned()) +} + +/// Returns: `actual_destination`, host header +/// Implemented according to the specification at +/// Numbers in comments below refer to bullet points in linked section of +/// specification +async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) { + debug!("Finding actual destination for {destination}"); + let destination_str = destination.as_str().to_owned(); + let mut hostname = destination_str.clone(); + let actual_destination = match get_ip_with_port(&destination_str) { + Some(host_port) => { + debug!("1: IP literal with provided or default port"); + host_port + }, + None => { + if let Some(pos) = destination_str.find(':') { + debug!("2: Hostname with included port"); + + let (host, port) = destination_str.split_at(pos); + query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await; + + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + debug!("Requesting well known for {destination}"); + if let Some(delegated_hostname) = request_well_known(destination.as_str()).await { + debug!("3: A .well-known file is available"); + hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); + match get_ip_with_port(&delegated_hostname) { + Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file + None => { + if let Some(pos) = delegated_hostname.find(':') { + debug!("3.2: Hostname with port in .well-known file"); + + let (host, port) = delegated_hostname.split_at(pos); + query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await; + + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + debug!("Delegated hostname has no port in this branch"); + if let Some(hostname_override) = query_srv_record(&delegated_hostname).await { + debug!("3.3: SRV lookup successful"); + + let force_port = hostname_override.port(); + query_and_cache_override( + &delegated_hostname, + &hostname_override.hostname(), + force_port.unwrap_or(8448), + ) + .await; + + if let Some(port) = force_port { + FedDest::Named(delegated_hostname, format!(":{port}")) + } else { + add_port_to_hostname(&delegated_hostname) + } + } else { + debug!("3.4: No SRV records, just use the hostname from .well-known"); + query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448).await; + add_port_to_hostname(&delegated_hostname) + } + } + }, + } + } else { + debug!("4: No .well-known or an error occured"); + if let Some(hostname_override) = query_srv_record(&destination_str).await { + debug!("4: SRV record found"); + + let force_port = hostname_override.port(); + query_and_cache_override(&hostname, &hostname_override.hostname(), force_port.unwrap_or(8448)) + .await; + + if let Some(port) = force_port { + FedDest::Named(hostname.clone(), format!(":{port}")) + } else { + add_port_to_hostname(&hostname) + } + } else { + debug!("5: No SRV record found"); + query_and_cache_override(&destination_str, &destination_str, 8448).await; + add_port_to_hostname(&destination_str) + } + } + } + }, + }; + + // Can't use get_ip_with_port here because we don't want to add a port + // to an IP address if it wasn't specified + let hostname = if let Ok(addr) = hostname.parse::() { + FedDest::Literal(addr) + } else if let Ok(addr) = hostname.parse::() { + FedDest::Named(addr.to_string(), ":8448".to_owned()) + } else if let Some(pos) = hostname.find(':') { + let (host, port) = hostname.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + FedDest::Named(hostname, ":8448".to_owned()) + }; + + debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}"); + (actual_destination, hostname) +} + +async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) { + match services() + .globals + .dns_resolver() + .lookup_ip(hostname.to_owned()) + .await + { + Ok(override_ip) => { + debug!("Caching result of {:?} overriding {:?}", hostname, overname); + + services() + .globals + .resolver + .overrides + .write() + .unwrap() + .insert(overname.to_owned(), (override_ip.iter().collect(), port)); + }, + Err(e) => { + debug!("Got {:?} for {:?} to override {:?}", e.kind(), hostname, overname); + }, + } +} + +async fn query_srv_record(hostname: &'_ str) -> Option { + fn handle_successful_srv(srv: &SrvLookup) -> Option { + srv.iter().next().map(|result| { + FedDest::Named( + result.target().to_string().trim_end_matches('.').to_owned(), + format!(":{}", result.port()), + ) + }) + } + + async fn lookup_srv(hostname: &str) -> Result { + debug!("querying SRV for {:?}", hostname); + let hostname = hostname.trim_end_matches('.'); + services() + .globals + .dns_resolver() + .srv_lookup(hostname.to_owned()) + .await + } + + let first_hostname = format!("_matrix-fed._tcp.{hostname}."); + let second_hostname = format!("_matrix._tcp.{hostname}."); + + lookup_srv(&first_hostname) + .or_else(|_| { + debug!("Querying deprecated _matrix SRV record for host {:?}", hostname); + lookup_srv(&second_hostname) + }) + .and_then(|srv_lookup| async move { Ok(handle_successful_srv(&srv_lookup)) }) + .await + .ok() + .flatten() +} + +async fn request_well_known(destination: &str) -> Option { + if !services() + .globals + .resolver + .overrides + .read() + .unwrap() + .contains_key(destination) + { + query_and_cache_override(destination, destination, 8448).await; + } + + let response = services() + .globals + .client + .well_known + .get(&format!("https://{destination}/.well-known/matrix/server")) + .send() + .await; + debug!("Got well known response"); + debug!("Well known response: {:?}", response); + + if let Err(e) = &response { + debug!("Well known error: {e:?}"); + return None; + } + + let text = response.ok()?.text().await; + + debug!("Got well known response text"); + debug!("Well known response text: {:?}", text); + + if text.as_ref().ok()?.len() > 10000 { + debug!( + "Well known response for destination '{destination}' exceeded past 10000 characters, assuming no \ + well-known." + ); + return None; + } + + let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; + debug!("serde_json body of well known text: {}", body); + + Some(body.get("m.server")?.as_str()?.to_owned()) +} + +impl FedDest { + fn into_https_string(self) -> String { + match self { + Self::Literal(addr) => format!("https://{addr}"), + Self::Named(host, port) => format!("https://{host}{port}"), + } + } + + fn into_uri_string(self) -> String { + match self { + Self::Literal(addr) => addr.to_string(), + Self::Named(host, port) => host + &port, + } + } + + fn hostname(&self) -> String { + match &self { + Self::Literal(addr) => addr.ip().to_string(), + Self::Named(host, _) => host.clone(), + } + } + + fn port(&self) -> Option { + match &self { + Self::Literal(addr) => Some(addr.port()), + Self::Named(_, port) => port[1..].parse().ok(), + } + } +} + +#[cfg(test)] +mod tests { + use super::{add_port_to_hostname, get_ip_with_port, FedDest}; + + #[test] + fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1"), + Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("dead:beef::"), + Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) + ); + } + + #[test] + fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1:1234"), + Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("[dead::beef]:8933"), + Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) + ); + } + + #[test] + fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname("example.com"), + FedDest::Named(String::from("example.com"), String::from(":8448")) + ); + } + + #[test] + fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname("example.com:1337"), + FedDest::Named(String::from("example.com"), String::from(":1337")) + ); + } +}