diff --git a/src/api/client_server/unversioned.rs b/src/api/client_server/unversioned.rs index 69a477a6..56ab9a90 100644 --- a/src/api/client_server/unversioned.rs +++ b/src/api/client_server/unversioned.rs @@ -10,7 +10,7 @@ use ruma::api::client::{ error::ErrorKind, }; -use crate::{services, utils::conduwuit_version, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; /// # `GET /_matrix/client/versions` /// @@ -145,7 +145,7 @@ pub(crate) async fn syncv3_client_server_json() -> Result { Ok(Json(serde_json::json!({ "server": server_url, - "version": conduwuit_version(), + "version": conduit::version::conduwuit(), }))) } @@ -156,7 +156,7 @@ pub(crate) async fn syncv3_client_server_json() -> Result { pub(crate) async fn conduwuit_server_version() -> Result { Ok(Json(serde_json::json!({ "name": "conduwuit", - "version": conduwuit_version(), + "version": conduit::version::conduwuit(), }))) } diff --git a/src/api/server_server.rs b/src/api/server_server.rs index f0d2ef01..72429752 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -74,7 +74,7 @@ pub(crate) async fn get_server_version_route( Ok(get_server_version::v1::Response { server: Some(get_server_version::v1::Server { name: Some("Conduwuit".to_owned()), - version: Some(utils::conduwuit_version()), + version: Some(conduit::version::conduwuit()), }), }) } diff --git a/src/core/debug.rs b/src/core/debug.rs index 6db9a4dd..207f08fa 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -2,6 +2,9 @@ use std::{panic, panic::PanicInfo}; +/// Export all of the ancillary tools from here as well. +pub use crate::utils::debug::*; + /// Log event at given level in debug-mode (when debug-assertions are enabled). /// In release-mode it becomes DEBUG level, and possibly subject to elision. /// diff --git a/src/core/mod.rs b/src/core/mod.rs index 5911e027..baebbb8f 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -7,12 +7,12 @@ pub mod mods; pub mod pducount; pub mod server; pub mod utils; +pub mod version; pub use config::Config; pub use error::{Error, Result, RumaResponse}; pub use pducount::PduCount; pub use server::Server; -pub use utils::conduwuit_version; #[cfg(not(conduit_mods))] pub mod mods { diff --git a/src/core/utils/debug.rs b/src/core/utils/debug.rs new file mode 100644 index 00000000..e4151f39 --- /dev/null +++ b/src/core/utils/debug.rs @@ -0,0 +1,40 @@ +use std::fmt; + +/// Debug-formats the given slice, but only up to the first `max_len` elements. +/// Any further elements are replaced by an ellipsis. +/// +/// See also [`slice_truncated()`], +pub struct TruncatedSlice<'a, T> { + inner: &'a [T], + max_len: usize, +} + +impl fmt::Debug for TruncatedSlice<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.inner.len() <= self.max_len { + write!(f, "{:?}", self.inner) + } else { + f.debug_list() + .entries(&self.inner[..self.max_len]) + .entry(&"...") + .finish() + } + } +} + +/// See [`TruncatedSlice`]. Useful for `#[instrument]`: +/// +/// ``` +/// use conduit_core::utils::debug::slice_truncated; +/// +/// #[tracing::instrument(fields(foos = slice_truncated(foos, 42)))] +/// fn bar(foos: &[&str]); +/// ``` +pub fn slice_truncated( + slice: &[T], max_len: usize, +) -> tracing::field::DebugValue> { + tracing::field::debug(TruncatedSlice { + inner: slice, + max_len, + }) +} diff --git a/src/core/utils/html.rs b/src/core/utils/html.rs new file mode 100644 index 00000000..3b44a31b --- /dev/null +++ b/src/core/utils/html.rs @@ -0,0 +1,37 @@ +use std::fmt; + +/// Wrapper struct which will emit the HTML-escaped version of the contained +/// string when passed to a format string. +pub struct Escape<'a>(pub &'a str); + +/// Copied from librustdoc: +/// * +impl fmt::Display for Escape<'_> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + // Because the internet is always right, turns out there's not that many + // characters to escape: http://stackoverflow.com/questions/7381974 + let Escape(s) = *self; + let pile_o_bits = s; + let mut last = 0; + for (i, ch) in s.char_indices() { + let s = match ch { + '>' => ">", + '<' => "<", + '&' => "&", + '\'' => "'", + '"' => """, + _ => continue, + }; + fmt.write_str(&pile_o_bits[last..i])?; + fmt.write_str(s)?; + // NOTE: we only expect single byte characters here - which is fine as long as + // we only match single byte characters + last = i + 1; + } + + if last < s.len() { + fmt.write_str(&pile_o_bits[last..])?; + } + Ok(()) + } +} diff --git a/src/core/utils/json.rs b/src/core/utils/json.rs new file mode 100644 index 00000000..a9adad54 --- /dev/null +++ b/src/core/utils/json.rs @@ -0,0 +1,39 @@ +use std::{fmt, str::FromStr}; + +use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; + +use crate::Result; + +/// Fallible conversion from any value that implements `Serialize` to a +/// `CanonicalJsonObject`. +/// +/// `value` must serialize to an `serde_json::Value::Object`. +pub fn to_canonical_object(value: T) -> Result { + use serde::ser::Error; + + match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? { + serde_json::Value::Object(map) => try_from_json_map(map), + _ => Err(CanonicalJsonError::SerDe(serde_json::Error::custom("Value must be an object"))), + } +} + +pub fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: FromStr, E: fmt::Display>( + deserializer: D, +) -> Result { + struct Visitor, E>(std::marker::PhantomData); + impl, Err: fmt::Display> serde::de::Visitor<'_> for Visitor { + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "a parsable string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse().map_err(serde::de::Error::custom) + } + } + deserializer.deserialize_str(Visitor(std::marker::PhantomData)) +} diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 582263a9..e85080f1 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,21 +1,26 @@ +pub mod content_disposition; +pub mod debug; +pub mod defer; +pub mod html; +pub mod json; +pub mod sys; + use std::{ cmp, cmp::Ordering, - fmt, - str::FromStr, time::{SystemTime, UNIX_EPOCH}, }; +pub use debug::slice_truncated as debug_slice_truncated; +pub use html::Escape as HtmlEscape; +pub use json::{deserialize_from_str, to_canonical_object}; use rand::prelude::*; use ring::digest; -use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId}; -use tracing::debug; +use ruma::OwnedUserId; +pub use sys::available_parallelism; use crate::{Error, Result}; -pub mod content_disposition; -pub mod defer; - pub fn clamp(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } #[must_use] @@ -108,178 +113,6 @@ pub fn common_elements( })) } -/// Fallible conversion from any value that implements `Serialize` to a -/// `CanonicalJsonObject`. -/// -/// `value` must serialize to an `serde_json::Value::Object`. -pub fn to_canonical_object(value: T) -> Result { - use serde::ser::Error; - - match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? { - serde_json::Value::Object(map) => try_from_json_map(map), - _ => Err(CanonicalJsonError::SerDe(serde_json::Error::custom("Value must be an object"))), - } -} - -pub fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: FromStr, E: fmt::Display>( - deserializer: D, -) -> Result { - struct Visitor, E>(std::marker::PhantomData); - impl, Err: fmt::Display> serde::de::Visitor<'_> for Visitor { - type Value = T; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "a parsable string") - } - - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, - { - v.parse().map_err(serde::de::Error::custom) - } - } - deserializer.deserialize_str(Visitor(std::marker::PhantomData)) -} - -// Copied from librustdoc: -// https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs - -/// Wrapper struct which will emit the HTML-escaped version of the contained -/// string when passed to a format string. -pub struct HtmlEscape<'a>(pub &'a str); - -impl fmt::Display for HtmlEscape<'_> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - // Because the internet is always right, turns out there's not that many - // characters to escape: http://stackoverflow.com/questions/7381974 - let HtmlEscape(s) = *self; - let pile_o_bits = s; - let mut last = 0; - for (i, ch) in s.char_indices() { - let s = match ch { - '>' => ">", - '<' => "<", - '&' => "&", - '\'' => "'", - '"' => """, - _ => continue, - }; - fmt.write_str(&pile_o_bits[last..i])?; - fmt.write_str(s)?; - // NOTE: we only expect single byte characters here - which is fine as long as - // we only match single byte characters - last = i + 1; - } - - if last < s.len() { - fmt.write_str(&pile_o_bits[last..])?; - } - Ok(()) - } -} - -/// one true function for returning the conduwuit version with the necessary -/// CONDUWUIT_VERSION_EXTRA env variables used if specified -/// -/// Set the environment variable `CONDUWUIT_VERSION_EXTRA` to any UTF-8 string -/// to include it in parenthesis after the SemVer version. A common value are -/// git commit hashes. -#[must_use] -pub fn conduwuit_version() -> String { - match option_env!("CONDUWUIT_VERSION_EXTRA") { - Some(extra) => { - if extra.is_empty() { - env!("CARGO_PKG_VERSION").to_owned() - } else { - format!("{} ({})", env!("CARGO_PKG_VERSION"), extra) - } - }, - None => match option_env!("CONDUIT_VERSION_EXTRA") { - Some(extra) => { - if extra.is_empty() { - env!("CARGO_PKG_VERSION").to_owned() - } else { - format!("{} ({})", env!("CARGO_PKG_VERSION"), extra) - } - }, - None => env!("CARGO_PKG_VERSION").to_owned(), - }, - } -} - -/// Debug-formats the given slice, but only up to the first `max_len` elements. -/// Any further elements are replaced by an ellipsis. -/// -/// See also [`debug_slice_truncated()`], -pub struct TruncatedDebugSlice<'a, T> { - inner: &'a [T], - max_len: usize, -} - -impl fmt::Debug for TruncatedDebugSlice<'_, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.inner.len() <= self.max_len { - write!(f, "{:?}", self.inner) - } else { - f.debug_list() - .entries(&self.inner[..self.max_len]) - .entry(&"...") - .finish() - } - } -} - -/// See [`TruncatedDebugSlice`]. Useful for `#[instrument]`: -/// -/// ``` -/// use conduit_core::utils::debug_slice_truncated; -/// -/// #[tracing::instrument(fields(foos = debug_slice_truncated(foos, 42)))] -/// fn bar(foos: &[&str]); -/// ``` -pub fn debug_slice_truncated( - slice: &[T], max_len: usize, -) -> tracing::field::DebugValue> { - tracing::field::debug(TruncatedDebugSlice { - inner: slice, - max_len, - }) -} - -/// This is needed for opening lots of file descriptors, which tends to -/// happen more often when using RocksDB and making lots of federation -/// connections at startup. The soft limit is usually 1024, and the hard -/// limit is usually 512000; I've personally seen it hit >2000. -/// -/// * -/// * -#[cfg(unix)] -pub fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { - use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE}; - - let (soft_limit, hard_limit) = getrlimit(NOFILE)?; - if soft_limit < hard_limit { - setrlimit(NOFILE, hard_limit, hard_limit)?; - assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit"); - debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",); - } - - Ok(()) -} - -/// Get the number of threads which could execute in parallel based on the -/// hardware and administrative constraints of this system. This value should be -/// used to hint the size of thread-pools and divide-and-conquer algorithms. -/// -/// * -#[must_use] -pub fn available_parallelism() -> usize { - std::thread::available_parallelism() - .expect("Unable to query for available parallelism.") - .get() -} - /// Boilerplate for wraps which are typed to never error. /// /// * diff --git a/src/core/utils/sys.rs b/src/core/utils/sys.rs new file mode 100644 index 00000000..825ec903 --- /dev/null +++ b/src/core/utils/sys.rs @@ -0,0 +1,36 @@ +use tracing::debug; + +use crate::Result; + +/// This is needed for opening lots of file descriptors, which tends to +/// happen more often when using RocksDB and making lots of federation +/// connections at startup. The soft limit is usually 1024, and the hard +/// limit is usually 512000; I've personally seen it hit >2000. +/// +/// * +/// * +#[cfg(unix)] +pub fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { + use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE}; + + let (soft_limit, hard_limit) = getrlimit(NOFILE)?; + if soft_limit < hard_limit { + setrlimit(NOFILE, hard_limit, hard_limit)?; + assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit"); + debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",); + } + + Ok(()) +} + +/// Get the number of threads which could execute in parallel based on the +/// hardware and administrative constraints of this system. This value should be +/// used to hint the size of thread-pools and divide-and-conquer algorithms. +/// +/// * +#[must_use] +pub fn available_parallelism() -> usize { + std::thread::available_parallelism() + .expect("Unable to query for available parallelism.") + .get() +} diff --git a/src/core/version.rs b/src/core/version.rs new file mode 100644 index 00000000..f65bac99 --- /dev/null +++ b/src/core/version.rs @@ -0,0 +1,28 @@ +/// one true function for returning the conduwuit version with the necessary +/// CONDUWUIT_VERSION_EXTRA env variables used if specified +/// +/// Set the environment variable `CONDUWUIT_VERSION_EXTRA` to any UTF-8 string +/// to include it in parenthesis after the SemVer version. A common value are +/// git commit hashes. +#[must_use] +pub fn conduwuit() -> String { + match option_env!("CONDUWUIT_VERSION_EXTRA") { + Some(extra) => { + if extra.is_empty() { + env!("CARGO_PKG_VERSION").to_owned() + } else { + format!("{} ({})", env!("CARGO_PKG_VERSION"), extra) + } + }, + None => match option_env!("CONDUIT_VERSION_EXTRA") { + Some(extra) => { + if extra.is_empty() { + env!("CARGO_PKG_VERSION").to_owned() + } else { + format!("{} ({})", env!("CARGO_PKG_VERSION"), extra) + } + }, + None => env!("CARGO_PKG_VERSION").to_owned(), + }, + } +} diff --git a/src/main/clap.rs b/src/main/clap.rs index 81a6da72..a2fb588e 100644 --- a/src/main/clap.rs +++ b/src/main/clap.rs @@ -3,11 +3,10 @@ use std::path::PathBuf; use clap::Parser; -use conduit_core::utils::conduwuit_version; /// Commandline arguments #[derive(Parser, Debug)] -#[clap(version = conduwuit_version(), about, long_about = None)] +#[clap(version = conduit::version::conduwuit(), about, long_about = None)] pub(crate) struct Args { #[arg(short, long)] /// Optional argument to the path of a conduwuit config TOML file diff --git a/src/main/server.rs b/src/main/server.rs index c3f6a928..2395469b 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -1,11 +1,10 @@ use std::sync::Arc; use conduit::{ - conduwuit_version, config::Config, info, log::{LogLevelReloadHandles, ReloadHandle}, - utils::maximize_fd_limit, + utils::sys::maximize_fd_limit, Error, Result, }; use tokio::runtime; @@ -43,7 +42,7 @@ impl Server { database_path = ?config.database_path, log_levels = %config.log, "{}", - conduwuit_version(), + conduit::version::conduwuit(), ); Ok(Arc::new(Server { diff --git a/src/service/globals/client.rs b/src/service/globals/client.rs index 82747ae7..33f6d85f 100644 --- a/src/service/globals/client.rs +++ b/src/service/globals/client.rs @@ -2,7 +2,7 @@ use std::{sync::Arc, time::Duration}; use reqwest::redirect; -use crate::{service::globals::resolver, utils::conduwuit_version, Config, Result}; +use crate::{service::globals::resolver, Config, Result}; pub struct Client { pub default: reqwest::Client, @@ -87,7 +87,7 @@ impl Client { } fn base(config: &Config) -> Result { - let version = conduwuit_version(); + let version = conduit::version::conduwuit(); let user_agent = format!("Conduwuit/{version}");