Hot-Reloading Refactor

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-09 15:59:08 -07:00 committed by June 🍓🦴
parent ae1a4fd283
commit 6c1434c165
212 changed files with 5679 additions and 4206 deletions

133
src/core/Cargo.toml Normal file
View file

@ -0,0 +1,133 @@
[package]
name = "conduit_core"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"rocksdb",
"io_uring",
"jemalloc",
"gzip_compression",
"zstd_compression",
"brotli_compression",
"sentry_telemetry",
"release_max_log_level",
]
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
sqlite = [
"dep:rusqlite",
"dep:parking_lot",
"dep:thread_local",
]
rocksdb = [
"dep:rust-rocksdb",
]
jemalloc = [
"dep:tikv-jemalloc-sys",
"dep:tikv-jemalloc-ctl",
"dep:tikv-jemallocator",
"rust-rocksdb/jemalloc",
]
jemalloc_prof = [
"tikv-jemalloc-sys/profiling",
]
hardened_malloc = [
"dep:hardened_malloc-rs"
]
io_uring = [
"rust-rocksdb/io-uring",
]
zstd_compression = [
"rust-rocksdb/zstd",
]
gzip_compression = [
"reqwest/gzip",
]
brotli_compression = [
"reqwest/brotli",
]
perf_measurements = []
sentry_telemetry = []
mods = [
"dep:libloading"
]
panic_trap = []
[dependencies]
async-trait.workspace = true
axum-server.workspace = true
axum.workspace = true
base64.workspace = true
bytes.workspace = true
clap.workspace = true
cyborgtime.workspace = true
either.workspace = true
figment.workspace = true
futures-util.workspace = true
http-body-util.workspace = true
http.workspace = true
image.workspace = true
infer.workspace = true
ipaddress.workspace = true
itertools.workspace = true
libloading.workspace = true
libloading.optional = true
log.workspace = true
lru-cache.workspace = true
parking_lot.optional = true
parking_lot.workspace = true
rand.workspace = true
regex.workspace = true
reqwest.workspace = true
ring.workspace = true
ruma.workspace = true
rusqlite.optional = true
rusqlite.workspace = true
rust-rocksdb.optional = true
rust-rocksdb.workspace = true
sanitize-filename.workspace = true
serde_json.workspace = true
serde_regex.workspace = true
serde.workspace = true
serde_yaml.workspace = true
sha-1.workspace = true
thiserror.workspace = true
thread_local.optional = true
thread_local.workspace = true
tikv-jemallocator.optional = true
tikv-jemallocator.workspace = true
tikv-jemalloc-ctl.optional = true
tikv-jemalloc-ctl.workspace = true
tikv-jemalloc-sys.optional = true
tikv-jemalloc-sys.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
url.workspace = true
zstd.optional = true
zstd.workspace = true
[target.'cfg(unix)'.dependencies]
nix.workspace = true
[target.'cfg(all(not(target_env = "msvc"), target_os = "linux"))'.dependencies]
hardened_malloc-rs.workspace = true
hardened_malloc-rs.optional = true
[lints]
workspace = true

View file

@ -0,0 +1,9 @@
//! Default allocator with no special features
/// Always returns the empty string
#[must_use]
pub fn memory_stats() -> String { Default::default() }
/// Always returns the empty string
#[must_use]
pub fn memory_usage() -> String { Default::default() }

View file

@ -0,0 +1,10 @@
#[global_allocator]
static HMALLOC: hardened_malloc_rs::HardenedMalloc = hardened_malloc_rs::HardenedMalloc;
#[must_use]
pub fn memory_usage() -> String {
String::default() //TODO: get usage
}
#[must_use]
pub fn memory_stats() -> String { "Extended statistics are not available from hardened_malloc.".to_owned() }

52
src/core/alloc/je.rs Normal file
View file

@ -0,0 +1,52 @@
use std::ffi::{c_char, c_void};
use tikv_jemalloc_ctl as mallctl;
use tikv_jemalloc_sys as ffi;
use tikv_jemallocator as jemalloc;
#[global_allocator]
static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc;
#[must_use]
pub fn memory_usage() -> String {
use mallctl::stats;
let allocated = stats::allocated::read().unwrap_or_default() as f64 / 1024.0 / 1024.0;
let active = stats::active::read().unwrap_or_default() as f64 / 1024.0 / 1024.0;
let mapped = stats::mapped::read().unwrap_or_default() as f64 / 1024.0 / 1024.0;
let metadata = stats::metadata::read().unwrap_or_default() as f64 / 1024.0 / 1024.0;
let resident = stats::resident::read().unwrap_or_default() as f64 / 1024.0 / 1024.0;
let retained = stats::retained::read().unwrap_or_default() as f64 / 1024.0 / 1024.0;
format!(
" allocated: {allocated:.2} MiB\n active: {active:.2} MiB\n mapped: {mapped:.2} MiB\n metadata: {metadata:.2} \
MiB\n resident: {resident:.2} MiB\n retained: {retained:.2} MiB\n "
)
}
#[must_use]
pub fn memory_stats() -> String {
const MAX_LENGTH: usize = 65536 - 4096;
let opts_s = "d";
let mut str = String::new();
let opaque = std::ptr::from_mut(&mut str).cast::<c_void>();
let opts_p: *const c_char = std::ffi::CString::new(opts_s).expect("cstring").into_raw() as *const c_char;
// SAFETY: calls malloc_stats_print() with our string instance which must remain
// in this frame. https://docs.rs/tikv-jemalloc-sys/latest/tikv_jemalloc_sys/fn.malloc_stats_print.html
unsafe { ffi::malloc_stats_print(Some(malloc_stats_cb), opaque, opts_p) };
str.truncate(MAX_LENGTH);
format!("<pre><code>{str}</code></pre>")
}
extern "C" fn malloc_stats_cb(opaque: *mut c_void, msg: *const c_char) {
// SAFETY: we have to trust the opaque points to our String
let res: &mut String = unsafe { opaque.cast::<String>().as_mut().unwrap() };
// SAFETY: we have to trust the string is null terminated.
let msg = unsafe { std::ffi::CStr::from_ptr(msg) };
let msg = String::from_utf8_lossy(msg.to_bytes());
res.push_str(msg.as_ref());
}

25
src/core/alloc/mod.rs Normal file
View file

@ -0,0 +1,25 @@
//! Integration with allocators
// jemalloc
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc", not(feature = "hardened_malloc")))]
pub mod je;
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc", not(feature = "hardened_malloc")))]
pub use je::{memory_stats, memory_usage};
// hardened_malloc
#[cfg(all(not(target_env = "msvc"), feature = "hardened_malloc", target_os = "linux", not(feature = "jemalloc")))]
pub mod hardened;
#[cfg(all(not(target_env = "msvc"), feature = "hardened_malloc", target_os = "linux", not(feature = "jemalloc")))]
pub use hardened::{memory_stats, memory_usage};
// default, enabled when none or multiple of the above are enabled
#[cfg(any(
not(any(feature = "jemalloc", feature = "hardened_malloc")),
all(feature = "jemalloc", feature = "hardened_malloc"),
))]
pub mod default;
#[cfg(any(
not(any(feature = "jemalloc", feature = "hardened_malloc")),
all(feature = "jemalloc", feature = "hardened_malloc"),
))]
pub use default::{memory_stats, memory_usage};

170
src/core/config/check.rs Normal file
View file

@ -0,0 +1,170 @@
#[cfg(unix)]
use std::path::Path; // not unix specific, just only for UNIX sockets stuff and *nix container checks
use tracing::{debug, error, info, warn};
use crate::{error::Error, Config};
pub fn check(config: &Config) -> Result<(), Error> {
config.warn_deprecated();
config.warn_unknown_key();
if config.sentry && config.sentry_endpoint.is_none() {
return Err(Error::bad_config("Sentry cannot be enabled without an endpoint set"));
}
if cfg!(feature = "hardened_malloc") && cfg!(feature = "jemalloc") {
warn!(
"hardened_malloc and jemalloc were built together, this causes neither to be used. Conduwuit will still \
function, but consider rebuilding and pick one as this is now no-op."
);
}
if config.unix_socket_path.is_some() && !cfg!(unix) {
return Err(Error::bad_config(
"UNIX socket support is only available on *nix platforms. Please remove \"unix_socket_path\" from your \
config.",
));
}
if config.address.is_loopback() && cfg!(unix) {
debug!(
"Found loopback listening address {}, running checks if we're in a container.",
config.address
);
#[cfg(unix)]
if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists()
/* Host */
{
error!(
"You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using \
OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this \
will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.",
config.address
);
}
#[cfg(unix)]
if Path::new("/.dockerenv").exists() {
error!(
"You are detected using Docker with a loopback/localhost listening address of {}. If you are using a \
reverse proxy on the host and require communication to conduwuit in the Docker container via \
NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \
you can ignore.",
config.address
);
}
#[cfg(unix)]
if Path::new("/run/.containerenv").exists() {
error!(
"You are detected using Podman with a loopback/localhost listening address of {}. If you are using a \
reverse proxy on the host and require communication to conduwuit in the Podman container via \
NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \
you can ignore.",
config.address
);
}
}
// rocksdb does not allow max_log_files to be 0
if config.rocksdb_max_log_files == 0 && cfg!(feature = "rocksdb") {
return Err(Error::bad_config(
"When using RocksDB, rocksdb_max_log_files cannot be 0. Please set a value at least 1.",
));
}
// yeah, unless the user built a debug build hopefully for local testing only
if config.server_name == "your.server.name" && !cfg!(debug_assertions) {
return Err(Error::bad_config(
"You must specify a valid server name for production usage of conduwuit.",
));
}
if cfg!(debug_assertions) {
info!("Note: conduwuit was built without optimisations (i.e. debug build)");
}
// check if the user specified a registration token as `""`
if config.registration_token == Some(String::new()) {
return Err(Error::bad_config("Registration token was specified but is empty (\"\")"));
}
if config.max_request_size < 5120000 {
return Err(Error::bad_config("Max request size is less than 5MB. Please increase it."));
}
// check if user specified valid IP CIDR ranges on startup
for cidr in &config.ip_range_denylist {
if let Err(e) = ipaddress::IPAddress::parse(cidr) {
error!("Error parsing specified IP CIDR range from string: {e}");
return Err(Error::bad_config("Error parsing specified IP CIDR ranges from strings"));
}
}
if config.allow_registration
&& !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
&& config.registration_token.is_none()
{
return Err(Error::bad_config(
"!! You have `allow_registration` enabled without a token configured in your config which means you are \
allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n
If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n
For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you \
want, please set the following config option to true:
`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`",
));
}
if config.allow_registration
&& config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
&& config.registration_token.is_none()
{
warn!(
"Open registration is enabled via setting \
`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to \
true without a registration token configured. You are expected to be aware of the risks now.\n
If this is not the desired behaviour, please set a registration token."
);
}
if config.allow_outgoing_presence && !config.allow_local_presence {
return Err(Error::bad_config(
"Outgoing presence requires allowing local presence. Please enable \"allow_local_presence\".",
));
}
if config
.url_preview_domain_contains_allowlist
.contains(&"*".to_owned())
{
warn!(
"All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \
This opens up significant attack surface to your server. You are expected to be aware of the risks by \
doing this."
);
}
if config
.url_preview_domain_explicit_allowlist
.contains(&"*".to_owned())
{
warn!(
"All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \
This opens up significant attack surface to your server. You are expected to be aware of the risks by \
doing this."
);
}
if config
.url_preview_url_contains_allowlist
.contains(&"*".to_owned())
{
warn!(
"All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \
opens up significant attack surface to your server. You are expected to be aware of the risks by doing \
this."
);
}
Ok(())
}

1072
src/core/config/mod.rs Normal file

File diff suppressed because it is too large Load diff

148
src/core/config/proxy.rs Normal file
View file

@ -0,0 +1,148 @@
use reqwest::{Proxy, Url};
use serde::Deserialize;
use crate::Result;
/// ## Examples:
/// - No proxy (default):
/// ```toml
/// proxy ="none"
/// ```
/// - Global proxy
/// ```toml
/// [global.proxy]
/// global = { url = "socks5h://localhost:9050" }
/// ```
/// - Proxy some domains
/// ```toml
/// [global.proxy]
/// [[global.proxy.by_domain]]
/// url = "socks5h://localhost:9050"
/// include = ["*.onion", "matrix.myspecial.onion"]
/// exclude = ["*.myspecial.onion"]
/// ```
/// ## Include vs. Exclude
/// If include is an empty list, it is assumed to be `["*"]`.
///
/// If a domain matches both the exclude and include list, the proxy will only
/// be used if it was included because of a more specific rule than it was
/// excluded. In the above example, the proxy would be used for
/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
#[derive(Clone, Default, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ProxyConfig {
#[default]
None,
Global {
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
url: Url,
},
ByDomain(Vec<PartialProxyConfig>),
}
impl ProxyConfig {
pub fn to_proxy(&self) -> Result<Option<Proxy>> {
Ok(match self.clone() {
ProxyConfig::None => None,
ProxyConfig::Global {
url,
} => Some(Proxy::all(url)?),
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching
// proxy
})),
})
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct PartialProxyConfig {
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
url: Url,
#[serde(default)]
include: Vec<WildCardedDomain>,
#[serde(default)]
exclude: Vec<WildCardedDomain>,
}
impl PartialProxyConfig {
#[must_use]
pub fn for_url(&self, url: &Url) -> Option<&Url> {
let domain = url.domain()?;
let mut included_because = None; // most specific reason it was included
let mut excluded_because = None; // most specific reason it was excluded
if self.include.is_empty() {
// treat empty include list as `*`
included_because = Some(&WildCardedDomain::WildCard);
}
for wc_domain in &self.include {
if wc_domain.matches(domain) {
match included_because {
Some(prev) if !wc_domain.more_specific_than(prev) => (),
_ => included_because = Some(wc_domain),
}
}
}
for wc_domain in &self.exclude {
if wc_domain.matches(domain) {
match excluded_because {
Some(prev) if !wc_domain.more_specific_than(prev) => (),
_ => excluded_because = Some(wc_domain),
}
}
}
match (included_because, excluded_because) {
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), /* included for a more specific reason */
// than excluded
(Some(_), None) => Some(&self.url),
_ => None,
}
}
}
/// A domain name, that optionally allows a * as its first subdomain.
#[derive(Clone, Debug)]
enum WildCardedDomain {
WildCard,
WildCarded(String),
Exact(String),
}
impl WildCardedDomain {
fn matches(&self, domain: &str) -> bool {
match self {
WildCardedDomain::WildCard => true,
WildCardedDomain::WildCarded(d) => domain.ends_with(d),
WildCardedDomain::Exact(d) => domain == d,
}
}
fn more_specific_than(&self, other: &Self) -> bool {
match (self, other) {
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
(_, WildCardedDomain::WildCard) => true,
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => a != b && a.ends_with(b),
_ => false,
}
}
}
impl std::str::FromStr for WildCardedDomain {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// maybe do some domain validation?
Ok(if s.starts_with("*.") {
WildCardedDomain::WildCarded(s[1..].to_owned())
} else if s == "*" {
WildCardedDomain::WildCarded(String::new())
} else {
WildCardedDomain::Exact(s.to_owned())
})
}
}
impl<'de> Deserialize<'de> for WildCardedDomain {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
crate::utils::deserialize_from_str(deserializer)
}
}

78
src/core/debug.rs Normal file
View file

@ -0,0 +1,78 @@
#![allow(dead_code)] // this is a developer's toolbox
use std::{panic, panic::PanicInfo};
/// Log event at given level in debug-mode (when debug-assertions are enabled).
/// In release-mode it becomes DEBUG level, and possibly subject to elision.
///
/// Release-mode can be simulated in debug-mode builds by enabling the feature
/// 'dev_release_log_level'.
#[macro_export]
macro_rules! debug_event {
( $level:expr, $($x:tt)+ ) => {
if cfg!(debug_assertions) && cfg!(not(feature = "dev_release_log_level")) {
tracing::event!( $level, $($x)+ );
} else {
tracing::debug!( $($x)+ );
}
}
}
/// Log message at the ERROR level in debug-mode (when debug-assertions are
/// enabled). In release-mode it becomes DEBUG level, and possibly subject to
/// elision.
#[macro_export]
macro_rules! debug_error {
( $($x:tt)+ ) => {
$crate::debug_event!(tracing::Level::ERROR, $($x)+ );
}
}
/// Log message at the WARN level in debug-mode (when debug-assertions are
/// enabled). In release-mode it becomes DEBUG level, and possibly subject to
/// elision.
#[macro_export]
macro_rules! debug_warn {
( $($x:tt)+ ) => {
$crate::debug_event!(tracing::Level::WARN, $($x)+ );
}
}
/// Log message at the INFO level in debug-mode (when debug-assertions are
/// enabled). In release-mode it becomes DEBUG level, and possibly subject to
/// elision.
#[macro_export]
macro_rules! debug_info {
( $($x:tt)+ ) => {
$crate::debug_event!(tracing::Level::INFO, $($x)+ );
}
}
pub fn set_panic_trap() {
let next = panic::take_hook();
panic::set_hook(Box::new(move |info| {
panic_handler(info, &next);
}));
}
#[inline(always)]
fn panic_handler(info: &PanicInfo<'_>, next: &dyn Fn(&PanicInfo<'_>)) {
trap();
next(info);
}
#[inline(always)]
#[allow(unexpected_cfgs)]
pub fn trap() {
#[cfg(core_intrinsics)]
//SAFETY: embeds llvm intrinsic for hardware breakpoint
unsafe {
std::intrinsics::breakpoint();
}
#[cfg(all(not(core_intrinsics), target_arch = "x86_64"))]
//SAFETY: embeds instruction for hardware breakpoint
unsafe {
std::arch::asm!("int3");
}
}

220
src/core/error.rs Normal file
View file

@ -0,0 +1,220 @@
use std::{convert::Infallible, fmt};
use axum::response::{IntoResponse, Response};
use bytes::BytesMut;
use http::StatusCode;
use http_body_util::Full;
use ruma::{
api::{
client::{
error::{Error as RumaError, ErrorBody, ErrorKind},
uiaa::{UiaaInfo, UiaaResponse},
},
OutgoingResponse,
},
OwnedServerName,
};
use thiserror::Error;
use tracing::error;
use ErrorKind::{
Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken, NotFound, ThreepidAuthFailed, ThreepidDenied,
TooLarge, Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated, WrongRoomKeysVersion,
};
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Error)]
pub enum Error {
#[cfg(feature = "sqlite")]
#[error("There was a problem with the connection to the sqlite database: {source}")]
Sqlite {
#[from]
source: rusqlite::Error,
},
#[cfg(feature = "rocksdb")]
#[error("There was a problem with the connection to the rocksdb database: {source}")]
RocksDb {
#[from]
source: rust_rocksdb::Error,
},
#[error("Could not generate an image.")]
Image {
#[from]
source: image::error::ImageError,
},
#[error("Could not connect to server: {source}")]
Reqwest {
#[from]
source: reqwest::Error,
},
#[error("Could build regular expression: {source}")]
Regex {
#[from]
source: regex::Error,
},
#[error("{0}")]
Federation(OwnedServerName, RumaError),
#[error("Could not do this io: {source}")]
Io {
#[from]
source: std::io::Error,
},
#[error("There was a problem with your configuration file: {0}")]
BadConfig(String),
#[error("{0}")]
BadServerResponse(&'static str),
#[error("{0}")]
/// Don't create this directly. Use Error::bad_database instead.
BadDatabase(&'static str),
#[error("uiaa")]
Uiaa(UiaaInfo),
#[error("{0}: {1}")]
BadRequest(ErrorKind, &'static str),
#[error("{0}")]
Conflict(&'static str), // This is only needed for when a room alias already exists
#[error("{0}")]
Extension(#[from] axum::extract::rejection::ExtensionRejection),
#[error("{0}")]
Path(#[from] axum::extract::rejection::PathRejection),
#[error("from {0}: {1}")]
Redaction(OwnedServerName, ruma::canonical_json::RedactionError),
#[error("{0} in {1}")]
InconsistentRoomState(&'static str, ruma::OwnedRoomId),
#[error("{0}")]
AdminCommand(&'static str),
#[error("{0}")]
Err(String),
}
impl Error {
pub fn bad_database(message: &'static str) -> Self {
error!("BadDatabase: {}", message);
Self::BadDatabase(message)
}
pub fn bad_config(message: &str) -> Self {
error!("BadConfig: {}", message);
Self::BadConfig(message.to_owned())
}
/// Returns the Matrix error code / error kind
pub fn error_code(&self) -> ErrorKind {
if let Self::Federation(_, error) = self {
return error.error_kind().unwrap_or_else(|| &Unknown).clone();
}
match self {
Self::BadRequest(kind, _) => kind.clone(),
_ => Unknown,
}
}
/// Sanitizes public-facing errors that can leak sensitive information.
pub fn sanitized_error(&self) -> String {
let db_error = String::from("Database or I/O error occurred.");
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite {
..
} => db_error,
#[cfg(feature = "rocksdb")]
Self::RocksDb {
..
} => db_error,
Self::Io {
..
} => db_error,
_ => self.to_string(),
}
}
}
impl From<Infallible> for Error {
fn from(i: Infallible) -> Self { match i {} }
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self) }
}
#[derive(Clone)]
pub struct RumaResponse<T>(pub T);
impl<T> From<T> for RumaResponse<T> {
fn from(t: T) -> Self { Self(t) }
}
impl From<Error> for RumaResponse<UiaaResponse> {
fn from(t: Error) -> Self { t.to_response() }
}
impl Error {
pub fn to_response(&self) -> RumaResponse<UiaaResponse> {
if let Self::Uiaa(uiaainfo) = self {
return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone()));
}
if let Self::Federation(origin, error) = self {
let mut error = error.clone();
error.body = ErrorBody::Standard {
kind: error.error_kind().unwrap_or_else(|| &Unknown).clone(),
message: format!("Answer from {origin}: {error}"),
};
return RumaResponse(UiaaResponse::MatrixError(error));
}
let message = format!("{self}");
let (kind, status_code) = match self {
Self::BadRequest(kind, _) => (
kind.clone(),
match kind {
WrongRoomKeysVersion {
..
}
| Forbidden {
..
}
| GuestAccessForbidden
| ThreepidAuthFailed
| UserDeactivated
| ThreepidDenied => StatusCode::FORBIDDEN,
Unauthorized
| UnknownToken {
..
}
| MissingToken => StatusCode::UNAUTHORIZED,
NotFound | Unrecognized => StatusCode::NOT_FOUND,
LimitExceeded {
..
} => StatusCode::TOO_MANY_REQUESTS,
TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
_ => StatusCode::BAD_REQUEST,
},
),
Self::Conflict(_) => (Unknown, StatusCode::CONFLICT),
_ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR),
};
RumaResponse(UiaaResponse::MatrixError(RumaError {
body: ErrorBody::Standard {
kind,
message,
},
status_code,
}))
}
}
impl ::axum::response::IntoResponse for Error {
fn into_response(self) -> ::axum::response::Response { self.to_response().into_response() }
}
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
fn into_response(self) -> Response {
match self.0.try_into_http_response::<BytesMut>() {
Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
}
}
}

79
src/core/log.rs Normal file
View file

@ -0,0 +1,79 @@
use std::sync::Arc;
use tracing_subscriber::{reload, EnvFilter};
/// We need to store a reload::Handle value, but can't name it's type explicitly
/// because the S type parameter depends on the subscriber's previous layers. In
/// our case, this includes unnameable 'impl Trait' types.
///
/// This is fixed[1] in the unreleased tracing-subscriber from the master
/// branch, which removes the S parameter. Unfortunately can't use it without
/// pulling in a version of tracing that's incompatible with the rest of our
/// deps.
///
/// To work around this, we define an trait without the S paramter that forwards
/// to the reload::Handle::reload method, and then store the handle as a trait
/// object.
///
/// [1]: <https://github.com/tokio-rs/tracing/pull/1035/commits/8a87ea52425098d3ef8f56d92358c2f6c144a28f>
pub trait ReloadHandle<L> {
fn reload(&self, new_value: L) -> Result<(), reload::Error>;
}
impl<L, S> ReloadHandle<L> for reload::Handle<L, S> {
fn reload(&self, new_value: L) -> Result<(), reload::Error> { reload::Handle::reload(self, new_value) }
}
struct LogLevelReloadHandlesInner {
handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>,
}
/// Wrapper to allow reloading the filter on several several
/// [`tracing_subscriber::reload::Handle`]s at once, with the same value.
#[derive(Clone)]
pub struct LogLevelReloadHandles {
inner: Arc<LogLevelReloadHandlesInner>,
}
impl LogLevelReloadHandles {
#[must_use]
pub fn new(handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>) -> LogLevelReloadHandles {
LogLevelReloadHandles {
inner: Arc::new(LogLevelReloadHandlesInner {
handles,
}),
}
}
pub fn reload(&self, new_value: &EnvFilter) -> Result<(), reload::Error> {
for handle in &self.inner.handles {
handle.reload(new_value.clone())?;
}
Ok(())
}
}
#[macro_export]
macro_rules! error {
( $($x:tt)+ ) => { tracing::error!( $($x)+ ); }
}
#[macro_export]
macro_rules! warn {
( $($x:tt)+ ) => { tracing::warn!( $($x)+ ); }
}
#[macro_export]
macro_rules! info {
( $($x:tt)+ ) => { tracing::info!( $($x)+ ); }
}
#[macro_export]
macro_rules! debug {
( $($x:tt)+ ) => { tracing::debug!( $($x)+ ); }
}
#[macro_export]
macro_rules! trace {
( $($x:tt)+ ) => { tracing::trace!( $($x)+ ); }
}

27
src/core/mod.rs Normal file
View file

@ -0,0 +1,27 @@
pub mod alloc;
pub mod config;
pub mod debug;
pub mod error;
pub mod log;
pub mod mods;
pub mod pducount;
pub mod server;
pub mod utils;
pub use config::Config;
pub use error::{Error, Result, RumaResponse};
pub use pducount::PduCount;
pub use server::Server;
pub use utils::conduwuit_version;
#[cfg(not(feature = "mods"))]
mod mods {
#[macro_export]
macro_rules! mod_ctor {
() => {};
}
#[macro_export]
macro_rules! mod_dtor {
() => {};
}
}

28
src/core/mods/canary.rs Normal file
View file

@ -0,0 +1,28 @@
use std::sync::atomic::{AtomicI32, Ordering};
const ORDERING: Ordering = Ordering::Relaxed;
static STATIC_DTORS: AtomicI32 = AtomicI32::new(0);
/// Called by Module::unload() to indicate module is about to be unloaded and
/// static destruction is intended. This will allow verifying it actually took
/// place.
pub(crate) fn prepare() {
let count = STATIC_DTORS.fetch_sub(1, ORDERING);
debug_assert!(count <= 0, "STATIC_DTORS should not be greater than zero.");
}
/// Called by static destructor of a module. This call should only be found
/// inside a mod_fini! macro. Do not call from anywhere else.
#[inline(always)]
pub fn report() { let _count = STATIC_DTORS.fetch_add(1, ORDERING); }
/// Called by Module::unload() (see check()) with action in case a check()
/// failed. This can allow a stuck module to be noted while allowing for other
/// independent modules to be diagnosed.
pub(crate) fn check_and_reset() -> bool { STATIC_DTORS.swap(0, ORDERING) == 0 }
/// Called by Module::unload() after unload to verify static destruction took
/// place. A call to prepare() must be made prior to Module::unload() and making
/// this call.
#[allow(dead_code)]
pub(crate) fn check() -> bool { STATIC_DTORS.load(ORDERING) == 0 }

44
src/core/mods/macros.rs Normal file
View file

@ -0,0 +1,44 @@
#[macro_export]
macro_rules! mod_ctor {
( $($body:block)? ) => {
$crate::mod_init! {{
$crate::debug_info!("Module loaded");
$($body)?
}}
}
}
#[macro_export]
macro_rules! mod_dtor {
( $($body:block)? ) => {
$crate::mod_fini! {{
$crate::debug_info!("Module unloading");
$($body)?
$crate::mods::canary::report();
}}
}
}
#[macro_export]
macro_rules! mod_init {
($body:block) => {
#[used]
#[cfg_attr(target_family = "unix", link_section = ".init_array")]
static MOD_INIT: extern "C" fn() = { _mod_init };
#[cfg_attr(target_family = "unix", link_section = ".text.startup")]
extern "C" fn _mod_init() -> () $body
};
}
#[macro_export]
macro_rules! mod_fini {
($body:block) => {
#[used]
#[cfg_attr(target_family = "unix", link_section = ".fini_array")]
static MOD_FINI: extern "C" fn() = { _mod_fini };
#[cfg_attr(target_family = "unix", link_section = ".text.startup")]
extern "C" fn _mod_fini() -> () $body
};
}

11
src/core/mods/mod.rs Normal file
View file

@ -0,0 +1,11 @@
#![cfg(feature = "mods")]
pub(crate) use libloading::os::unix::{Library, Symbol};
pub mod canary;
pub mod macros;
pub mod module;
pub mod new;
pub mod path;
pub use module::Module;

74
src/core/mods/module.rs Normal file
View file

@ -0,0 +1,74 @@
use std::{
ffi::{CString, OsString},
time::SystemTime,
};
use super::{canary, new, path, Library, Symbol};
use crate::{error, Result};
pub struct Module {
handle: Option<Library>,
loaded: SystemTime,
path: OsString,
}
impl Module {
pub fn from_name(name: &str) -> Result<Self> { Self::from_path(path::from_name(name)?) }
pub fn from_path(path: OsString) -> Result<Self> {
Ok(Self {
handle: Some(new::from_path(&path)?),
loaded: SystemTime::now(),
path,
})
}
pub fn unload(&mut self) {
canary::prepare();
self.close();
if !canary::check_and_reset() {
let name = self.name().expect("Module is named");
error!("Module {name:?} is stuck and failed to unload.");
}
}
pub(crate) fn close(&mut self) {
if let Some(handle) = self.handle.take() {
handle.close().expect("Module handle closed");
}
}
pub fn get<Prototype>(&self, name: &str) -> Result<Symbol<Prototype>> {
let cname = CString::new(name.to_owned()).expect("terminated string from provided name");
let handle = self
.handle
.as_ref()
.expect("backing library loaded by this instance");
// SAFETY: Calls dlsym(3) on unix platforms. This might not have to be unsafe
// if wrapped in libloading with_dlerror().
let sym = unsafe { handle.get::<Prototype>(cname.as_bytes()) };
let sym = sym.expect("symbol found; binding successful");
Ok(sym)
}
pub fn deleted(&self) -> Result<bool> {
let mtime = path::mtime(self.path())?;
let res = mtime > self.loaded;
Ok(res)
}
pub fn name(&self) -> Result<String> { path::to_name(self.path()) }
#[must_use]
pub fn path(&self) -> &OsString { &self.path }
}
impl Drop for Module {
fn drop(&mut self) {
if self.handle.is_some() {
self.unload();
}
}
}

23
src/core/mods/new.rs Normal file
View file

@ -0,0 +1,23 @@
use std::ffi::OsStr;
use super::{path, Library};
use crate::{Error, Result};
const OPEN_FLAGS: i32 = libloading::os::unix::RTLD_LAZY | libloading::os::unix::RTLD_GLOBAL;
pub fn from_name(name: &str) -> Result<Library> {
let path = path::from_name(name)?;
from_path(&path)
}
pub fn from_path(path: &OsStr) -> Result<Library> {
//SAFETY: Calls dlopen(3) on unix platforms. This might not have to be unsafe
// if wrapped in with_dlerror.
let lib = unsafe { Library::open(Some(path), OPEN_FLAGS) };
if let Err(e) = lib {
let name = path::to_name(path)?;
return Err(Error::Err(format!("Loading module {name:?} failed: {e}")));
}
Ok(lib.expect("module loaded"))
}

40
src/core/mods/path.rs Normal file
View file

@ -0,0 +1,40 @@
use std::{
env::current_exe,
ffi::{OsStr, OsString},
path::{Path, PathBuf},
time::SystemTime,
};
use libloading::library_filename;
use crate::Result;
pub fn from_name(name: &str) -> Result<OsString> {
let root = PathBuf::new();
let exe_path = current_exe()?;
let exe_dir = exe_path.parent().unwrap_or(&root);
let mut mod_path = exe_dir.to_path_buf();
let mod_file = library_filename(name);
mod_path.push(mod_file);
Ok(mod_path.into_os_string())
}
pub fn to_name(path: &OsStr) -> Result<String> {
let path = Path::new(path);
let name = path
.file_stem()
.expect("path file stem")
.to_str()
.expect("name string");
let name = name.strip_prefix("lib").unwrap_or(name).to_owned();
Ok(name)
}
pub fn mtime(path: &OsStr) -> Result<SystemTime> {
let meta = std::fs::metadata(path)?;
let mtime = meta.modified()?;
Ok(mtime)
}

51
src/core/pducount.rs Normal file
View file

@ -0,0 +1,51 @@
use std::cmp::Ordering;
use ruma::api::client::error::ErrorKind;
use crate::{Error, Result};
#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)]
pub enum PduCount {
Backfilled(u64),
Normal(u64),
}
impl PduCount {
#[must_use]
pub fn min() -> Self { Self::Backfilled(u64::MAX) }
#[must_use]
pub fn max() -> Self { Self::Normal(u64::MAX) }
pub fn try_from_string(token: &str) -> Result<Self> {
if let Some(stripped_token) = token.strip_prefix('-') {
stripped_token.parse().map(PduCount::Backfilled)
} else {
token.parse().map(PduCount::Normal)
}
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token."))
}
#[must_use]
pub fn stringify(&self) -> String {
match self {
PduCount::Backfilled(x) => format!("-{x}"),
PduCount::Normal(x) => x.to_string(),
}
}
}
impl PartialOrd for PduCount {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
}
impl Ord for PduCount {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o),
(PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s),
(PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater,
(PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less,
}
}
}

72
src/core/server.rs Normal file
View file

@ -0,0 +1,72 @@
use std::{
sync::{
atomic::{AtomicBool, AtomicU32},
Mutex,
},
time::SystemTime,
};
use tokio::runtime;
use crate::{config::Config, log::LogLevelReloadHandles};
/// Server runtime state; public portion
pub struct Server {
/// Server-wide configuration instance
pub config: Config,
/// Timestamp server was started; used for uptime.
pub started: SystemTime,
/// Reload/shutdown signal channel. Called from the signal handler or admin
/// command to initiate shutdown.
pub shutdown: Mutex<Option<axum_server::Handle>>,
/// Reload/shutdown desired indicator; when false, shutdown is desired. This
/// is an observable used on shutdown and modifying is not recommended.
pub reload: AtomicBool,
/// Reload/shutdown pending indicator; server is shutting down. This is an
/// observable used on shutdown and should not be modified.
pub interrupt: AtomicBool,
/// Handle to the runtime
pub runtime: Option<runtime::Handle>,
/// Log level reload handles.
pub tracing_reload_handle: LogLevelReloadHandles,
/// TODO: move stats
pub requests_spawn_active: AtomicU32,
pub requests_spawn_finished: AtomicU32,
pub requests_handle_active: AtomicU32,
pub requests_handle_finished: AtomicU32,
pub requests_panic: AtomicU32,
}
impl Server {
#[must_use]
pub fn new(config: Config, runtime: Option<runtime::Handle>, tracing_reload_handle: LogLevelReloadHandles) -> Self {
Self {
config,
started: SystemTime::now(),
shutdown: Mutex::new(None),
reload: AtomicBool::new(false),
interrupt: AtomicBool::new(false),
runtime,
tracing_reload_handle,
requests_spawn_active: AtomicU32::new(0),
requests_spawn_finished: AtomicU32::new(0),
requests_handle_active: AtomicU32::new(0),
requests_handle_finished: AtomicU32::new(0),
requests_panic: AtomicU32::new(0),
}
}
#[inline]
pub fn runtime(&self) -> &runtime::Handle {
self.runtime
.as_ref()
.expect("runtime handle available in Server")
}
}

20
src/core/utils/clap.rs Normal file
View file

@ -0,0 +1,20 @@
//! Integration with `clap`
use std::path::PathBuf;
pub use clap::Parser;
use super::conduwuit_version;
/// Commandline arguments
#[derive(Parser, Debug)]
#[clap(version = conduwuit_version(), about, long_about = None)]
pub struct Args {
#[arg(short, long)]
/// Optional argument to the path of a conduwuit config TOML file
pub config: Option<PathBuf>,
}
/// Parse commandline arguments into structured data
#[must_use]
pub fn parse() -> Args { Args::parse() }

View file

@ -0,0 +1,130 @@
use infer::MatcherType;
use crate::debug_info;
const ATTACHMENT: &str = "attachment";
const INLINE: &str = "inline";
const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
const IMAGE_SVG_XML: &str = "image/svg+xml";
/// Returns a Content-Disposition of `attachment` or `inline`, depending on the
/// *parsed* contents of the file uploaded via format magic keys using `infer`
/// crate (basically libmagic without needing libmagic).
///
/// This forbids trusting what the client or remote server says the file is from
/// their `Content-Type` and we try to detect it ourselves. Also returns
/// `attachment` if the Content-Type does not match what we detected.
///
/// TODO: add a "strict" function for comparing the Content-Type with what we
/// detected: `file_type.mime_type() != content_type`
#[must_use]
#[tracing::instrument(skip(buf))]
pub fn content_disposition_type(buf: &[u8], content_type: &Option<String>) -> &'static str {
let Some(file_type) = infer::get(buf) else {
return ATTACHMENT;
};
debug_info!("MIME type: {}", file_type.mime_type());
match file_type.matcher_type() {
MatcherType::Image | MatcherType::Audio | MatcherType::Text | MatcherType::Video => {
if file_type.mime_type().contains("xml") {
ATTACHMENT
} else {
INLINE
}
},
_ => ATTACHMENT,
}
}
/// overrides the Content-Type with what we detected
///
/// SVG is special-cased due to the MIME type being classified as `text/xml` but
/// browsers need `image/svg+xml`
#[must_use]
#[tracing::instrument(skip(buf))]
pub fn make_content_type(buf: &[u8], content_type: &Option<String>) -> &'static str {
let Some(file_type) = infer::get(buf) else {
debug_info!("Failed to infer the file's contents");
return APPLICATION_OCTET_STREAM;
};
let Some(claimed_content_type) = content_type else {
return file_type.mime_type();
};
if claimed_content_type.contains("svg") && file_type.mime_type().contains("xml") {
return IMAGE_SVG_XML;
}
file_type.mime_type()
}
/// sanitises the file name for the Content-Disposition using
/// `sanitize_filename` crate
#[tracing::instrument]
pub fn sanitise_filename(filename: String) -> String {
let options = sanitize_filename::Options {
truncate: false,
..Default::default()
};
sanitize_filename::sanitize_with_options(filename, options)
}
/// creates the final Content-Disposition based on whether the filename exists
/// or not.
///
/// if filename exists:
/// `Content-Disposition: attachment/inline; filename=filename.ext`
///
/// else: `Content-Disposition: attachment/inline`
#[tracing::instrument(skip(file))]
pub fn make_content_disposition(
file: &[u8], content_type: &Option<String>, content_disposition: Option<String>,
) -> String {
let filename = content_disposition.map_or_else(String::new, |content_disposition| {
let (_, filename) = content_disposition
.split_once("filename=")
.unwrap_or(("", ""));
if filename.is_empty() {
String::new()
} else {
sanitise_filename(filename.to_owned())
}
});
if !filename.is_empty() {
// Content-Disposition: attachment/inline; filename=filename.ext
format!("{}; filename={}", content_disposition_type(file, content_type), filename)
} else {
// Content-Disposition: attachment/inline
String::from(content_disposition_type(file, content_type))
}
}
#[cfg(test)]
mod tests {
#[test]
fn string_sanitisation() {
const SAMPLE: &str =
"🏳this\\r\\n įs \r\\n ä \\r\nstrïng 🥴that\n\r ../../../../../../../may be\r\n malicious🏳";
const SANITISED: &str = "🏳thisrn įs n ä rstrïng 🥴that ..............may be malicious🏳";
let options = sanitize_filename::Options {
windows: true,
truncate: true,
replacement: "",
};
// cargo test -- --nocapture
println!("{}", SAMPLE);
println!("{}", sanitize_filename::sanitize_with_options(SAMPLE, options.clone()));
println!("{:?}", SAMPLE);
println!("{:?}", sanitize_filename::sanitize_with_options(SAMPLE, options.clone()));
assert_eq!(SANITISED, sanitize_filename::sanitize_with_options(SAMPLE, options.clone()));
}
}

22
src/core/utils/defer.rs Normal file
View file

@ -0,0 +1,22 @@
#[macro_export]
macro_rules! defer {
($body:block) => {
struct _Defer_<F>
where
F: FnMut(),
{
closure: F,
}
impl<F> Drop for _Defer_<F>
where
F: FnMut(),
{
fn drop(&mut self) { (self.closure)(); }
}
let _defer_ = _Defer_ {
closure: || $body,
};
};
}

270
src/core/utils/mod.rs Normal file
View file

@ -0,0 +1,270 @@
use std::{
cmp,
cmp::Ordering,
fmt,
str::FromStr,
time::{SystemTime, UNIX_EPOCH},
};
use rand::prelude::*;
use ring::digest;
use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId};
use tracing::debug;
use crate::{Error, Result};
pub mod clap;
pub mod content_disposition;
pub mod defer;
pub fn clamp<T: Ord>(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) }
#[must_use]
#[allow(clippy::as_conversions)]
pub fn millis_since_unix_epoch() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time is valid")
.as_millis() as u64
}
pub fn increment(old: Option<&[u8]>) -> Vec<u8> {
let number = match old.map(TryInto::try_into) {
Some(Ok(bytes)) => {
let number = u64::from_be_bytes(bytes);
number + 1
},
_ => 1, // Start at one. since 0 should return the first event in the db
};
number.to_be_bytes().to_vec()
}
#[must_use]
pub fn generate_keypair() -> Vec<u8> {
let mut value = random_string(8).as_bytes().to_vec();
value.push(0xFF);
value.extend_from_slice(
&ruma::signatures::Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"),
);
value
}
/// Parses the bytes into an u64.
pub fn u64_from_bytes(bytes: &[u8]) -> Result<u64, std::array::TryFromSliceError> {
let array: [u8; 8] = bytes.try_into()?;
Ok(u64::from_be_bytes(array))
}
/// Parses the bytes into a string.
pub fn string_from_bytes(bytes: &[u8]) -> Result<String, std::string::FromUtf8Error> {
String::from_utf8(bytes.to_vec())
}
/// Parses a `OwnedUserId` from bytes.
pub fn user_id_from_bytes(bytes: &[u8]) -> Result<OwnedUserId> {
OwnedUserId::try_from(
string_from_bytes(bytes).map_err(|_| Error::bad_database("Failed to parse string from bytes"))?,
)
.map_err(|_| Error::bad_database("Failed to parse user id from bytes"))
}
pub fn random_string(length: usize) -> String {
thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(length)
.map(char::from)
.collect()
}
#[tracing::instrument(skip(keys))]
pub fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
// We only hash the pdu's event ids, not the whole pdu
let bytes = keys.join(&0xFF);
let hash = digest::digest(&digest::SHA256, &bytes);
hash.as_ref().to_owned()
}
#[allow(clippy::impl_trait_in_params)]
pub fn common_elements(
mut iterators: impl Iterator<Item = impl Iterator<Item = Vec<u8>>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering,
) -> Option<impl Iterator<Item = Vec<u8>>> {
let first_iterator = iterators.next()?;
let mut other_iterators = iterators.map(Iterator::peekable).collect::<Vec<_>>();
Some(first_iterator.filter(move |target| {
other_iterators.iter_mut().all(|it| {
while let Some(element) = it.peek() {
match check_order(element, target) {
Ordering::Greater => return false, // We went too far
Ordering::Equal => return true, // Element is in both iters
Ordering::Less => {
// Keep searching
it.next();
},
}
}
false
})
}))
}
/// Fallible conversion from any value that implements `Serialize` to a
/// `CanonicalJsonObject`.
///
/// `value` must serialize to an `serde_json::Value::Object`.
pub fn to_canonical_object<T: serde::Serialize>(value: T) -> Result<CanonicalJsonObject, CanonicalJsonError> {
use serde::ser::Error;
match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? {
serde_json::Value::Object(map) => try_from_json_map(map),
_ => Err(CanonicalJsonError::SerDe(serde_json::Error::custom("Value must be an object"))),
}
}
pub fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: FromStr<Err = E>, E: fmt::Display>(
deserializer: D,
) -> Result<T, D::Error> {
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
impl<T: FromStr<Err = Err>, Err: fmt::Display> serde::de::Visitor<'_> for Visitor<T, Err> {
type Value = T;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a parsable string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(serde::de::Error::custom)
}
}
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
}
// Copied from librustdoc:
// https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs
/// Wrapper struct which will emit the HTML-escaped version of the contained
/// string when passed to a format string.
pub struct HtmlEscape<'a>(pub &'a str);
impl fmt::Display for HtmlEscape<'_> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
// Because the internet is always right, turns out there's not that many
// characters to escape: http://stackoverflow.com/questions/7381974
let HtmlEscape(s) = *self;
let pile_o_bits = s;
let mut last = 0;
for (i, ch) in s.char_indices() {
let s = match ch {
'>' => "&gt;",
'<' => "&lt;",
'&' => "&amp;",
'\'' => "&#39;",
'"' => "&quot;",
_ => continue,
};
fmt.write_str(&pile_o_bits[last..i])?;
fmt.write_str(s)?;
// NOTE: we only expect single byte characters here - which is fine as long as
// we only match single byte characters
last = i + 1;
}
if last < s.len() {
fmt.write_str(&pile_o_bits[last..])?;
}
Ok(())
}
}
/// one true function for returning the conduwuit version with the necessary
/// CONDUWUIT_VERSION_EXTRA env variables used if specified
///
/// Set the environment variable `CONDUWUIT_VERSION_EXTRA` to any UTF-8 string
/// to include it in parenthesis after the SemVer version. A common value are
/// git commit hashes.
#[must_use]
pub fn conduwuit_version() -> String {
match option_env!("CONDUWUIT_VERSION_EXTRA") {
Some(extra) => {
if extra.is_empty() {
env!("CARGO_PKG_VERSION").to_owned()
} else {
format!("{} ({})", env!("CARGO_PKG_VERSION"), extra)
}
},
None => match option_env!("CONDUIT_VERSION_EXTRA") {
Some(extra) => {
if extra.is_empty() {
env!("CARGO_PKG_VERSION").to_owned()
} else {
format!("{} ({})", env!("CARGO_PKG_VERSION"), extra)
}
},
None => env!("CARGO_PKG_VERSION").to_owned(),
},
}
}
/// Debug-formats the given slice, but only up to the first `max_len` elements.
/// Any further elements are replaced by an ellipsis.
///
/// See also [`debug_slice_truncated()`],
pub struct TruncatedDebugSlice<'a, T> {
inner: &'a [T],
max_len: usize,
}
impl<T: fmt::Debug> fmt::Debug for TruncatedDebugSlice<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.inner.len() <= self.max_len {
write!(f, "{:?}", self.inner)
} else {
f.debug_list()
.entries(&self.inner[..self.max_len])
.entry(&"...")
.finish()
}
}
}
/// See [`TruncatedDebugSlice`]. Useful for `#[instrument]`:
///
/// ```
/// use conduit_core::utils::debug_slice_truncated;
///
/// #[tracing::instrument(fields(foos = debug_slice_truncated(foos, 42)))]
/// fn bar(foos: &[&str]);
/// ```
pub fn debug_slice_truncated<T: fmt::Debug>(
slice: &[T], max_len: usize,
) -> tracing::field::DebugValue<TruncatedDebugSlice<'_, T>> {
tracing::field::debug(TruncatedDebugSlice {
inner: slice,
max_len,
})
}
/// This is needed for opening lots of file descriptors, which tends to
/// happen more often when using RocksDB and making lots of federation
/// connections at startup. The soft limit is usually 1024, and the hard
/// limit is usually 512000; I've personally seen it hit >2000.
///
/// * <https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6>
/// * <https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741>
#[cfg(unix)]
pub fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE};
let (soft_limit, hard_limit) = getrlimit(NOFILE)?;
if soft_limit < hard_limit {
setrlimit(NOFILE, hard_limit, hard_limit)?;
assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit");
debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",);
}
Ok(())
}