diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 660498f7..61eedca5 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -9,9 +9,9 @@ use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use ipaddress::IPAddress; use ruma::ServerName; -use crate::resolver::{ +use super::{ cache::{CachedDest, CachedOverride}, - fed::{add_port_to_hostname, get_ip_with_port, FedDest}, + fed::{add_port_to_hostname, get_ip_with_port, FedDest, PortString}, }; #[derive(Clone, Debug)] @@ -77,12 +77,12 @@ impl super::Service { let host = if let Ok(addr) = host.parse::() { FedDest::Literal(addr) } else if let Ok(addr) = host.parse::() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) + FedDest::Named(addr.to_string(), FedDest::default_port()) } else if let Some(pos) = host.find(':') { let (host, port) = host.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) + FedDest::Named(host.to_owned(), port.try_into().unwrap_or_else(|_| FedDest::default_port())) } else { - FedDest::Named(host, ":8448".to_owned()) + FedDest::Named(host, FedDest::default_port()) }; debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); @@ -103,7 +103,10 @@ impl super::Service { let (host, port) = dest.as_str().split_at(pos); self.conditional_query_and_cache_override(host, host, port.parse::().unwrap_or(8448), cache) .await?; - Ok(FedDest::Named(host.to_owned(), port.to_owned())) + Ok(FedDest::Named( + host.to_owned(), + port.try_into().unwrap_or_else(|_| FedDest::default_port()), + )) } async fn actual_dest_3(&self, host: &mut String, cache: bool, delegated: String) -> Result { @@ -136,7 +139,10 @@ impl super::Service { let (host, port) = delegated.split_at(pos); self.conditional_query_and_cache_override(host, host, port.parse::().unwrap_or(8448), cache) .await?; - Ok(FedDest::Named(host.to_owned(), port.to_owned())) + Ok(FedDest::Named( + host.to_owned(), + port.try_into().unwrap_or_else(|_| FedDest::default_port()), + )) } async fn actual_dest_3_3(&self, cache: bool, delegated: String, overrider: FedDest) -> Result { @@ -145,7 +151,13 @@ impl super::Service { self.conditional_query_and_cache_override(&delegated, &overrider.hostname(), force_port.unwrap_or(8448), cache) .await?; if let Some(port) = force_port { - Ok(FedDest::Named(delegated, format!(":{port}"))) + Ok(FedDest::Named( + delegated, + format!(":{port}") + .as_str() + .try_into() + .unwrap_or_else(|_| FedDest::default_port()), + )) } else { Ok(add_port_to_hostname(&delegated)) } @@ -164,7 +176,11 @@ impl super::Service { self.conditional_query_and_cache_override(host, &overrider.hostname(), force_port.unwrap_or(8448), cache) .await?; if let Some(port) = force_port { - Ok(FedDest::Named(host.to_owned(), format!(":{port}"))) + let port = format!(":{port}"); + Ok(FedDest::Named( + host.to_owned(), + PortString::from(port.as_str()).unwrap_or_else(|_| FedDest::default_port()), + )) } else { Ok(add_port_to_hostname(host)) } @@ -269,7 +285,10 @@ impl super::Service { srv.iter().next().map(|result| { FedDest::Named( result.target().to_string().trim_end_matches('.').to_owned(), - format!(":{}", result.port()), + format!(":{}", result.port()) + .as_str() + .try_into() + .unwrap_or_else(|_| FedDest::default_port()), ) }) } diff --git a/src/service/resolver/fed.rs b/src/service/resolver/fed.rs index 79f71f13..9c348b47 100644 --- a/src/service/resolver/fed.rs +++ b/src/service/resolver/fed.rs @@ -4,12 +4,19 @@ use std::{ net::{IpAddr, SocketAddr}, }; +use arrayvec::ArrayString; + #[derive(Clone, Debug, PartialEq, Eq)] pub enum FedDest { Literal(SocketAddr), - Named(String, String), + Named(String, PortString), } +/// numeric or service-name +pub type PortString = ArrayString<16>; + +const DEFAULT_PORT: &str = ":8448"; + pub(crate) fn get_ip_with_port(dest_str: &str) -> Option { if let Ok(dest) = dest_str.parse::() { Some(FedDest::Literal(dest)) @@ -20,13 +27,16 @@ pub(crate) fn get_ip_with_port(dest_str: &str) -> Option { } } -pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { - let (host, port) = match dest_str.find(':') { - None => (dest_str, ":8448"), - Some(pos) => dest_str.split_at(pos), +pub(crate) fn add_port_to_hostname(dest: &str) -> FedDest { + let (host, port) = match dest.find(':') { + None => (dest, DEFAULT_PORT), + Some(pos) => dest.split_at(pos), }; - FedDest::Named(host.to_owned(), port.to_owned()) + FedDest::Named( + host.to_owned(), + PortString::from(port).unwrap_or_else(|_| FedDest::default_port()), + ) } impl FedDest { @@ -60,6 +70,10 @@ impl FedDest { Self::Named(_, port) => port[1..].parse().ok(), } } + + #[inline] + #[must_use] + pub fn default_port() -> PortString { PortString::from(DEFAULT_PORT).expect("default port string") } } impl fmt::Display for FedDest { diff --git a/src/service/resolver/tests.rs b/src/service/resolver/tests.rs index 55cf0345..870f5eab 100644 --- a/src/service/resolver/tests.rs +++ b/src/service/resolver/tests.rs @@ -30,7 +30,7 @@ fn ips_keep_custom_ports() { fn hostnames_get_default_ports() { assert_eq!( add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) + FedDest::Named(String::from("example.com"), ":8448".try_into().unwrap()) ); } @@ -38,6 +38,6 @@ fn hostnames_get_default_ports() { fn hostnames_keep_custom_ports() { assert_eq!( add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) + FedDest::Named(String::from("example.com"), ":1337".try_into().unwrap()) ); }