From 6c1434c1657c5a911c98a9cfbe06c984d4079690 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 9 May 2024 15:59:08 -0700 Subject: [PATCH] Hot-Reloading Refactor Signed-off-by: Jason Volk --- Cargo.toml | 808 ++++++++++-------- deps/rust-rocksdb/Cargo.toml | 32 + deps/rust-rocksdb/lib.rs | 59 ++ nix/pkgs/main/default.nix | 2 +- src/admin/Cargo.toml | 63 ++ .../admin/appservice/appservice_command.rs | 2 +- src/{service => }/admin/appservice/mod.rs | 2 +- .../admin/debug/debug_commands.rs | 16 +- src/{service => }/admin/debug/mod.rs | 0 .../admin/federation/federation_commands.rs | 7 +- src/{service => }/admin/federation/mod.rs | 0 src/{service => }/admin/fsck/fsck_commands.rs | 0 src/{service => }/admin/fsck/mod.rs | 0 src/admin/handler.rs | 305 +++++++ .../admin/media/media_commands.rs | 4 +- src/{service => }/admin/media/mod.rs | 4 +- src/admin/mod.rs | 55 ++ src/{service => }/admin/query/account_data.rs | 0 src/{service => }/admin/query/appservice.rs | 0 src/{service => }/admin/query/globals.rs | 0 src/{service => }/admin/query/mod.rs | 0 src/{service => }/admin/query/presence.rs | 0 src/{service => }/admin/query/room_alias.rs | 0 src/{service => }/admin/query/sending.rs | 0 src/{service => }/admin/query/users.rs | 0 src/{service => }/admin/room/mod.rs | 0 .../admin/room/room_alias_commands.rs | 2 +- src/{service => }/admin/room/room_commands.rs | 5 +- .../admin/room/room_directory_commands.rs | 5 +- .../admin/room/room_moderation_commands.rs | 92 +- src/{service => }/admin/server/mod.rs | 0 .../admin/server/server_commands.rs | 19 +- src/{service => }/admin/tester/mod.rs | 2 +- src/{service => }/admin/user/mod.rs | 0 src/{service => }/admin/user/user_commands.rs | 14 +- src/admin/utils.rs | 30 + src/alloc/default.rs | 7 - src/alloc/hardened.rs | 8 - src/api/Cargo.toml | 66 ++ src/api/client_server/account.rs | 7 +- src/api/client_server/alias.rs | 8 +- src/api/client_server/directory.rs | 2 +- src/api/client_server/keys.rs | 3 +- src/api/client_server/media.rs | 6 +- src/api/client_server/membership.rs | 13 +- src/api/client_server/message.rs | 6 +- src/api/client_server/mod.rs | 77 +- src/api/client_server/profile.rs | 5 +- src/api/client_server/read_marker.rs | 3 +- src/api/client_server/room.rs | 4 +- src/api/client_server/state.rs | 6 +- src/api/client_server/sync.rs | 6 +- src/api/client_server/to_device.rs | 2 +- src/api/mod.rs | 16 +- src/{router/routes.rs => api/router.rs} | 23 +- src/api/ruma_wrapper/axum.rs | 83 +- src/api/ruma_wrapper/mod.rs | 20 +- src/api/ruma_wrapper/xmatrix.rs | 61 ++ src/api/server_server.rs | 40 +- src/bin/Cargo.toml | 123 +++ src/bin/main.rs | 96 +++ src/bin/mods.rs | 129 +++ src/bin/server.rs | 186 ++++ src/core/Cargo.toml | 133 +++ src/core/alloc/default.rs | 9 + src/core/alloc/hardened.rs | 10 + src/{ => core}/alloc/je.rs | 6 +- src/{ => core}/alloc/mod.rs | 12 +- src/{ => core}/config/check.rs | 4 +- src/{ => core}/config/mod.rs | 320 +++---- src/{ => core}/config/proxy.rs | 9 +- src/{utils => core}/debug.rs | 33 + src/{utils => core}/error.rs | 126 +-- src/core/log.rs | 79 ++ src/core/mod.rs | 27 + src/core/mods/canary.rs | 28 + src/core/mods/macros.rs | 44 + src/core/mods/mod.rs | 11 + src/core/mods/module.rs | 74 ++ src/core/mods/new.rs | 23 + src/core/mods/path.rs | 40 + src/core/pducount.rs | 51 ++ src/core/server.rs | 72 ++ src/{ => core}/utils/clap.rs | 8 +- src/{ => core}/utils/content_disposition.rs | 10 +- src/core/utils/defer.rs | 22 + src/{ => core}/utils/mod.rs | 89 +- src/database/Cargo.toml | 81 ++ src/database/cork.rs | 4 +- src/database/kvdatabase.rs | 320 +++++++ src/database/kvengine.rs | 2 +- src/database/kvtree.rs | 2 +- src/database/mod.rs | 576 +------------ src/database/rocksdb/kvtree.rs | 5 +- src/database/rocksdb/mod.rs | 40 +- src/database/rocksdb/opts.rs | 12 +- src/database/sqlite/mod.rs | 6 +- src/main.rs | 503 ----------- src/router/Cargo.toml | 85 ++ src/router/layers.rs | 190 ++++ src/router/mod.rs | 267 +----- src/router/request.rs | 102 +++ src/router/router.rs | 20 + src/router/run.rs | 185 ++++ src/router/serve.rs | 137 +++ src/service/Cargo.toml | 115 +++ src/service/account_data/data.rs | 2 +- src/service/account_data/mod.rs | 14 +- src/service/{admin/mod.rs => admin.rs} | 463 ++-------- src/service/admin/test_cmd/mod.rs | 41 - src/service/appservice/data.rs | 2 +- src/service/appservice/mod.rs | 58 +- src/service/globals/client.rs | 18 +- src/service/globals/data.rs | 2 +- src/service/globals/emerg_access.rs | 66 ++ .../globals}/migrations.rs | 6 +- src/service/globals/mod.rs | 253 +++--- src/service/globals/resolver.rs | 18 +- src/service/globals/updates.rs | 76 ++ src/service/key_backups/data.rs | 2 +- src/service/key_backups/mod.rs | 44 +- .../key_value/account_data.rs | 4 +- .../key_value/appservice.rs | 4 +- .../key_value/globals.rs | 11 +- .../key_value/key_backups.rs | 4 +- src/{database => service}/key_value/media.rs | 9 +- src/{database => service}/key_value/mod.rs | 0 .../key_value/presence.rs | 9 +- src/{database => service}/key_value/pusher.rs | 4 +- .../key_value/rooms/alias.rs | 4 +- .../key_value/rooms/auth_chain.rs | 4 +- .../key_value/rooms/directory.rs | 4 +- .../key_value/rooms/lazy_load.rs | 4 +- .../key_value/rooms/metadata.rs | 4 +- .../key_value/rooms/mod.rs | 4 +- .../key_value/rooms/outlier.rs | 4 +- .../key_value/rooms/pdu_metadata.rs | 8 +- .../key_value/rooms/read_receipt.rs | 4 +- .../key_value/rooms/search.rs | 4 +- .../key_value/rooms/short.rs | 4 +- .../key_value/rooms/state.rs | 4 +- .../key_value/rooms/state_accessor.rs | 4 +- .../key_value/rooms/state_cache.rs | 11 +- .../key_value/rooms/state_compressor.rs | 8 +- .../key_value/rooms/threads.rs | 4 +- .../key_value/rooms/timeline.rs | 5 +- .../key_value/rooms/user.rs | 4 +- .../key_value/sending.rs | 10 +- .../key_value/transaction_ids.rs | 4 +- src/{database => service}/key_value/uiaa.rs | 4 +- src/{database => service}/key_value/users.rs | 23 +- src/service/media/data.rs | 2 +- src/service/media/mod.rs | 54 +- src/service/mod.rs | 334 +------- src/service/pdu.rs | 82 +- src/service/presence/data.rs | 2 +- src/service/presence/mod.rs | 91 +- src/service/pusher/data.rs | 2 +- src/service/pusher/mod.rs | 23 +- src/service/rooms/alias/data.rs | 2 +- src/service/rooms/alias/mod.rs | 20 +- src/service/rooms/auth_chain/data.rs | 2 +- src/service/rooms/auth_chain/mod.rs | 16 +- src/service/rooms/directory/data.rs | 2 +- src/service/rooms/directory/mod.rs | 16 +- src/service/rooms/event_handler/mod.rs | 31 +- .../rooms/event_handler/parse_incoming_pdu.rs | 31 + .../rooms/event_handler/signing_keys.rs | 13 +- src/service/rooms/lazy_loading/data.rs | 2 +- src/service/rooms/lazy_loading/mod.rs | 25 +- src/service/rooms/metadata/data.rs | 2 +- src/service/rooms/metadata/mod.rs | 24 +- src/service/rooms/mod.rs | 84 +- src/service/rooms/outlier/data.rs | 2 +- src/service/rooms/outlier/mod.rs | 16 +- src/service/rooms/pdu_metadata/data.rs | 4 +- src/service/rooms/pdu_metadata/mod.rs | 28 +- src/service/rooms/read_receipt/data.rs | 2 +- src/service/rooms/read_receipt/mod.rs | 18 +- src/service/rooms/search/data.rs | 2 +- src/service/rooms/search/mod.rs | 12 +- src/service/rooms/short/data.rs | 2 +- src/service/rooms/short/mod.rs | 24 +- src/service/rooms/spaces/mod.rs | 22 +- src/service/rooms/state/data.rs | 2 +- src/service/rooms/state/mod.rs | 26 +- src/service/rooms/state_accessor/data.rs | 2 +- src/service/rooms/state_accessor/mod.rs | 48 +- src/service/rooms/state_cache/data.rs | 2 +- src/service/rooms/state_cache/mod.rs | 80 +- src/service/rooms/state_compressor/data.rs | 10 +- src/service/rooms/state_compressor/mod.rs | 24 +- src/service/rooms/threads/data.rs | 2 +- src/service/rooms/threads/mod.rs | 12 +- src/service/rooms/timeline/data.rs | 2 +- src/service/rooms/timeline/mod.rs | 117 +-- src/service/rooms/typing/mod.rs | 26 +- src/service/rooms/user/data.rs | 2 +- src/service/rooms/user/mod.rs | 26 +- src/service/sending/data.rs | 2 +- src/service/sending/mod.rs | 71 +- src/service/sending/send.rs | 222 +---- src/service/sending/sender.rs | 38 +- src/service/services.rs | 342 ++++++++ src/service/transaction_ids/data.rs | 2 +- src/service/transaction_ids/mod.rs | 12 +- src/service/uiaa/data.rs | 2 +- src/service/uiaa/mod.rs | 20 +- src/service/users/data.rs | 2 +- src/service/users/mod.rs | 118 ++- src/utils/server_name.rs | 8 - src/utils/user_id.rs | 8 - 212 files changed, 5679 insertions(+), 4206 deletions(-) create mode 100644 deps/rust-rocksdb/Cargo.toml create mode 100644 deps/rust-rocksdb/lib.rs create mode 100644 src/admin/Cargo.toml rename src/{service => }/admin/appservice/appservice_command.rs (97%) rename src/{service => }/admin/appservice/mod.rs (98%) rename src/{service => }/admin/debug/debug_commands.rs (98%) rename src/{service => }/admin/debug/mod.rs (100%) rename src/{service => }/admin/federation/federation_commands.rs (97%) rename src/{service => }/admin/federation/mod.rs (100%) rename src/{service => }/admin/fsck/fsck_commands.rs (100%) rename src/{service => }/admin/fsck/mod.rs (100%) create mode 100644 src/admin/handler.rs rename src/{service => }/admin/media/media_commands.rs (98%) rename src/{service => }/admin/media/mod.rs (96%) create mode 100644 src/admin/mod.rs rename src/{service => }/admin/query/account_data.rs (100%) rename src/{service => }/admin/query/appservice.rs (100%) rename src/{service => }/admin/query/globals.rs (100%) rename src/{service => }/admin/query/mod.rs (100%) rename src/{service => }/admin/query/presence.rs (100%) rename src/{service => }/admin/query/room_alias.rs (100%) rename src/{service => }/admin/query/sending.rs (100%) rename src/{service => }/admin/query/users.rs (100%) rename src/{service => }/admin/room/mod.rs (100%) rename src/{service => }/admin/room/room_alias_commands.rs (98%) rename src/{service => }/admin/room/room_commands.rs (93%) rename src/{service => }/admin/room/room_directory_commands.rs (96%) rename src/{service => }/admin/room/room_moderation_commands.rs (83%) rename src/{service => }/admin/server/mod.rs (100%) rename src/{service => }/admin/server/server_commands.rs (90%) rename src/{service => }/admin/tester/mod.rs (92%) rename src/{service => }/admin/user/mod.rs (100%) rename src/{service => }/admin/user/user_commands.rs (97%) create mode 100644 src/admin/utils.rs delete mode 100644 src/alloc/default.rs delete mode 100644 src/alloc/hardened.rs create mode 100644 src/api/Cargo.toml rename src/{router/routes.rs => api/router.rs} (96%) create mode 100644 src/api/ruma_wrapper/xmatrix.rs create mode 100644 src/bin/Cargo.toml create mode 100644 src/bin/main.rs create mode 100644 src/bin/mods.rs create mode 100644 src/bin/server.rs create mode 100644 src/core/Cargo.toml create mode 100644 src/core/alloc/default.rs create mode 100644 src/core/alloc/hardened.rs rename src/{ => core}/alloc/je.rs (95%) rename src/{ => core}/alloc/mod.rs (80%) rename src/{ => core}/config/check.rs (98%) rename src/{ => core}/config/mod.rs (80%) rename src/{ => core}/config/proxy.rs (95%) rename src/{utils => core}/debug.rs (66%) rename src/{utils => core}/error.rs (78%) create mode 100644 src/core/log.rs create mode 100644 src/core/mod.rs create mode 100644 src/core/mods/canary.rs create mode 100644 src/core/mods/macros.rs create mode 100644 src/core/mods/mod.rs create mode 100644 src/core/mods/module.rs create mode 100644 src/core/mods/new.rs create mode 100644 src/core/mods/path.rs create mode 100644 src/core/pducount.rs create mode 100644 src/core/server.rs rename src/{ => core}/utils/clap.rs (73%) rename src/{ => core}/utils/content_disposition.rs (93%) create mode 100644 src/core/utils/defer.rs rename src/{ => core}/utils/mod.rs (72%) create mode 100644 src/database/Cargo.toml create mode 100644 src/database/kvdatabase.rs delete mode 100644 src/main.rs create mode 100644 src/router/Cargo.toml create mode 100644 src/router/layers.rs create mode 100644 src/router/request.rs create mode 100644 src/router/router.rs create mode 100644 src/router/run.rs create mode 100644 src/router/serve.rs create mode 100644 src/service/Cargo.toml rename src/service/{admin/mod.rs => admin.rs} (51%) delete mode 100644 src/service/admin/test_cmd/mod.rs create mode 100644 src/service/globals/emerg_access.rs rename src/{database => service/globals}/migrations.rs (99%) create mode 100644 src/service/globals/updates.rs rename src/{database => service}/key_value/account_data.rs (96%) rename src/{database => service}/key_value/appservice.rs (92%) rename src/{database => service}/key_value/globals.rs (96%) rename src/{database => service}/key_value/key_backups.rs (98%) rename src/{database => service}/key_value/media.rs (97%) rename src/{database => service}/key_value/mod.rs (100%) rename src/{database => service}/key_value/presence.rs (96%) rename src/{database => service}/key_value/pusher.rs (94%) rename src/{database => service}/key_value/rooms/alias.rs (94%) rename src/{database => service}/key_value/rooms/auth_chain.rs (91%) rename src/{database => service}/key_value/rooms/directory.rs (85%) rename src/{database => service}/key_value/rooms/lazy_load.rs (92%) rename src/{database => service}/key_value/rooms/metadata.rs (93%) rename src/{database => service}/key_value/rooms/mod.rs (71%) rename src/{database => service}/key_value/rooms/outlier.rs (85%) rename src/{database => service}/key_value/rooms/pdu_metadata.rs (92%) rename src/{database => service}/key_value/rooms/read_receipt.rs (96%) rename src/{database => service}/key_value/rooms/search.rs (93%) rename src/{database => service}/key_value/rooms/short.rs (97%) rename src/{database => service}/key_value/rooms/state.rs (94%) rename src/{database => service}/key_value/rooms/state_accessor.rs (96%) rename src/{database => service}/key_value/rooms/state_cache.rs (98%) rename src/{database => service}/key_value/rooms/state_compressor.rs (88%) rename src/{database => service}/key_value/rooms/threads.rs (93%) rename src/{database => service}/key_value/rooms/timeline.rs (97%) rename src/{database => service}/key_value/rooms/user.rs (96%) rename src/{database => service}/key_value/sending.rs (96%) rename src/{database => service}/key_value/transaction_ids.rs (88%) rename src/{database => service}/key_value/uiaa.rs (94%) rename src/{database => service}/key_value/users.rs (98%) create mode 100644 src/service/rooms/event_handler/parse_incoming_pdu.rs create mode 100644 src/service/services.rs delete mode 100644 src/utils/server_name.rs delete mode 100644 src/utils/user_id.rs diff --git a/Cargo.toml b/Cargo.toml index d7b9454b..2e2153a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,115 +1,95 @@ -[package] -# TODO: when can we rename to conduwuit? -name = "conduit" +#cargo-features = ["profile-rustflags"] + +[workspace] +resolver = "2" +members = ["src/*"] +default-members = ["src/*"] + +[workspace.package] description = "a very cool fork of Conduit, a Matrix homeserver written in Rust" license = "Apache-2.0" authors = [ "strawberry ", "timokoesters ", ] +version = "0.3.4" +edition = "2021" +# See also `rust-toolchain.toml` +rust-version = "1.77.0" homepage = "https://conduwuit.puppyirl.gay/" repository = "https://github.com/girlbossceo/conduwuit" readme = "README.md" -version = "0.3.4" -edition = "2021" -# See also `rust-toolchain.toml` -rust-version = "1.77.0" - -[dependencies] # 1.1.17 seems broken on nix from a permission error? -libz-sys = "=1.1.16" +[workspace.dependencies.libz-sys] +version = "=1.1.16" -console-subscriber = { version = "0.2", optional = true } +[workspace.dependencies.sanitize-filename] +version = "0.5.0" -infer = { version = "0.15", default-features = false } +[workspace.dependencies.infer] +version = "0.15" +default-features = false -# for hot lib reload -hot-lib-reloader = { version = "^0.7", optional = true } +[workspace.dependencies.jsonwebtoken] +version = "9.3.0" -# Used for secure identifiers -rand = "0.8.5" - -# Used for conduit::Error type -thiserror = "1.0.60" - -# Used to encode server public key -base64 = "0.22.1" - -# Used when hashing the state -ring = "0.17.8" - -# Used to find matching events for appservices -regex = "1.10.4" - -# Used to load forbidden room/user regex from config -serde_regex = "1.1.0" - -# Used to make working with iterators easier, was already a transitive depdendency -itertools = "0.12.1" - -# jwt jsonwebtokens -jsonwebtoken = "9.3.0" - -# Used for ruma wrapper -serde_html_form = "0.2.6" +[workspace.dependencies.base64] +version = "0.22.1" # used for TURN server authentication -hmac = "0.12.1" -sha-1 = "0.10.1" +[workspace.dependencies.hmac] +version = "0.12.1" + +[workspace.dependencies.sha-1] +version = "0.10.1" # used for checking if an IP is in specific subnets / CIDR ranges easier -ipaddress = "0.1.3" +[workspace.dependencies.ipaddress] +version = "0.1.3" -# to get the client IP address of requests -#axum-client-ip = "0.4.2" +[workspace.dependencies.rand] +version = "0.8.5" -# to parse user-friendly time durations in admin commands -cyborgtime = "2.1.1" - -# all the web/HTTP dependencies # Used for the http request / response body type for Ruma endpoints used with reqwest -bytes = "1.6.0" -http = "1.1.0" -http-body-util = "0.1.1" +[workspace.dependencies.bytes] +version = "1.6.0" -# used to replace the channels of the tokio runtime -loole = "0.3.0" +[workspace.dependencies.http-body-util] +version = "0.1.1" -# Validating urls in config, was already a transitive dependency -url = { version = "2.5.0", features = ["serde"] } +[workspace.dependencies.http] +version = "1.1.0" -async-trait = "0.1.80" +[workspace.dependencies.regex] +version = "1.10.4" -lru-cache = "0.1.2" -sanitize-filename = "0.5.0" - -# standard date and time tools -[dependencies.chrono] -version = "0.4.38" -features = ["alloc"] -default-features = false - -# Web framework -[dependencies.axum] +[workspace.dependencies.axum] version = "0.7.5" default-features = false -features = ["form", "http1", "http2", "json", "matched-path"] +features = [ + "form", + "http1", + "http2", + "json", + "matched-path", + "tokio", +] -[dependencies.axum-extra] +[workspace.dependencies.axum-extra] version = "0.9.3" default-features = false features = ["typed-header"] -[dependencies.axum-server] +[workspace.dependencies.axum-server] version = "0.6.0" features = ["tls-rustls"] -[dependencies.tower] +[workspace.dependencies.tower] version = "0.4.13" features = ["util"] -[dependencies.tower-http] +[workspace.dependencies.tower-http] version = "0.5.2" features = [ "add-extension", @@ -121,153 +101,170 @@ features = [ "catch-panic", ] -[dependencies.hyper] -version = "1.3.1" -features = ["server", "http1", "http2"] - -[dependencies.hyper-util] -version = "0.1.3" - -[dependencies.reqwest] +[workspace.dependencies.reqwest] version = "0.12.4" default-features = false -features = ["rustls-tls-native-roots", "socks", "hickory-dns"] +features = [ + "rustls-tls-native-roots", + "socks", + "hickory-dns", +] -# all the serde stuff -# Used for pdu definition -[dependencies.serde] +[workspace.dependencies.serde] version = "1.0.201" features = ["rc"] -# Used for appservice registration files -[dependencies.serde_yaml] -version = "0.9.34" -# Used for ruma wrapper -[dependencies.serde_json] + +[workspace.dependencies.serde_json] version = "1.0.117" features = ["raw_value"] +# Used for appservice registration files +[workspace.dependencies.serde_yaml] +version = "0.9.34" + +# Used to load forbidden room/user regex from config +[workspace.dependencies.serde_regex] +version = "1.1.0" + +# Used for ruma wrapper +[workspace.dependencies.serde_html_form] +version = "0.2.6" # Used for password hashing -[dependencies.argon2] +[workspace.dependencies.argon2] version = "0.5.3" features = ["alloc", "rand"] default-features = false # Used to generate thumbnails for images -[dependencies.image] +[workspace.dependencies.image] version = "0.25.1" default-features = false -features = ["jpeg", "png", "gif", "webp"] +features = [ + "jpeg", + "png", + "gif", + "webp", +] # logging -[dependencies.log] +[workspace.dependencies.log] version = "0.4.21" default-features = false -[dependencies.tracing] +[workspace.dependencies.tracing] version = "0.1.40" default-features = false -[dependencies.tracing-subscriber] +[workspace.dependencies.tracing-subscriber] version = "0.3.18" features = ["env-filter"] -# optional SHA256 media keys feature -[dependencies.sha2] -version = "0.10.8" -optional = true - -# optional opentelemetry, performance measurements, flamegraphs, etc for performance measurements and monitoring -[dependencies.opentelemetry] -version = "0.21.0" -optional = true -[dependencies.tracing-flame] -version = "0.2.0" -optional = true -[dependencies.tracing-opentelemetry] -version = "0.22.0" -optional = true -[dependencies.opentelemetry_sdk] -version = "0.21.2" -optional = true -features = ["rt-tokio"] -[dependencies.opentelemetry-jaeger] -version = "0.20.0" -optional = true -features = ["rt-tokio"] - -# optional sentry metrics for crash/panic reporting -[dependencies.sentry] -version = "0.32.3" -optional = true -default-features = false -features = [ - "backtrace", - "contexts", - "debug-images", - "panic", - "rustls", - "tower", - "tower-http", - "tracing", - "reqwest", - "log", -] -[dependencies.sentry-tracing] -version = "0.32.3" -optional = true -[dependencies.sentry-tower] -version = "0.32.3" -optional = true - -# optional jemalloc usage -[dependencies.tikv-jemalloc-sys] -version = "0.5.4" -optional = true -default-features = false -features = ["stats", "unprefixed_malloc_on_supported_platforms"] -[dependencies.tikv-jemallocator] -version = "0.5.4" -optional = true -default-features = false -features = ["stats", "unprefixed_malloc_on_supported_platforms"] -[dependencies.tikv-jemalloc-ctl] -version = "0.5.4" -optional = true -default-features = false -features = ["use_std"] - # for URL previews -[dependencies.webpage] +[workspace.dependencies.webpage] version = "2.0.1" default-features = false -# to support multiple variations of setting a config option -[dependencies.either] -version = "1.11.0" -features = ["serde"] - -# to listen on both HTTP and HTTPS if listening on TLS dierctly from conduwuit for complement or sytest -[dependencies.axum-server-dual-protocol] -version = "0.6" -optional = true - # used for conduit's CLI and admin room command parsing -[dependencies.clap] +[workspace.dependencies.clap] version = "4.5.4" default-features = false -features = ["std", "derive", "help", "usage", "error-context", "string"] +features = [ + "std", + "derive", + "help", + "usage", + "error-context", + "string", +] -[dependencies.futures-util] +[workspace.dependencies.futures-util] version = "0.3.30" default-features = false +[workspace.dependencies.tokio] +version = "1.37.0" +features = [ + "fs", + "net", + "macros", + "sync", + "signal", + "time", + "rt-multi-thread", + "io-util", +] + +[workspace.dependencies.libloading] +version = "0.8.3" + +# Validating urls in config, was already a transitive dependency +[workspace.dependencies.url] +version = "2.5.0" +features = ["serde"] + +# standard date and time tools +[workspace.dependencies.chrono] +version = "0.4.38" +features = ["alloc"] +default-features = false + +[workspace.dependencies.hyper] +version = "1.3.1" +features = [ + "server", + "http1", + "http2", +] + +[workspace.dependencies.hyper-util] +version = "0.1.3" + +# to support multiple variations of setting a config option +[workspace.dependencies.either] +version = "1.11.0" +features = ["serde"] + # Used for reading the configuration from conduwuit.toml & environment variables -[dependencies.figment] +[workspace.dependencies.figment] version = "0.10.18" features = ["env", "toml"] +[workspace.dependencies.hickory-resolver] +version = "0.24.1" +default-features = false + +# Used for conduit::Error type +[workspace.dependencies.thiserror] +version = "1.0.60" + +# Used when hashing the state +[workspace.dependencies.ring] +version = "0.17.8" + +# Used to make working with iterators easier, was already a transitive depdendency +[workspace.dependencies.itertools] +version = "0.12.1" + +# to parse user-friendly time durations in admin commands +#TODO: overlaps chrono? +[workspace.dependencies.cyborgtime] +version = "2.1.1" + +# used to replace the channels of the tokio runtime +[workspace.dependencies.loole] +version = "0.3.0" + +[workspace.dependencies.async-trait] +version = "0.1.80" + +[workspace.dependencies.lru-cache] +version = "0.1.2" + +[workspace.dependencies.num_cpus] +version = "1.16.0" + # Used for matrix spec type definitions and helpers -[dependencies.ruma] -git = "https://github.com/girlbossceo/ruma" +[workspace.dependencies.ruma] +git = "https://github.com/girlbossceo/ruwuma" branch = "conduwuit-changes" features = [ "compat", @@ -292,60 +289,128 @@ features = [ "unstable-extensible-events", ] -[dependencies.ruma-identifiers-validation] -git = "https://github.com/girlbossceo/ruma" +[workspace.dependencies.ruma-identifiers-validation] +git = "https://github.com/girlbossceo/ruwuma" branch = "conduwuit-changes" -[dependencies.hickory-resolver] -version = "0.24.1" +[workspace.dependencies.rust-rocksdb] +path = "deps/rust-rocksdb" +package = "rust-rocksdb-uwu" +features = [ + "multi-threaded-cf", + "mt_static", + "snappy", + "lz4", + "zstd", + "zlib", + "bzip2", +] + +[workspace.dependencies.zstd] +version = "0.13.1" + +# to listen on both HTTP and HTTPS if listening on TLS dierctly from conduwuit for complement or sytest +[workspace.dependencies.axum-server-dual-protocol] +version = "0.6" + +# optional SHA256 media keys feature +[workspace.dependencies.sha2] +version = "0.10.8" + +# optional opentelemetry, performance measurements, flamegraphs, etc for performance measurements and monitoring +[workspace.dependencies.opentelemetry] +version = "0.21.0" + +[workspace.dependencies.tracing-flame] +version = "0.2.0" + +[workspace.dependencies.tracing-opentelemetry] +version = "0.22.0" + +[workspace.dependencies.opentelemetry_sdk] +version = "0.21.2" +features = ["rt-tokio"] + +[workspace.dependencies.opentelemetry-jaeger] +version = "0.20.0" +features = ["rt-tokio"] + +# optional sentry metrics for crash/panic reporting +[workspace.dependencies.sentry] +version = "0.32.3" default-features = false +features = [ + "backtrace", + "contexts", + "debug-images", + "panic", + "rustls", + "tower", + "tower-http", + "tracing", + "reqwest", + "log", +] -[dependencies.rust-rocksdb] -git = "https://github.com/zaidoon1/rust-rocksdb" -branch = "master" -optional = true -default-features = true -features = ["multi-threaded-cf", "zstd"] +[workspace.dependencies.sentry-tracing] +version = "0.32.3" +[workspace.dependencies.sentry-tower] +version = "0.32.3" -[dependencies.rusqlite] +# optional jemalloc usage +[workspace.dependencies.tikv-jemalloc-sys] +version = "0.5.4" +default-features = false +features = ["unprefixed_malloc_on_supported_platforms"] +[workspace.dependencies.tikv-jemallocator] +version = "0.5.4" +default-features = false +features = ["unprefixed_malloc_on_supported_platforms"] +[workspace.dependencies.tikv-jemalloc-ctl] +version = "0.5.4" +default-features = false +features = ["use_std"] + +[workspace.dependencies.rusqlite] git = "https://github.com/rusqlite/rusqlite" #branch = "master" rev = "e00b626e2b1c67347d789fb7f600281705c89381" -optional = true features = ["bundled"] # used only by rusqlite -[dependencies.parking_lot] +[workspace.dependencies.parking_lot] version = "0.12.2" -optional = true # used only by rusqlite -[dependencies.thread_local] +[workspace.dependencies.thread_local] version = "1.1.8" -optional = true -# used only by rusqlite and rust-rocksdb -[dependencies.num_cpus] -version = "1.16.0" +[workspace.dependencies.tokio-metrics] +version = "0.3.1" +default-features = false -[dependencies.tokio] -version = "1.37.0" -features = ["fs", "macros", "sync", "signal"] +[workspace.dependencies.console-subscriber] +version = "0.2" -# *nix-specific dependencies -[target.'cfg(unix)'.dependencies] -nix = { version = "0.28.0", features = ["resource"] } -sd-notify = { version = "0.4.1", optional = true } # systemd is only available/relevant on *nix platforms +[workspace.dependencies.nix] +version = "0.28.0" +features = ["resource"] +[workspace.dependencies.sd-notify] +version = "0.4.1" -[target.'cfg(all(not(target_env = "msvc"), target_os = "linux"))'.dependencies] -hardened_malloc-rs = { version = "0.1.2", optional = true, features = [ - "static", - "gcc", - "light", -], default-features = false } -#hardened_malloc-rs = { optional = true, features = ["static","clang","light"], path = "../hardened_malloc-rs", default-features = false } +[workspace.dependencies.hardened_malloc-rs] +version = "0.1.2" +default-features = false +features = [ + "static", + "gcc", + "light", +] +# +# Patches +# # backport of [https://github.com/tokio-rs/tracing/pull/2956] to the 0.1.x branch of tracing. # we can switch back to upstream if #2956 is merged and backported in the upstream repo. @@ -359,166 +424,213 @@ branch = "tracing-subscriber/env-filter-clone-0.1.x-backport" git = "https://github.com/girlbossceo/tracing" branch = "tracing-subscriber/env-filter-clone-0.1.x-backport" -[features] -default = [ - "backend_rocksdb", - "systemd", - "element_hacks", - "sentry_telemetry", - "gzip_compression", - "brotli_compression", - "zstd_compression", - "release_max_log_level", - "io_uring", -] -backend_sqlite = ["sqlite"] -backend_rocksdb = ["rocksdb"] -rocksdb = ["dep:rust-rocksdb"] -jemalloc = [ - "dep:tikv-jemalloc-sys", - "dep:tikv-jemalloc-ctl", - "dep:tikv-jemallocator", - "rust-rocksdb/jemalloc", -] -jemalloc_prof = ["tikv-jemalloc-sys/profiling"] -sqlite = ["dep:rusqlite", "dep:parking_lot", "dep:thread_local"] -systemd = ["dep:sd-notify"] -sentry_telemetry = ["dep:sentry", "dep:sentry-tracing", "dep:sentry-tower"] - -gzip_compression = ["tower-http/compression-gzip", "reqwest/gzip"] -zstd_compression = ["tower-http/compression-zstd"] -brotli_compression = ["tower-http/compression-br", "reqwest/brotli"] - -sha256_media = ["dep:sha2"] -io_uring = ["rust-rocksdb/io-uring"] -axum_dual_protocol = ["dep:axum-server-dual-protocol"] - -perf_measurements = [ - "dep:opentelemetry", - "dep:tracing-flame", - "dep:tracing-opentelemetry", - "dep:opentelemetry_sdk", - "dep:opentelemetry-jaeger", -] - -# enable the tokio_console server -# incompatible with release_max_log_level -tokio_console = ["dep:console-subscriber", "tokio/tracing"] - -hot_reload = ["dep:hot-lib-reloader"] - -hardened_malloc = ["dep:hardened_malloc-rs"] - -# increases performance, reduces build times, and reduces binary size by not compiling or -# genreating code for log level filters that users will generally not use (debug and trace) only in release builds # -# the expense is obviously losing those log level filters for usage at runtime. debug builds will still have all log levels -release_max_log_level = [ - "tracing/max_level_trace", - "tracing/release_max_level_info", - "log/max_level_trace", - "log/release_max_level_info", -] - -# developer feature useful only in debug builds. -dev_release_log_level = [] - -# client/server interopability hacks +# Our crates # -## element has various non-spec compliant behaviour -element_hacks = [] +[workspace.dependencies.conduit-router] +package = "conduit_router" +path = "src/router" +default-features = false -[package.metadata.deb] -name = "conduwuit" -maintainer = "strawberry " -copyright = "2024, strawberry " -license-file = ["LICENSE", "3"] -depends = "$auto, ca-certificates" -extended-description = """\ -a cool hard fork of Conduit, a Matrix homeserver written in Rust""" -section = "net" -priority = "optional" -assets = [ - [ - "debian/README.md", - "usr/share/doc/conduwuit/README.Debian", - "644", - ], - [ - "README.md", - "usr/share/doc/conduwuit/", - "644", - ], - [ - "target/release/conduwuit", - "usr/sbin/conduwuit", - "755", - ], - [ - "conduwuit-example.toml", - "etc/conduwuit/conduwuit.toml", - "640", - ], -] -conf-files = ["/etc/conduwuit/conduwuit.toml"] -maintainer-scripts = "debian/" -systemd-units = { unit-name = "conduwuit", start = false } +[workspace.dependencies.conduit-admin] +package = "conduit_admin" +path = "src/admin" +default-features = false +[workspace.dependencies.conduit-api] +package = "conduit_api" +path = "src/api" +default-features = false -[profile.dev] -#debug = 0 -lto = 'off' -codegen-units = 512 -incremental = true -overflow-checks = true -#panic = "abort" +[workspace.dependencies.conduit-service] +package = "conduit_service" +path = "src/service" +default-features = false -# seems to speed up continuous debug compilations -[profile.dev.build-override] -opt-level = 3 -[profile.dev.package."*"] # external dependencies -opt-level = 1 -[profile.dev.package."tokio"] -opt-level = 3 +[workspace.dependencies.conduit-database] +package = "conduit_database" +path = "src/database" +default-features = false + +[workspace.dependencies.conduit-core] +package = "conduit_core" +path = "src/core" +default-features = false + +############################################################################### +# +# Release profiles +# -# default release profile [profile.release] -lto = 'thin' -incremental = false -opt-level = 3 strip = "symbols" -control-flow-guard = true # Windows only -debug = 0 +lto = "thin" # release profile with debug symbols [profile.release-debuginfo] inherits = "release" -strip = "none" debug = "full" +strip = "none" - -# high performance release profile which uses fat LTO across all crates, 1 codegen unit, max opt-level, and optimises across all crates [profile.release-high-perf] inherits = "release" -lto = 'fat' +lto = "fat" codegen-units = 1 panic = "abort" -# For releases also try to max optimizations for dependencies: -[profile.release-high-perf.build-override] -debug = 0 -opt-level = 3 +# do not use without profile-rustflags enabled +[profile.release-max-perf] +inherits = "release" +strip = "symbols" +lto = "fat" +#rustflags = [ +# '-Ctarget-cpu=native', +# '-Ztune-cpu=native', +# '-Ctarget-feature=+crt-static', +# '-Crelocation-model=static', +# '-Ztls-model=local-exec', +# '-Zinline-in-all-cgus=true', +# '-Zinline-mir=true', +# '-Zmir-opt-level=3', +# '-Clink-arg=-fuse-ld=mold', +# '-Clink-arg=-Wl,--threads', +# '-Clink-arg=-Wl,--gc-sections', +# '-Ztime-passes', +# '-Ztime-llvm-passes', +#] + +[profile.release-max-perf.build-override] +inherits = "release-max-perf" +opt-level = 0 +#rustflags = [ +# '-Ctarget-feature=-crt-static', +#] + +[profile.bench] +inherits = "release" +#rustflags = [ +# "-Cremark=all", +# '-Ztime-passes', +# '-Ztime-llvm-passes', +#] + +############################################################################### +# +# Developer profile +# + +# To enable hot-reloading: +# 1. Uncomment all of the rustflags here. +# 2. Uncomment crate-type=dylib in src/*/Cargo.toml and deps/rust-rocksdb/Cargo.toml +# 2. Build with the 'mods' feature. +# +# opt-level, mir-opt-level, validate-mir are not known to interfere with reloading +# and can be raised if build times are tolerable. + +[profile.dev] +debug = 1 +opt-level = 0 +panic = "unwind" +debug-assertions = true +incremental = true +codegen-units = 64 +rpath = true +#rustflags = [ +# '-Ztime-passes', +# '-Zmir-opt-level=0', +# '-Zvalidate-mir=false', +# '-Ztls-model=global-dynamic', +# '-Cprefer-dynamic=true', +# '-Zstaticlib-prefer-dynamic=true', +# '-Zstaticlib-allow-rdylib-deps=true', +# '-Zpacked-bundled-libs=false', +# '-Zplt=true', +# '-Crpath=true', +# '-Clink-arg=-Wl,--as-needed', +# '-Clink-arg=-Wl,--allow-shlib-undefined', +# '-Clink-arg=-Wl,-z,keep-text-section-prefix', +# '-Clink-arg=-Wl,-z,lazy', +#] + +[profile.dev.package.conduit_core] +inherits = "dev" +incremental = false +#rustflags = [ +# '-Ztime-passes', +# '-Zmir-opt-level=0', +# '-Ztls-model=initial-exec', +# '-Cprefer-dynamic=true', +# '-Zstaticlib-prefer-dynamic=true', +# '-Zstaticlib-allow-rdylib-deps=true', +# '-Zpacked-bundled-libs=false', +# '-Zplt=true', +# '-Clink-arg=-Wl,--as-needed', +# '-Clink-arg=-Wl,--allow-shlib-undefined', +# '-Clink-arg=-Wl,-z,lazy', +# '-Clink-arg=-Wl,-z,unique', +# '-Clink-arg=-Wl,-z,nodlopen', +# '-Clink-arg=-Wl,-z,nodelete', +#] + +[profile.dev.package.conduit] +inherits = "dev" +incremental = false +#rustflags = [ +# '-Ztime-passes', +# '-Zmir-opt-level=0', +# '-Zvalidate-mir=false', +# '-Ztls-model=global-dynamic', +# '-Cprefer-dynamic=true', +# '-Zexport-executable-symbols=true', +# '-Zplt=true', +# '-Crpath=true', +# '-Clink-arg=-Wl,--as-needed', +# '-Clink-arg=-Wl,--allow-shlib-undefined', +# '-Clink-arg=-Wl,--export-dynamic', +# '-Clink-arg=-Wl,-z,lazy', +#] + +[profile.dev.package.rust-rocksdb-uwu] +inherits = "dev" +debug = 'limited' +incremental = false codegen-units = 1 +opt-level = 'z' +#rustflags = [ +# '-Ztls-model=initial-exec', +# '-Cprefer-dynamic=true', +# '-Zstaticlib-prefer-dynamic=true', +# '-Zstaticlib-allow-rdylib-deps=true', +# '-Zpacked-bundled-libs=true', +# '-Zplt=true', +# '-Clink-arg=-Wl,--no-as-needed', +# '-Clink-arg=-Wl,--allow-shlib-undefined', +# '-Clink-arg=-Wl,-z,lazy', +# '-Clink-arg=-Wl,-z,nodlopen', +# '-Clink-arg=-Wl,-z,nodelete', +#] -[profile.release-high-perf.package."*"] -debug = 0 -opt-level = 3 +[profile.dev.package.'*'] +inherits = "dev" +debug = 'limited' +incremental = false codegen-units = 1 +opt-level = 'z' +#rustflags = [ +# '-Ztls-model=global-dynamic', +# '-Cprefer-dynamic=true', +# '-Zstaticlib-prefer-dynamic=true', +# '-Zstaticlib-allow-rdylib-deps=true', +# '-Zpacked-bundled-libs=true', +# '-Zplt=true', +# '-Clink-arg=-Wl,--as-needed', +# '-Clink-arg=-Wl,-z,lazy', +# '-Clink-arg=-Wl,-z,nodelete', +#] - -[lints] -workspace = true +[profile.test] +incremental = false [workspace.lints.rust] missing_abi = "warn" @@ -543,7 +655,6 @@ unused_braces = "allow" # some sadness missing_docs = "allow" - [workspace.lints.clippy] # pedantic = "warn" @@ -615,7 +726,6 @@ unnecessary_box_returns = "warn" map_unwrap_or = "warn" implicit_clone = "warn" match_wildcard_for_single_variants = "warn" -unnecessary_wraps = "warn" match_same_arms = "warn" ignored_unit_patterns = "warn" redundant_else = "warn" @@ -650,6 +760,7 @@ unwrap_used = "allow" expect_used = "allow" if_then_some_else_none = "allow" let_underscore_must_use = "allow" +let_underscore_future = "allow" map_err_ignore = "allow" missing_docs_in_private_items = "allow" multiple_inherent_impl = "allow" @@ -657,3 +768,4 @@ error_impl_error = "allow" string_add = "allow" string_slice = "allow" ref_patterns = "allow" +unnecessary_wraps = "allow" diff --git a/deps/rust-rocksdb/Cargo.toml b/deps/rust-rocksdb/Cargo.toml new file mode 100644 index 00000000..f5e3211e --- /dev/null +++ b/deps/rust-rocksdb/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "rust-rocksdb-uwu" +version = "0.0.1" +edition = "2021" + +[features] +default = ["snappy", "lz4", "zstd", "zlib", "bzip2"] +jemalloc = ["rust-rocksdb/jemalloc"] +io-uring = ["rust-rocksdb/io-uring"] +valgrind = ["rust-rocksdb/valgrind"] +snappy = ["rust-rocksdb/snappy"] +lz4 = ["rust-rocksdb/lz4"] +zstd = ["rust-rocksdb/zstd"] +zlib = ["rust-rocksdb/zlib"] +bzip2 = ["rust-rocksdb/bzip2"] +rtti = ["rust-rocksdb/rtti"] +mt_static = ["rust-rocksdb/mt_static"] +multi-threaded-cf = ["rust-rocksdb/multi-threaded-cf"] +serde1 = ["rust-rocksdb/serde1"] +malloc-usable-size = ["rust-rocksdb/malloc-usable-size"] + +[dependencies.rust-rocksdb] +git = "https://github.com/zaidoon1/rust-rocksdb" +branch = "master" +default-features = false + +[lib] +path = "lib.rs" +crate-type = [ + "rlib", +# "dylib" +] diff --git a/deps/rust-rocksdb/lib.rs b/deps/rust-rocksdb/lib.rs new file mode 100644 index 00000000..ae11159e --- /dev/null +++ b/deps/rust-rocksdb/lib.rs @@ -0,0 +1,59 @@ +pub use rust_rocksdb::*; + +#[link(name = "rocksdb", kind = "static")] +extern "C" { + pub fn rocksdb_list_column_families(); + pub fn rocksdb_logger_create_stderr_logger(); + pub fn rocksdb_options_set_info_log(); + pub fn rocksdb_get_options_from_string(); + pub fn rocksdb_writebatch_create(); + pub fn rocksdb_writebatch_put_cf(); + pub fn rocksdb_writebatch_delete_cf(); + pub fn rocksdb_iter_value(); + pub fn rocksdb_iter_seek_to_last(); + pub fn rocksdb_iter_seek_for_prev(); + pub fn rocksdb_iter_seek_to_first(); + pub fn rocksdb_iter_next(); + pub fn rocksdb_iter_prev(); + pub fn rocksdb_iter_seek(); + pub fn rocksdb_iter_valid(); + pub fn rocksdb_iter_get_error(); + pub fn rocksdb_iter_key(); + pub fn rocksdb_iter_destroy(); + pub fn rocksdb_livefiles(); + pub fn rocksdb_livefiles_count(); + pub fn rocksdb_livefiles_destroy(); + pub fn rocksdb_livefiles_column_family_name(); + pub fn rocksdb_livefiles_name(); + pub fn rocksdb_livefiles_size(); + pub fn rocksdb_livefiles_level(); + pub fn rocksdb_livefiles_smallestkey(); + pub fn rocksdb_livefiles_largestkey(); + pub fn rocksdb_livefiles_entries(); + pub fn rocksdb_livefiles_deletions(); + pub fn rocksdb_put_cf(); + pub fn rocksdb_delete_cf(); + pub fn rocksdb_get_pinned_cf(); + pub fn rocksdb_create_column_family(); + pub fn rocksdb_get_latest_sequence_number(); + pub fn rocksdb_batched_multi_get_cf(); + pub fn rocksdb_cancel_all_background_work(); + pub fn rocksdb_repair_db(); + pub fn rocksdb_list_column_families_destroy(); + pub fn rocksdb_flush(); + pub fn rocksdb_flush_wal(); + pub fn rocksdb_open_column_families(); + pub fn rocksdb_open_for_read_only_column_families(); + pub fn rocksdb_open_as_secondary_column_families(); + pub fn rocksdb_open_column_families_with_ttl(); + pub fn rocksdb_open(); + pub fn rocksdb_open_for_read_only(); + pub fn rocksdb_open_with_ttl(); + pub fn rocksdb_open_as_secondary(); + pub fn rocksdb_write(); + pub fn rocksdb_create_iterator_cf(); + pub fn rocksdb_backup_engine_create_new_backup_flush(); + pub fn rocksdb_backup_engine_options_create(); + pub fn rocksdb_write_buffer_manager_destroy(); + pub fn rocksdb_options_set_ttl(); +} diff --git a/nix/pkgs/main/default.nix b/nix/pkgs/main/default.nix index f546a328..ec2aebc2 100644 --- a/nix/pkgs/main/default.nix +++ b/nix/pkgs/main/default.nix @@ -55,7 +55,7 @@ commonAttrs = { include = [ "Cargo.lock" "Cargo.toml" - "hot_lib" + "deps" "src" ]; }; diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml new file mode 100644 index 00000000..49ec4267 --- /dev/null +++ b/src/admin/Cargo.toml @@ -0,0 +1,63 @@ +[package] +name = "conduit_admin" +version.workspace = true +edition.workspace = true + +[lib] +path = "mod.rs" +crate-type = [ + "rlib", +# "dylib", +] + +[features] +default = [ + "rocksdb", + "io_uring", + "jemalloc", + "zstd_compression", + "release_max_log_level", +] + +dev_release_log_level = [] +release_max_log_level = [ + "tracing/max_level_trace", + "tracing/release_max_level_info", + "log/max_level_trace", + "log/release_max_level_info", +] +rocksdb = [ + "dep:rust-rocksdb", +] +jemalloc = [ + "rust-rocksdb/jemalloc", +] +io_uring = [ + "rust-rocksdb/io-uring", +] +zstd_compression = [ + "rust-rocksdb/zstd", +] + +[dependencies] +clap.workspace = true +conduit-api.workspace = true +conduit-core.workspace = true +conduit-database.workspace = true +conduit-service.workspace = true +futures-util.workspace = true +log.workspace = true +loole.workspace = true +regex.workspace = true +ruma.workspace = true +rust-rocksdb.optional = true +rust-rocksdb.workspace = true +serde_json.workspace = true +serde.workspace = true +serde_yaml.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true + +[lints] +workspace = true diff --git a/src/service/admin/appservice/appservice_command.rs b/src/admin/appservice/appservice_command.rs similarity index 97% rename from src/service/admin/appservice/appservice_command.rs rename to src/admin/appservice/appservice_command.rs index 4e99c78b..409ef83b 100644 --- a/src/service/admin/appservice/appservice_command.rs +++ b/src/admin/appservice/appservice_command.rs @@ -1,6 +1,6 @@ use ruma::{api::appservice::Registration, events::room::message::RoomMessageEventContent}; -use crate::{service::admin::escape_html, services, Result}; +use crate::{escape_html, services, Result}; pub(crate) async fn register(body: Vec<&str>) -> Result { if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { diff --git a/src/service/admin/appservice/mod.rs b/src/admin/appservice/mod.rs similarity index 98% rename from src/service/admin/appservice/mod.rs rename to src/admin/appservice/mod.rs index b0d225aa..8cf246b9 100644 --- a/src/service/admin/appservice/mod.rs +++ b/src/admin/appservice/mod.rs @@ -1,8 +1,8 @@ use clap::Subcommand; +use conduit::Result; use ruma::events::room::message::RoomMessageEventContent; use self::appservice_command::{list, register, show, unregister}; -use crate::Result; pub(crate) mod appservice_command; diff --git a/src/service/admin/debug/debug_commands.rs b/src/admin/debug/debug_commands.rs similarity index 98% rename from src/service/admin/debug/debug_commands.rs rename to src/admin/debug/debug_commands.rs index 3ce7fb14..838c4b22 100644 --- a/src/service/admin/debug/debug_commands.rs +++ b/src/admin/debug/debug_commands.rs @@ -1,18 +1,15 @@ use std::{collections::BTreeMap, sync::Arc, time::Instant}; +use conduit::{utils::HtmlEscape, Error, Result}; use ruma::{ api::client::error::ErrorKind, events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, RoomId, RoomVersionId, ServerName, }; +use service::{rooms::event_handler::parse_incoming_pdu, sending::send::resolve_actual_dest, services, PduEvent}; use tokio::sync::RwLock; use tracing::{debug, info, warn}; use tracing_subscriber::EnvFilter; -use crate::{ - api::server_server::parse_incoming_pdu, service::sending::send::resolve_actual_dest, services, utils::HtmlEscape, - Error, PduEvent, Result, -}; - pub(crate) async fn get_auth_chain(_body: Vec<&str>, event_id: Box) -> Result { let event_id = Arc::::from(event_id); if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? { @@ -332,7 +329,7 @@ pub(crate) async fn change_log_level( }; match services() - .globals + .server .tracing_reload_handle .reload(&old_filter_layer) { @@ -361,7 +358,7 @@ pub(crate) async fn change_log_level( }; match services() - .globals + .server .tracing_reload_handle .reload(&new_filter_layer) { @@ -447,15 +444,16 @@ pub(crate) async fn resolve_true_destination( )); } - let (actual_dest, hostname_uri) = resolve_actual_dest(&server_name, no_cache, true).await?; + let (actual_dest, hostname_uri) = resolve_actual_dest(&server_name, no_cache).await?; Ok(RoomMessageEventContent::text_plain(format!( "Actual destination: {actual_dest:?} | Hostname URI: {hostname_uri}" ))) } +#[must_use] pub(crate) fn memory_stats() -> RoomMessageEventContent { - let html_body = crate::alloc::memory_stats(); + let html_body = conduit::alloc::memory_stats(); if html_body.is_empty() { return RoomMessageEventContent::text_plain("malloc stats are not supported on your compiled malloc."); diff --git a/src/service/admin/debug/mod.rs b/src/admin/debug/mod.rs similarity index 100% rename from src/service/admin/debug/mod.rs rename to src/admin/debug/mod.rs diff --git a/src/service/admin/federation/federation_commands.rs b/src/admin/federation/federation_commands.rs similarity index 97% rename from src/service/admin/federation/federation_commands.rs rename to src/admin/federation/federation_commands.rs index b69fe73c..f461c237 100644 --- a/src/service/admin/federation/federation_commands.rs +++ b/src/admin/federation/federation_commands.rs @@ -2,12 +2,7 @@ use std::fmt::Write as _; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId}; -use crate::{ - service::admin::{escape_html, get_room_info}, - services, - utils::HtmlEscape, - Result, -}; +use crate::{escape_html, get_room_info, services, utils::HtmlEscape, Result}; pub(crate) async fn disable_room(_body: Vec<&str>, room_id: Box) -> Result { services().rooms.metadata.disable_room(&room_id, true)?; diff --git a/src/service/admin/federation/mod.rs b/src/admin/federation/mod.rs similarity index 100% rename from src/service/admin/federation/mod.rs rename to src/admin/federation/mod.rs diff --git a/src/service/admin/fsck/fsck_commands.rs b/src/admin/fsck/fsck_commands.rs similarity index 100% rename from src/service/admin/fsck/fsck_commands.rs rename to src/admin/fsck/fsck_commands.rs diff --git a/src/service/admin/fsck/mod.rs b/src/admin/fsck/mod.rs similarity index 100% rename from src/service/admin/fsck/mod.rs rename to src/admin/fsck/mod.rs diff --git a/src/admin/handler.rs b/src/admin/handler.rs new file mode 100644 index 00000000..7b7396ef --- /dev/null +++ b/src/admin/handler.rs @@ -0,0 +1,305 @@ +use std::sync::Arc; + +use clap::Parser; +use regex::Regex; +use ruma::{ + events::{ + relation::InReplyTo, + room::message::{Relation::Reply, RoomMessageEventContent}, + TimelineEventType, + }, + OwnedRoomId, OwnedUserId, ServerName, UserId, +}; +use serde_json::value::to_raw_value; +use tokio::sync::MutexGuard; +use tracing::error; + +extern crate conduit_service as service; + +use conduit::{Error, Result}; +pub(crate) use service::admin::{AdminRoomEvent, Service}; +use service::{admin::HandlerResult, pdu::PduBuilder}; + +use self::{fsck::FsckCommand, tester::TesterCommands}; +use crate::{ + appservice, appservice::AppserviceCommand, debug, debug::DebugCommand, escape_html, federation, + federation::FederationCommand, fsck, media, media::MediaCommand, query, query::QueryCommand, room, + room::RoomCommand, server, server::ServerCommand, services, tester, user, user::UserCommand, +}; +pub(crate) const PAGE_SIZE: usize = 100; + +#[cfg_attr(test, derive(Debug))] +#[derive(Parser)] +#[command(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))] +pub(crate) enum AdminCommand { + #[command(subcommand)] + /// - Commands for managing appservices + Appservices(AppserviceCommand), + + #[command(subcommand)] + /// - Commands for managing local users + Users(UserCommand), + + #[command(subcommand)] + /// - Commands for managing rooms + Rooms(RoomCommand), + + #[command(subcommand)] + /// - Commands for managing federation + Federation(FederationCommand), + + #[command(subcommand)] + /// - Commands for managing the server + Server(ServerCommand), + + #[command(subcommand)] + /// - Commands for managing media + Media(MediaCommand), + + #[command(subcommand)] + /// - Commands for debugging things + Debug(DebugCommand), + + #[command(subcommand)] + /// - Query all the database getters and iterators + Query(QueryCommand), + + #[command(subcommand)] + /// - Query all the database getters and iterators + Fsck(FsckCommand), + + #[command(subcommand)] + Tester(TesterCommands), +} + +#[must_use] +pub fn handle(event: AdminRoomEvent, room: OwnedRoomId, user: OwnedUserId) -> HandlerResult { + Box::pin(handle_event(event, room, user)) +} + +async fn handle_event(event: AdminRoomEvent, admin_room: OwnedRoomId, server_user: OwnedUserId) -> Result<()> { + let (mut message_content, reply) = match event { + AdminRoomEvent::SendMessage(content) => (content, None), + AdminRoomEvent::ProcessMessage(room_message, reply_id) => { + (process_admin_message(room_message).await, Some(reply_id)) + }, + }; + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(admin_room.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + if let Some(reply) = reply { + message_content.relates_to = Some(Reply { + in_reply_to: InReplyTo { + event_id: reply.into(), + }, + }); + } + + let response_pdu = PduBuilder { + event_type: TimelineEventType::RoomMessage, + content: to_raw_value(&message_content).expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }; + + if let Err(e) = services() + .rooms + .timeline + .build_and_append_pdu(response_pdu, &server_user, &admin_room, &state_lock) + .await + { + handle_response_error(&e, &admin_room, &server_user, &state_lock).await?; + } + + Ok(()) +} + +async fn handle_response_error( + e: &Error, admin_room: &OwnedRoomId, server_user: &UserId, state_lock: &MutexGuard<'_, ()>, +) -> Result<()> { + error!("Failed to build and append admin room response PDU: \"{e}\""); + let error_room_message = RoomMessageEventContent::text_plain(format!( + "Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \ + successfully, but we could not return the output." + )); + + let response_pdu = PduBuilder { + event_type: TimelineEventType::RoomMessage, + content: to_raw_value(&error_room_message).expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }; + + services() + .rooms + .timeline + .build_and_append_pdu(response_pdu, server_user, admin_room, state_lock) + .await?; + + Ok(()) +} + +// Parse and process a message from the admin room +async fn process_admin_message(room_message: String) -> RoomMessageEventContent { + let mut lines = room_message.lines().filter(|l| !l.trim().is_empty()); + let command_line = lines.next().expect("each string has at least one line"); + let body = lines.collect::>(); + + let admin_command = match parse_admin_command(command_line) { + Ok(command) => command, + Err(error) => { + let server_name = services().globals.server_name(); + let message = error.replace("server.name", server_name.as_str()); + let html_message = usage_to_html(&message, server_name); + + return RoomMessageEventContent::text_html(message, html_message); + }, + }; + + match process_admin_command(admin_command, body).await { + Ok(reply_message) => reply_message, + Err(error) => { + let markdown_message = format!("Encountered an error while handling the command:\n```\n{error}\n```",); + let html_message = format!("Encountered an error while handling the command:\n
\n{error}\n
",); + + RoomMessageEventContent::text_html(markdown_message, html_message) + }, + } +} + +// Parse chat messages from the admin room into an AdminCommand object +fn parse_admin_command(command_line: &str) -> Result { + // Note: argv[0] is `@conduit:servername:`, which is treated as the main command + let mut argv = command_line.split_whitespace().collect::>(); + + // Replace `help command` with `command --help` + // Clap has a help subcommand, but it omits the long help description. + if argv.len() > 1 && argv[1] == "help" { + argv.remove(1); + argv.push("--help"); + } + + // Backwards compatibility with `register_appservice`-style commands + let command_with_dashes_argv1; + if argv.len() > 1 && argv[1].contains('_') { + command_with_dashes_argv1 = argv[1].replace('_', "-"); + argv[1] = &command_with_dashes_argv1; + } + + // Backwards compatibility with `register_appservice`-style commands + let command_with_dashes_argv2; + if argv.len() > 2 && argv[2].contains('_') { + command_with_dashes_argv2 = argv[2].replace('_', "-"); + argv[2] = &command_with_dashes_argv2; + } + + // if the user is using the `query` command (argv[1]), replace the database + // function/table calls with underscores to match the codebase + let command_with_dashes_argv3; + if argv.len() > 3 && argv[1].eq("query") { + command_with_dashes_argv3 = argv[3].replace('_', "-"); + argv[3] = &command_with_dashes_argv3; + } + + AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) +} + +async fn process_admin_command(command: AdminCommand, body: Vec<&str>) -> Result { + let reply_message_content = match command { + AdminCommand::Appservices(command) => appservice::process(command, body).await?, + AdminCommand::Media(command) => media::process(command, body).await?, + AdminCommand::Users(command) => user::process(command, body).await?, + AdminCommand::Rooms(command) => room::process(command, body).await?, + AdminCommand::Federation(command) => federation::process(command, body).await?, + AdminCommand::Server(command) => server::process(command, body).await?, + AdminCommand::Debug(command) => debug::process(command, body).await?, + AdminCommand::Query(command) => query::process(command, body).await?, + AdminCommand::Fsck(command) => fsck::process(command, body).await?, + AdminCommand::Tester(command) => tester::process(command, body).await?, + }; + + Ok(reply_message_content) +} + +// Utility to turn clap's `--help` text to HTML. +fn usage_to_html(text: &str, server_name: &ServerName) -> String { + // Replace `@conduit:servername:-subcmdname` with `@conduit:servername: + // subcmdname` + let text = text.replace(&format!("@conduit:{server_name}:-"), &format!("@conduit:{server_name}: ")); + + // For the conduit admin room, subcommands become main commands + let text = text.replace("SUBCOMMAND", "COMMAND"); + let text = text.replace("subcommand", "command"); + + // Escape option names (e.g. ``) since they look like HTML tags + let text = escape_html(&text); + + // Italicize the first line (command name and version text) + let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); + let text = re.replace_all(&text, "$1\n"); + + // Unmerge wrapped lines + let text = text.replace("\n ", " "); + + // Wrap option names in backticks. The lines look like: + // -V, --version Prints version information + // And are converted to: + // -V, --version: Prints version information + // (?m) enables multi-line mode for ^ and $ + let re = Regex::new("(?m)^ {4}(([a-zA-Z_&;-]+(, )?)+) +(.*)$").expect("Regex compilation should not fail"); + let text = re.replace_all(&text, "$1: $4"); + + // Look for a `[commandbody]` tag. If it exists, use all lines below it that + // start with a `#` in the USAGE section. + let mut text_lines = text.lines().collect::>(); + let mut command_body = String::new(); + + if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { + text_lines.remove(line_index); + + while text_lines + .get(line_index) + .is_some_and(|line| line.starts_with('#')) + { + command_body += if text_lines[line_index].starts_with("# ") { + &text_lines[line_index][2..] + } else { + &text_lines[line_index][1..] + }; + command_body += "[nobr]\n"; + text_lines.remove(line_index); + } + } + + let text = text_lines.join("\n"); + + // Improve the usage section + let text = if command_body.is_empty() { + // Wrap the usage line in code tags + let re = Regex::new("(?m)^USAGE:\n {4}(@conduit:.*)$").expect("Regex compilation should not fail"); + re.replace_all(&text, "USAGE:\n$1").to_string() + } else { + // Wrap the usage line in a code block, and add a yaml block example + // This makes the usage of e.g. `register-appservice` more accurate + let re = Regex::new("(?m)^USAGE:\n {4}(.*?)\n\n").expect("Regex compilation should not fail"); + re.replace_all(&text, "USAGE:\n
$1[nobr]\n[commandbodyblock]
") + .replace("[commandbodyblock]", &command_body) + }; + + // Add HTML line-breaks + + text.replace("\n\n\n", "\n\n") + .replace('\n', "
\n") + .replace("[nobr]
", "") +} diff --git a/src/service/admin/media/media_commands.rs b/src/admin/media/media_commands.rs similarity index 98% rename from src/service/admin/media/media_commands.rs rename to src/admin/media/media_commands.rs index 3f1fc8bf..8e87b736 100644 --- a/src/service/admin/media/media_commands.rs +++ b/src/admin/media/media_commands.rs @@ -1,7 +1,7 @@ -use ruma::{events::room::message::RoomMessageEventContent, EventId}; +use ruma::{events::room::message::RoomMessageEventContent, EventId, MxcUri}; use tracing::{debug, info}; -use crate::{service::admin::MxcUri, services, Result}; +use crate::{services, Result}; pub(crate) async fn delete( _body: Vec<&str>, mxc: Option>, event_id: Option>, diff --git a/src/service/admin/media/mod.rs b/src/admin/media/mod.rs similarity index 96% rename from src/service/admin/media/mod.rs rename to src/admin/media/mod.rs index d091f94a..4e21b750 100644 --- a/src/service/admin/media/mod.rs +++ b/src/admin/media/mod.rs @@ -1,8 +1,8 @@ use clap::Subcommand; -use ruma::{events::room::message::RoomMessageEventContent, EventId}; +use ruma::{events::room::message::RoomMessageEventContent, EventId, MxcUri}; use self::media_commands::{delete, delete_list, delete_past_remote_media}; -use crate::{service::admin::MxcUri, Result}; +use crate::Result; pub(crate) mod media_commands; diff --git a/src/admin/mod.rs b/src/admin/mod.rs new file mode 100644 index 00000000..1832cc9b --- /dev/null +++ b/src/admin/mod.rs @@ -0,0 +1,55 @@ +pub(crate) mod appservice; +pub(crate) mod debug; +pub(crate) mod federation; +pub(crate) mod fsck; +pub(crate) mod handler; +pub(crate) mod media; +pub(crate) mod query; +pub(crate) mod room; +pub(crate) mod server; +pub(crate) mod tester; +pub(crate) mod user; +pub(crate) mod utils; + +extern crate conduit_api as api; +extern crate conduit_core as conduit; +extern crate conduit_service as service; + +pub(crate) use conduit::{mod_ctor, mod_dtor, Result}; +pub use handler::handle; +pub(crate) use service::{services, user_is_local}; + +pub(crate) use crate::{ + handler::Service, + utils::{escape_html, get_room_info}, +}; + +mod_ctor! {} +mod_dtor! {} + +#[cfg(test)] +mod test { + use clap::Parser; + + use crate::handler::AdminCommand; + + #[test] + fn get_help_short() { get_help_inner("-h"); } + + #[test] + fn get_help_long() { get_help_inner("--help"); } + + #[test] + fn get_help_subcommand() { get_help_inner("help"); } + + fn get_help_inner(input: &str) { + let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) + .unwrap_err() + .to_string(); + + // Search for a handful of keywords that suggest the help printed properly + assert!(error.contains("Usage:")); + assert!(error.contains("Commands:")); + assert!(error.contains("Options:")); + } +} diff --git a/src/service/admin/query/account_data.rs b/src/admin/query/account_data.rs similarity index 100% rename from src/service/admin/query/account_data.rs rename to src/admin/query/account_data.rs diff --git a/src/service/admin/query/appservice.rs b/src/admin/query/appservice.rs similarity index 100% rename from src/service/admin/query/appservice.rs rename to src/admin/query/appservice.rs diff --git a/src/service/admin/query/globals.rs b/src/admin/query/globals.rs similarity index 100% rename from src/service/admin/query/globals.rs rename to src/admin/query/globals.rs diff --git a/src/service/admin/query/mod.rs b/src/admin/query/mod.rs similarity index 100% rename from src/service/admin/query/mod.rs rename to src/admin/query/mod.rs diff --git a/src/service/admin/query/presence.rs b/src/admin/query/presence.rs similarity index 100% rename from src/service/admin/query/presence.rs rename to src/admin/query/presence.rs diff --git a/src/service/admin/query/room_alias.rs b/src/admin/query/room_alias.rs similarity index 100% rename from src/service/admin/query/room_alias.rs rename to src/admin/query/room_alias.rs diff --git a/src/service/admin/query/sending.rs b/src/admin/query/sending.rs similarity index 100% rename from src/service/admin/query/sending.rs rename to src/admin/query/sending.rs diff --git a/src/service/admin/query/users.rs b/src/admin/query/users.rs similarity index 100% rename from src/service/admin/query/users.rs rename to src/admin/query/users.rs diff --git a/src/service/admin/room/mod.rs b/src/admin/room/mod.rs similarity index 100% rename from src/service/admin/room/mod.rs rename to src/admin/room/mod.rs diff --git a/src/service/admin/room/room_alias_commands.rs b/src/admin/room/room_alias_commands.rs similarity index 98% rename from src/service/admin/room/room_alias_commands.rs rename to src/admin/room/room_alias_commands.rs index 516df071..f2b5c7eb 100644 --- a/src/service/admin/room/room_alias_commands.rs +++ b/src/admin/room/room_alias_commands.rs @@ -3,7 +3,7 @@ use std::fmt::Write as _; use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId}; use super::RoomAliasCommand; -use crate::{service::admin::escape_html, services, Result}; +use crate::{escape_html, services, Result}; pub(crate) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Result { match command { diff --git a/src/service/admin/room/room_commands.rs b/src/admin/room/room_commands.rs similarity index 93% rename from src/service/admin/room/room_commands.rs rename to src/admin/room/room_commands.rs index 4e4e60e1..701cfb54 100644 --- a/src/service/admin/room/room_commands.rs +++ b/src/admin/room/room_commands.rs @@ -2,10 +2,7 @@ use std::fmt::Write as _; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; -use crate::{ - service::admin::{escape_html, get_room_info, PAGE_SIZE}, - services, Result, -}; +use crate::{escape_html, get_room_info, handler::PAGE_SIZE, services, Result}; pub(crate) async fn list(_body: Vec<&str>, page: Option) -> Result { // TODO: i know there's a way to do this with clap, but i can't seem to find it diff --git a/src/service/admin/room/room_directory_commands.rs b/src/admin/room/room_directory_commands.rs similarity index 96% rename from src/service/admin/room/room_directory_commands.rs rename to src/admin/room/room_directory_commands.rs index ccce2164..f6429dee 100644 --- a/src/service/admin/room/room_directory_commands.rs +++ b/src/admin/room/room_directory_commands.rs @@ -3,10 +3,7 @@ use std::fmt::Write as _; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; use super::RoomDirectoryCommand; -use crate::{ - service::admin::{escape_html, get_room_info, PAGE_SIZE}, - services, Result, -}; +use crate::{escape_html, get_room_info, handler::PAGE_SIZE, services, Result}; pub(crate) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> Result { match command { diff --git a/src/service/admin/room/room_moderation_commands.rs b/src/admin/room/room_moderation_commands.rs similarity index 83% rename from src/service/admin/room/room_moderation_commands.rs rename to src/admin/room/room_moderation_commands.rs index 5c62e360..03de4cde 100644 --- a/src/service/admin/room/room_moderation_commands.rs +++ b/src/admin/room/room_moderation_commands.rs @@ -1,18 +1,16 @@ -use std::fmt::Write as _; +use std::fmt::Write; +use api::client_server::{get_alias_helper, leave_room}; use ruma::{ events::room::message::RoomMessageEventContent, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, }; use tracing::{debug, error, info, warn}; -use super::RoomModerationCommand; -use crate::{ - api::client_server::{get_alias_helper, leave_room}, - service::admin::{escape_html, Service}, - services, - utils::user_id::user_is_local, - Result, +use super::{ + super::{escape_html, Service}, + RoomModerationCommand, }; +use crate::{services, user_is_local, Result}; pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> Result { match command { @@ -105,16 +103,16 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> .filter_map(|user| { user.ok().filter(|local_user| { user_is_local(local_user) - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (user_is_local(local_user) - && services() - .users - .is_admin(local_user) - .unwrap_or(true)) // since this is a force - // operation, assume user - // is an admin if somehow - // this fails + // additional wrapped check here is to avoid adding remote users + // who are in the admin room to the list of local users (would fail auth check) + && (user_is_local(local_user) + && services() + .users + .is_admin(local_user) + .unwrap_or(true)) // since this is a force + // operation, assume user + // is an admin if somehow + // this fails }) }) .collect::>() @@ -134,14 +132,14 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> .filter_map(|user| { user.ok().filter(|local_user| { local_user.server_name() == services().globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == services().globals.server_name() - && !services() - .users - .is_admin(local_user) - .unwrap_or(false)) + // additional wrapped check here is to avoid adding remote users + // who are in the admin room to the list of local users (would fail auth check) + && (local_user.server_name() + == services().globals.server_name() + && !services() + .users + .is_admin(local_user) + .unwrap_or(false)) }) }) .collect::>() @@ -309,19 +307,19 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> .filter_map(|user| { user.ok().filter(|local_user| { local_user.server_name() == services().globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == services().globals.server_name() - && services() - .users - .is_admin(local_user) - .unwrap_or(true)) // since this is a - // force operation, - // assume user is - // an admin if - // somehow this - // fails + // additional wrapped check here is to avoid adding remote users + // who are in the admin room to the list of local users (would fail auth check) + && (local_user.server_name() + == services().globals.server_name() + && services() + .users + .is_admin(local_user) + .unwrap_or(true)) // since this is a + // force operation, + // assume user is + // an admin if + // somehow this + // fails }) }) .collect::>() @@ -341,14 +339,14 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> .filter_map(|user| { user.ok().filter(|local_user| { local_user.server_name() == services().globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == services().globals.server_name() - && !services() - .users - .is_admin(local_user) - .unwrap_or(false)) + // additional wrapped check here is to avoid adding remote users + // who are in the admin room to the list of local users (would fail auth check) + && (local_user.server_name() + == services().globals.server_name() + && !services() + .users + .is_admin(local_user) + .unwrap_or(false)) }) }) .collect::>() diff --git a/src/service/admin/server/mod.rs b/src/admin/server/mod.rs similarity index 100% rename from src/service/admin/server/mod.rs rename to src/admin/server/mod.rs diff --git a/src/service/admin/server/server_commands.rs b/src/admin/server/server_commands.rs similarity index 90% rename from src/service/admin/server/server_commands.rs rename to src/admin/server/server_commands.rs index df82fd86..d80cc3d7 100644 --- a/src/service/admin/server/server_commands.rs +++ b/src/admin/server/server_commands.rs @@ -4,7 +4,7 @@ use crate::{services, Result}; pub(crate) async fn uptime(_body: Vec<&str>) -> Result { let seconds = services() - .globals + .server .started .elapsed() .expect("standard duration") @@ -28,7 +28,7 @@ pub(crate) async fn show_config(_body: Vec<&str>) -> Result) -> Result { let response0 = services().memory_usage().await; let response1 = services().globals.db.memory_usage(); - let response2 = crate::alloc::memory_usage(); + let response2 = conduit::alloc::memory_usage(); Ok(RoomMessageEventContent::text_plain(format!( "Services:\n{response0}\n\nDatabase:\n{response1}\n{}", @@ -69,12 +69,15 @@ pub(crate) async fn backup_database(_body: Vec<&str>) -> Result String::new(), - Err(e) => (*e).to_string(), - }) - .await - .unwrap(); + let mut result = services() + .server + .runtime() + .spawn_blocking(move || match services().globals.db.backup() { + Ok(()) => String::new(), + Err(e) => (*e).to_string(), + }) + .await + .unwrap(); if result.is_empty() { result = services().globals.db.backup_list()?; diff --git a/src/service/admin/tester/mod.rs b/src/admin/tester/mod.rs similarity index 92% rename from src/service/admin/tester/mod.rs rename to src/admin/tester/mod.rs index c0f3df15..f7b4ecea 100644 --- a/src/service/admin/tester/mod.rs +++ b/src/admin/tester/mod.rs @@ -9,6 +9,6 @@ pub(crate) enum TesterCommands { } pub(crate) async fn process(command: TesterCommands, _body: Vec<&str>) -> Result { Ok(match command { - TesterCommands::Tester => RoomMessageEventContent::notice_plain(String::from("complete")), + TesterCommands::Tester => RoomMessageEventContent::notice_plain(String::from("completed")), }) } diff --git a/src/service/admin/user/mod.rs b/src/admin/user/mod.rs similarity index 100% rename from src/service/admin/user/mod.rs rename to src/admin/user/mod.rs diff --git a/src/service/admin/user/user_commands.rs b/src/admin/user/user_commands.rs similarity index 97% rename from src/service/admin/user/user_commands.rs rename to src/admin/user/user_commands.rs index 1aa8d4ea..2ec23f0d 100644 --- a/src/service/admin/user/user_commands.rs +++ b/src/admin/user/user_commands.rs @@ -1,15 +1,13 @@ use std::{fmt::Write as _, sync::Arc}; +use api::client_server::{join_room_by_id_helper, leave_all_rooms}; +use conduit::utils; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, UserId}; use tracing::{error, info, warn}; -use crate::{ - api::client_server::{join_room_by_id_helper, leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, - service::admin::{escape_html, get_room_info}, - services, - utils::{self, user_id::user_is_local}, - Result, -}; +use crate::{escape_html, get_room_info, services, user_is_local, Result}; + +const AUTO_GEN_PASSWORD_LENGTH: usize = 25; pub(crate) async fn list(_body: Vec<&str>) -> Result { match services().users.list_local_users() { @@ -111,7 +109,7 @@ pub(crate) async fn create( ) .await { - Ok(_) => { + Ok(_response) => { info!("Automatically joined room {room} for user {user_id}"); }, Err(e) => { diff --git a/src/admin/utils.rs b/src/admin/utils.rs new file mode 100644 index 00000000..7031b848 --- /dev/null +++ b/src/admin/utils.rs @@ -0,0 +1,30 @@ +pub(crate) use conduit::utils::HtmlEscape; +use ruma::OwnedRoomId; + +use crate::services; + +pub(crate) fn escape_html(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") +} + +pub(crate) fn get_room_info(id: &OwnedRoomId) -> (OwnedRoomId, u64, String) { + ( + id.clone(), + services() + .rooms + .state_cache + .room_joined_count(id) + .ok() + .flatten() + .unwrap_or(0), + services() + .rooms + .state_accessor + .get_name(id) + .ok() + .flatten() + .unwrap_or_else(|| id.to_string()), + ) +} diff --git a/src/alloc/default.rs b/src/alloc/default.rs deleted file mode 100644 index 1d61682e..00000000 --- a/src/alloc/default.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! Default allocator with no special features - -/// Always returns the empty string -pub(crate) fn memory_stats() -> String { Default::default() } - -/// Always returns the empty string -pub(crate) fn memory_usage() -> String { Default::default() } diff --git a/src/alloc/hardened.rs b/src/alloc/hardened.rs deleted file mode 100644 index 9ac84f9a..00000000 --- a/src/alloc/hardened.rs +++ /dev/null @@ -1,8 +0,0 @@ -#[global_allocator] -static HMALLOC: hardened_malloc_rs::HardenedMalloc = hardened_malloc_rs::HardenedMalloc; - -pub(crate) fn memory_usage() -> String { - String::default() //TODO: get usage -} - -pub(crate) fn memory_stats() -> String { "Extended statistics are not available from hardened_malloc.".to_owned() } diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml new file mode 100644 index 00000000..890d66af --- /dev/null +++ b/src/api/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "conduit_api" +version.workspace = true +edition.workspace = true + +[lib] +path = "mod.rs" +crate-type = [ + "rlib", +# "dylib", +] + +[features] +default = [ + "element_hacks", + "gzip_compression", + "brotli_compression", + "release_max_log_level", +] + +element_hacks = [] +dev_release_log_level = [] +release_max_log_level = [ + "tracing/max_level_trace", + "tracing/release_max_level_info", + "log/max_level_trace", + "log/release_max_level_info", +] +gzip_compression = [ + "reqwest/gzip", +] +brotli_compression = [ + "reqwest/brotli", +] + +[dependencies] +argon2.workspace = true +axum-extra.workspace = true +axum.workspace = true +base64.workspace = true +bytes.workspace = true +conduit-core.workspace = true +conduit-database.workspace = true +conduit-service.workspace = true +futures-util.workspace = true +hmac.workspace = true +http.workspace = true +hyper.workspace = true +image.workspace = true +ipaddress.workspace = true +jsonwebtoken.workspace = true +log.workspace = true +rand.workspace = true +reqwest.workspace = true +ruma.workspace = true +serde_html_form.workspace = true +serde_json.workspace = true +serde.workspace = true +sha-1.workspace = true +thiserror.workspace = true +tokio.workspace = true +tracing.workspace = true +webpage.workspace = true + +[lints] +workspace = true diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 41343183..14c8fead 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -19,9 +19,10 @@ use tracing::{error, info, warn}; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{ - api::client_server::{self, join_room_by_id_helper}, - service, services, - utils::{self, user_id::user_is_local}, + client_server::{self, join_room_by_id_helper}, + service::user_is_local, + services, + utils::{self}, Error, Result, Ruma, }; diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 821fa0c2..8c3c0e4c 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -13,8 +13,9 @@ use ruma::{ use tracing::debug; use crate::{ - debug_info, debug_warn, service::appservice::RegistrationInfo, services, utils::server_name::server_is_ours, Error, - Result, Ruma, + debug_info, debug_warn, + service::{appservice::RegistrationInfo, server_is_ours}, + services, Error, Result, Ruma, }; /// # `PUT /_matrix/client/v3/directory/room/{roomAlias}` @@ -65,7 +66,6 @@ pub(crate) async fn create_alias_route(body: Ruma) -> /// - TODO: Update canonical alias event pub(crate) async fn delete_alias_route(body: Ruma) -> Result { alias_checks(&body.room_alias, &body.appservice_info).await?; - if services() .rooms .alias @@ -99,7 +99,7 @@ pub(crate) async fn get_alias_route(body: Ruma) -> Resul get_alias_helper(body.body.room_alias, None).await } -pub(crate) async fn get_alias_helper( +pub async fn get_alias_helper( room_alias: OwnedRoomAliasId, servers: Option>, ) -> Result { debug!("get_alias_helper servers: {servers:?}"); diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs index 094e007a..7cfb3392 100644 --- a/src/api/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -24,7 +24,7 @@ use ruma::{ }; use tracing::{error, info, warn}; -use crate::{services, utils::server_name::server_is_ours, Error, Result, Ruma}; +use crate::{service::server_is_ours, services, Error, Result, Ruma}; /// # `POST /_matrix/client/v3/publicRooms` /// diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 040916b0..b021bea1 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -22,8 +22,9 @@ use tracing::debug; use super::SESSION_ID_LENGTH; use crate::{ + service::user_is_local, services, - utils::{self, user_id::user_is_local}, + utils::{self}, Error, Result, Ruma, }; diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index b294fec2..0e70c1dc 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -15,14 +15,16 @@ use webpage::HTML; use crate::{ debug_warn, - service::media::{FileMeta, UrlPreviewData}, + service::{ + media::{FileMeta, UrlPreviewData}, + server_is_ours, + }, services, utils::{ self, content_disposition::{ content_disposition_type, make_content_disposition, make_content_type, sanitise_filename, }, - server_name::server_is_ours, }, Error, Result, Ruma, RumaResponse, }; diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 4fad8cf2..1cee7720 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -35,9 +35,12 @@ use tracing::{debug, error, info, trace, warn}; use super::get_alias_helper; use crate::{ - service::pdu::{gen_event_id_canonical_json, PduBuilder}, + service::{ + pdu::{gen_event_id_canonical_json, PduBuilder}, + server_is_ours, user_is_local, + }, services, - utils::{self, server_name::server_is_ours, user_id::user_is_local}, + utils::{self}, Error, PduEvent, Result, Ruma, }; @@ -607,7 +610,7 @@ pub(crate) async fn joined_members_route( }) } -pub(crate) async fn join_room_by_id_helper( +pub async fn join_room_by_id_helper( sender_user: Option<&UserId>, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, ) -> Result { @@ -1525,7 +1528,7 @@ pub(crate) async fn invite_helper( // Make a user leave all their joined rooms, forgets all rooms, and ignores // errors -pub(crate) async fn leave_all_rooms(user_id: &UserId) { +pub async fn leave_all_rooms(user_id: &UserId) { let all_rooms = services() .rooms .state_cache @@ -1550,7 +1553,7 @@ pub(crate) async fn leave_all_rooms(user_id: &UserId) { } } -pub(crate) async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option) -> Result<()> { +pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option) -> Result<()> { // Ask a remote server if we don't have this room if !services() .rooms diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index 0aa7792d..5bb683f2 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -3,6 +3,7 @@ use std::{ sync::Arc, }; +use conduit::PduCount; use ruma::{ api::client::{ error::ErrorKind, @@ -14,10 +15,7 @@ use ruma::{ }; use serde_json::{from_str, Value}; -use crate::{ - service::{pdu::PduBuilder, rooms::timeline::PduCount}, - services, utils, Error, PduEvent, Result, Ruma, -}; +use crate::{service::pdu::PduBuilder, services, utils, Error, PduEvent, Result, Ruma}; /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` /// diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index 59d851bf..171e9bbe 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -1,40 +1,41 @@ -mod account; -mod alias; -mod backup; -mod capabilities; -mod config; -mod context; -mod device; -mod directory; -mod filter; -mod keys; -mod media; -mod membership; -mod message; -mod presence; -mod profile; -mod push; -mod read_marker; -mod redact; -mod relations; -mod report; -mod room; -mod search; -mod session; -mod space; -mod state; -mod sync; -mod tag; -mod thirdparty; -mod threads; -mod to_device; -mod typing; -mod unstable; -mod unversioned; -mod user_directory; -mod voip; +pub(crate) mod account; +pub(crate) mod alias; +pub(crate) mod backup; +pub(crate) mod capabilities; +pub(crate) mod config; +pub(crate) mod context; +pub(crate) mod device; +pub(crate) mod directory; +pub(crate) mod filter; +pub(crate) mod keys; +pub(crate) mod media; +pub(crate) mod membership; +pub(crate) mod message; +pub(crate) mod presence; +pub(crate) mod profile; +pub(crate) mod push; +pub(crate) mod read_marker; +pub(crate) mod redact; +pub(crate) mod relations; +pub(crate) mod report; +pub(crate) mod room; +pub(crate) mod search; +pub(crate) mod session; +pub(crate) mod space; +pub(crate) mod state; +pub(crate) mod sync; +pub(crate) mod tag; +pub(crate) mod thirdparty; +pub(crate) mod threads; +pub(crate) mod to_device; +pub(crate) mod typing; +pub(crate) mod unstable; +pub(crate) mod unversioned; +pub(crate) mod user_directory; +pub(crate) mod voip; pub(crate) use account::*; +pub use alias::get_alias_helper; pub(crate) use alias::*; pub(crate) use backup::*; pub(crate) use capabilities::*; @@ -46,6 +47,7 @@ pub(crate) use filter::*; pub(crate) use keys::*; pub(crate) use media::*; pub(crate) use membership::*; +pub use membership::{join_room_by_id_helper, leave_all_rooms, leave_room}; pub(crate) use message::*; pub(crate) use presence::*; pub(crate) use profile::*; @@ -77,7 +79,4 @@ const DEVICE_ID_LENGTH: usize = 10; const TOKEN_LENGTH: usize = 32; /// generated user session ID length -pub(crate) const SESSION_ID_LENGTH: usize = 32; - -/// auto-generated password length -pub(crate) const AUTO_GEN_PASSWORD_LENGTH: usize = 25; +const SESSION_ID_LENGTH: usize = service::uiaa::SESSION_ID_LENGTH; diff --git a/src/api/client_server/profile.rs b/src/api/client_server/profile.rs index a8cf9af2..b6e4598d 100644 --- a/src/api/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -13,7 +13,10 @@ use ruma::{ }; use serde_json::value::to_raw_value; -use crate::{service::pdu::PduBuilder, services, utils::user_id::user_is_local, Error, Result, Ruma}; +use crate::{ + service::{pdu::PduBuilder, user_is_local}, + services, Error, Result, Ruma, +}; /// # `PUT /_matrix/client/r0/profile/{userId}/displayname` /// diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index f3d6c362..0f43eeef 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use conduit::PduCount; use ruma::{ api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, events::{ @@ -9,7 +10,7 @@ use ruma::{ MilliSecondsSinceUnixEpoch, }; -use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; /// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` /// diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index 46ad4454..db0d4e4a 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -1,5 +1,6 @@ use std::{cmp::max, collections::BTreeMap, sync::Arc}; +use conduit::{debug_info, debug_warn}; use ruma::{ api::client::{ error::ErrorKind, @@ -28,8 +29,7 @@ use serde_json::{json, value::to_raw_value}; use tracing::{error, info, warn}; use crate::{ - api::client_server::invite_helper, - debug_info, debug_warn, + client_server::invite_helper, service::{appservice::RegistrationInfo, pdu::PduBuilder}, services, Error, Result, Ruma, }; diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 03152e3e..7445cc40 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -19,10 +19,8 @@ use ruma::{ use tracing::{error, log::warn}; use crate::{ - service::{self, pdu::PduBuilder}, - services, - utils::server_name::server_is_ours, - Error, Result, Ruma, RumaResponse, + service::{pdu::PduBuilder, server_is_ours}, + services, Error, Result, Ruma, RumaResponse, }; /// # `PUT /_matrix/client/*/rooms/{roomId}/state/{eventType}/{stateKey}` diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 8e7282fe..4967ca86 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -5,6 +5,7 @@ use std::{ time::Duration, }; +use conduit::PduCount; use ruma::{ api::client::{ filter::{FilterDefinition, LazyLoadOptions}, @@ -29,10 +30,7 @@ use ruma::{ }; use tracing::{error, Instrument as _, Span}; -use crate::{ - service::{pdu::EventHash, rooms::timeline::PduCount}, - services, utils, Error, PduEvent, Result, Ruma, RumaResponse, -}; +use crate::{service::pdu::EventHash, services, utils, Error, PduEvent, Result, Ruma, RumaResponse}; /// # `GET /_matrix/client/r0/sync` /// diff --git a/src/api/client_server/to_device.rs b/src/api/client_server/to_device.rs index e85b991f..011e08f7 100644 --- a/src/api/client_server/to_device.rs +++ b/src/api/client_server/to_device.rs @@ -8,7 +8,7 @@ use ruma::{ to_device::DeviceIdOrAllDevices, }; -use crate::{services, utils::user_id::user_is_local, Error, Result, Ruma}; +use crate::{services, user_is_local, Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// diff --git a/src/api/mod.rs b/src/api/mod.rs index 285b9f51..7fe02cfe 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,15 @@ -pub(crate) mod client_server; +pub mod client_server; +pub mod router; pub(crate) mod ruma_wrapper; -pub(crate) mod server_server; +pub mod server_server; + +extern crate conduit_core as conduit; +extern crate conduit_service as service; + +pub use client_server::membership::{join_room_by_id_helper, leave_all_rooms}; +pub(crate) use conduit::{debug_error, debug_info, debug_warn, error::RumaResponse, utils, Error, Result}; +pub(crate) use ruma_wrapper::Ruma; +pub(crate) use service::{pdu::PduEvent, services, user_is_local}; + +conduit::mod_ctor! {} +conduit::mod_dtor! {} diff --git a/src/router/routes.rs b/src/api/router.rs similarity index 96% rename from src/router/routes.rs rename to src/api/router.rs index d8b73b4f..068dc375 100644 --- a/src/router/routes.rs +++ b/src/api/router.rs @@ -6,16 +6,15 @@ use axum::{ routing::{any, get, on, post, MethodFilter}, Router, }; +use conduit::{Error, Result, Server}; use http::{Method, Uri}; use ruma::api::{client::error::ErrorKind, IncomingRequest}; -use crate::{ - api::{client_server, server_server}, - Config, Error, Result, Ruma, RumaResponse, -}; +use crate::{client_server, server_server, Ruma, RumaResponse}; -pub(crate) fn routes(config: &Config) -> Router { - let router = Router::new() +pub fn build(router: Router, server: &Server) -> Router { + let config = &server.config; + let router = router .ruma_route(client_server::get_supported_versions_route) .ruma_route(client_server::get_register_available_route) .ruma_route(client_server::register_route) @@ -187,9 +186,7 @@ pub(crate) fn routes(config: &Config) -> Router { .route("/_conduwuit/server_version", get(client_server::conduwuit_server_version)) .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) - .route("/client/server.json", get(client_server::syncv3_client_server_json)) - .route("/", get(it_works)) - .fallback(not_found); + .route("/client/server.json", get(client_server::syncv3_client_server_json)); if config.allow_federation { router @@ -230,16 +227,10 @@ pub(crate) fn routes(config: &Config) -> Router { } } -async fn not_found(_uri: Uri) -> impl IntoResponse { - Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") -} - async fn initial_sync(_uri: Uri) -> impl IntoResponse { Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented") } -async fn it_works() -> &'static str { "hewwo from conduwuit woof!" } - async fn federation_disabled() -> impl IntoResponse { Error::bad_config("Federation is disabled.") } trait RouterExt { @@ -259,7 +250,7 @@ impl RouterExt for Router { } } -pub(crate) trait RumaHandler { +trait RumaHandler { // Can't transform to a handler without boxing or relying on the nightly-only // impl-trait-in-traits feature. Moving a small amount of extra logic into the // trait allows bypassing both. diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 5443cc5f..93252b7c 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -3,30 +3,26 @@ use std::{collections::BTreeMap, str}; use axum::{ async_trait, extract::{FromRequest, Path}, - response::{IntoResponse, Response}, RequestExt, RequestPartsExt, }; use axum_extra::{ - headers::{ - authorization::{Bearer, Credentials}, - Authorization, - }, + headers::{authorization::Bearer, Authorization}, typed_header::TypedHeaderRejectionReason, TypedHeader, }; use bytes::{BufMut, BytesMut}; -use http::{uri::PathAndQuery, StatusCode}; -use http_body_util::Full; +use conduit::debug_warn; +use http::uri::PathAndQuery; use hyper::Request; use ruma::{ - api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, - CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, + api::{client::error::ErrorKind, AuthScheme, IncomingRequest}, + CanonicalJsonValue, OwnedDeviceId, OwnedUserId, UserId, }; use serde::Deserialize; use tracing::{debug, error, trace, warn}; -use super::{Ruma, RumaResponse}; -use crate::{debug_warn, service::appservice::RegistrationInfo, services, Error, Result}; +use super::{xmatrix::XMatrix, Ruma}; +use crate::{service::appservice::RegistrationInfo, services, Error, Result}; enum Token { Appservice(Box), @@ -332,68 +328,3 @@ where }) } } - -struct XMatrix { - origin: OwnedServerName, - destination: Option, - key: String, // KeyName? - sig: String, -} - -impl Credentials for XMatrix { - const SCHEME: &'static str = "X-Matrix"; - - fn decode(value: &http::HeaderValue) -> Option { - debug_assert!( - value.as_bytes().starts_with(b"X-Matrix "), - "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", - ); - - let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) - .ok()? - .trim_start(); - - let mut origin = None; - let mut destination = None; - let mut key = None; - let mut sig = None; - - for entry in parameters.split_terminator(',') { - let (name, value) = entry.split_once('=')?; - - // It's not at all clear why some fields are quoted and others not in the spec, - // let's simply accept either form for every field. - let value = value - .strip_prefix('"') - .and_then(|rest| rest.strip_suffix('"')) - .unwrap_or(value); - - // FIXME: Catch multiple fields of the same name - match name { - "origin" => origin = Some(value.try_into().ok()?), - "key" => key = Some(value.to_owned()), - "sig" => sig = Some(value.to_owned()), - "destination" => destination = Some(value.to_owned()), - _ => debug!("Unexpected field `{name}` in X-Matrix Authorization header"), - } - } - - Some(Self { - origin: origin?, - key: key?, - sig: sig?, - destination, - }) - } - - fn encode(&self) -> http::HeaderValue { todo!() } -} - -impl IntoResponse for RumaResponse { - fn into_response(self) -> Response { - match self.0.try_into_http_response::() { - Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), - Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), - } - } -} diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index 93474816..e32efb5f 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -1,10 +1,11 @@ +pub(crate) mod axum; +mod xmatrix; + use std::ops::Deref; -use ruma::{api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId}; +use ruma::{CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId}; -use crate::{service::appservice::RegistrationInfo, Error}; - -mod axum; +use crate::service::appservice::RegistrationInfo; /// Extractor for Ruma request structs pub(crate) struct Ruma { @@ -21,14 +22,3 @@ impl Deref for Ruma { fn deref(&self) -> &Self::Target { &self.body } } - -#[derive(Clone)] -pub(crate) struct RumaResponse(pub(crate) T); - -impl From for RumaResponse { - fn from(t: T) -> Self { Self(t) } -} - -impl From for RumaResponse { - fn from(t: Error) -> Self { t.to_response() } -} diff --git a/src/api/ruma_wrapper/xmatrix.rs b/src/api/ruma_wrapper/xmatrix.rs new file mode 100644 index 00000000..74fb7d20 --- /dev/null +++ b/src/api/ruma_wrapper/xmatrix.rs @@ -0,0 +1,61 @@ +use std::str; + +use axum_extra::headers::authorization::Credentials; +use ruma::OwnedServerName; +use tracing::debug; + +pub(crate) struct XMatrix { + pub(crate) origin: OwnedServerName, + pub(crate) destination: Option, + pub(crate) key: String, // KeyName? + pub(crate) sig: String, +} + +impl Credentials for XMatrix { + const SCHEME: &'static str = "X-Matrix"; + + fn decode(value: &http::HeaderValue) -> Option { + debug_assert!( + value.as_bytes().starts_with(b"X-Matrix "), + "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", + ); + + let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) + .ok()? + .trim_start(); + + let mut origin = None; + let mut destination = None; + let mut key = None; + let mut sig = None; + + for entry in parameters.split_terminator(',') { + let (name, value) = entry.split_once('=')?; + + // It's not at all clear why some fields are quoted and others not in the spec, + // let's simply accept either form for every field. + let value = value + .strip_prefix('"') + .and_then(|rest| rest.strip_suffix('"')) + .unwrap_or(value); + + // FIXME: Catch multiple fields of the same name + match name { + "origin" => origin = Some(value.try_into().ok()?), + "key" => key = Some(value.to_owned()), + "sig" => sig = Some(value.to_owned()), + "destination" => destination = Some(value.to_owned()), + _ => debug!("Unexpected field `{name}` in X-Matrix Authorization header"), + } + } + + Some(Self { + origin: origin?, + key: key?, + sig: sig?, + destination, + }) + } + + fn encode(&self) -> http::HeaderValue { todo!() } +} diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 8b3650e0..f9592e3b 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -44,19 +44,23 @@ use ruma::{ }, serde::{Base64, JsonObject, Raw}, to_device::DeviceIdOrAllDevices, - uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, - OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, RoomVersionId, ServerName, + uint, user_id, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedServerName, + OwnedServerSigningKeyId, OwnedUserId, RoomId, RoomVersionId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::RwLock; use tracing::{debug, error, trace, warn}; use crate::{ - api::client_server::{self, claim_keys_helper, get_keys_helper}, + client_server::{self, claim_keys_helper, get_keys_helper}, debug_error, - service::pdu::{gen_event_id_canonical_json, PduBuilder}, + service::{ + pdu::{gen_event_id_canonical_json, PduBuilder}, + rooms::event_handler::parse_incoming_pdu, + server_is_ours, user_is_local, + }, services, - utils::{self, server_name::server_is_ours, user_id::user_is_local}, + utils::{self}, Error, PduEvent, Result, Ruma, }; @@ -196,32 +200,6 @@ pub(crate) async fn get_public_rooms_route( }) } -pub(crate) fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - - let room_id: OwnedRoomId = value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; - - let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else { - return Err(Error::Err(format!("Server is not in room {room_id}"))); - }; - - let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - }; - - Ok((event_id, value, room_id)) -} - /// # `PUT /_matrix/federation/v1/send/{txnId}` /// /// Push EDUs and PDUs to this server. diff --git a/src/bin/Cargo.toml b/src/bin/Cargo.toml new file mode 100644 index 00000000..e50dba24 --- /dev/null +++ b/src/bin/Cargo.toml @@ -0,0 +1,123 @@ +[package] +# TODO: when can we rename to conduwuit? +name = "conduit" +default-run = "conduit" +description.workspace = true +license.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +readme.workspace = true +version.workspace = true +edition.workspace = true +rust-version.workspace = true + +[package.metadata.deb] +name = "conduwuit" +maintainer = "strawberry " +copyright = "2024, strawberry " +license-file = ["LICENSE", "3"] +depends = "$auto, ca-certificates" +extended-description = """\ +a cool hard fork of Conduit, a Matrix homeserver written in Rust""" +section = "net" +priority = "optional" +conf-files = ["/etc/conduwuit/conduwuit.toml"] +maintainer-scripts = "debian/" +systemd-units = { unit-name = "conduwuit", start = false } +assets = [ + ["debian/README.md", "usr/share/doc/conduwuit/README.Debian", "644"], + ["README.md", "usr/share/doc/conduwuit/", "644"], + ["target/release/conduwuit", "usr/sbin/conduwuit", "755"], + ["conduwuit-example.toml", "etc/conduwuit/conduwuit.toml", "640"], +] + +[features] +default = [ + "sentry_telemetry", + "release_max_log_level", +] + +# increases performance, reduces build times, and reduces binary size by not compiling or +# genreating code for log level filters that users will generally not use (debug and trace) +release_max_log_level = [ + "tracing/max_level_trace", + "tracing/release_max_level_info", + "log/max_level_trace", + "log/release_max_level_info", +] +sentry_telemetry = [ + "dep:sentry", + "dep:sentry-tracing", + "dep:sentry-tower", +] +# enable the tokio_console server ncompatible with release_max_log_level +tokio_console = [ + "dep:console-subscriber", + "tokio/tracing", +] +perf_measurements = [ + "dep:opentelemetry", + "dep:tracing-flame", + "dep:tracing-opentelemetry", + "dep:opentelemetry_sdk", + "dep:opentelemetry-jaeger", +] +jemalloc = [ + "dep:tikv-jemallocator", +] +panic_trap = [] +mods = [] + +[dependencies] +conduit-router.workspace = true +conduit-admin.workspace = true +conduit-api.workspace = true +conduit-service.workspace = true +conduit-database.workspace = true +conduit-core.workspace = true + +tokio.workspace = true +log.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true +clap.workspace = true +num_cpus.workspace = true + +opentelemetry.workspace = true +opentelemetry.optional = true +tracing-flame.workspace = true +tracing-flame.optional = true +tracing-opentelemetry.workspace = true +tracing-opentelemetry.optional = true +opentelemetry_sdk.workspace = true +opentelemetry_sdk.optional = true +opentelemetry-jaeger.workspace = true +opentelemetry-jaeger.optional = true + +sentry.workspace = true +sentry.optional = true +sentry-tracing.workspace = true +sentry-tracing.optional = true +sentry-tower.workspace = true +sentry-tower.optional = true + +tikv-jemallocator.workspace = true +tikv-jemallocator.optional = true + +tokio-metrics.workspace = true +tokio-metrics.optional = true + +console-subscriber.workspace = true +console-subscriber.optional = true + +[target.'cfg(all(not(target_env = "msvc"), target_os = "linux"))'.dependencies] +hardened_malloc-rs.workspace = true +hardened_malloc-rs.optional = true + +[lints] +workspace = true + +[[bin]] +name = "conduit" +path = "main.rs" diff --git a/src/bin/main.rs b/src/bin/main.rs new file mode 100644 index 00000000..0d049fcb --- /dev/null +++ b/src/bin/main.rs @@ -0,0 +1,96 @@ +mod mods; +mod server; + +extern crate conduit_core as conduit; + +use std::{cmp, sync::Arc, time::Duration}; + +use conduit::{debug_info, error, utils::clap, Error, Result}; +use server::Server; +use tokio::runtime; + +const WORKER_NAME: &str = "conduwuit:worker"; +const WORKER_MIN: usize = 2; +const WORKER_KEEPALIVE_MS: u64 = 2500; + +fn main() -> Result<(), Error> { + let args = clap::parse(); + let runtime = runtime::Builder::new_multi_thread() + .enable_io() + .enable_time() + .thread_name(WORKER_NAME) + .worker_threads(cmp::max(WORKER_MIN, num_cpus::get())) + .thread_keep_alive(Duration::from_millis(WORKER_KEEPALIVE_MS)) + .build() + .expect("built runtime"); + + let handle = runtime.handle(); + let server: Arc = Server::build(args, Some(handle))?; + runtime.block_on(async { async_main(server.clone()).await })?; + + // explicit drop here to trace thread and tls dtors + drop(runtime); + + debug_info!("Exit"); + Ok(()) +} + +/// Operate the server normally in release-mode static builds. This will start, +/// run and stop the server within the asynchronous runtime. +#[cfg(not(feature = "mods"))] +async fn async_main(server: Arc) -> Result<(), Error> { + extern crate conduit_router as router; + use tracing::error; + + if let Err(error) = router::start(&server.server).await { + error!("Critical error starting server: {error}"); + return Err(error); + } + + if let Err(error) = router::run(&server.server).await { + error!("Critical error running server: {error}"); + return Err(error); + } + + if let Err(error) = router::stop(&server.server).await { + error!("Critical error stopping server: {error}"); + return Err(error); + } + + debug_info!("Exit runtime"); + Ok(()) +} + +/// Operate the server in developer-mode dynamic builds. This will start, run, +/// and hot-reload portions of the server as-needed before returning for an +/// actual shutdown. This is not available in release-mode or static builds. +#[cfg(feature = "mods")] +async fn async_main(server: Arc) -> Result<(), Error> { + let mut starts = true; + let mut reloads = true; + while reloads { + if let Err(error) = mods::open(&server).await { + error!("Loading router: {error}"); + return Err(error); + } + + let result = mods::run(&server, starts).await; + if let Ok(result) = result { + (starts, reloads) = result; + } + + let force = !reloads || result.is_err(); + if let Err(error) = mods::close(&server, force).await { + error!("Unloading router: {error}"); + return Err(error); + } + + if let Err(error) = result { + error!("{error}"); + return Err(error); + } + } + + debug_info!("Exit runtime"); + Ok(()) +} diff --git a/src/bin/mods.rs b/src/bin/mods.rs new file mode 100644 index 00000000..404fa467 --- /dev/null +++ b/src/bin/mods.rs @@ -0,0 +1,129 @@ +#![cfg(feature = "mods")] +#[cfg(not(any(clippy, debug_assertions, doctest, test)))] +compile_error!("Feature 'mods' is only available in developer builds"); + +use std::{ + future::Future, + pin::Pin, + sync::{atomic::Ordering, Arc}, +}; + +use conduit::{mods, Error, Result}; +use tracing::{debug, error}; + +use crate::Server; + +type RunFuncResult = Pin>>>; +type RunFuncProto = fn(&Arc) -> RunFuncResult; + +const RESTART_THRESH: &str = "conduit_service"; +const MODULE_NAMES: &[&str] = &[ + //"conduit_core", + "conduit_database", + "conduit_service", + "conduit_api", + "conduit_admin", + "conduit_router", +]; + +#[cfg(feature = "panic_trap")] +conduit::mod_init! {{ + conduit::debug::set_panic_trap(); +}} + +pub(crate) async fn run(server: &Arc, starts: bool) -> Result<(bool, bool), Error> { + let main_lock = server.mods.read().await; + let main_mod = (*main_lock).last().expect("main module loaded"); + if starts { + let start = main_mod.get::("start")?; + if let Err(error) = start(&server.server).await { + error!("Starting server: {error}"); + return Err(error); + } + } + let run = main_mod.get::("run")?; + if let Err(error) = run(&server.server).await { + error!("Running server: {error}"); + return Err(error); + } + let reloads = server.server.reload.swap(false, Ordering::AcqRel); + let stops = !reloads || stale(server).await? <= restart_thresh(); + let starts = reloads && stops; + if stops { + let stop = main_mod.get::("stop")?; + if let Err(error) = stop(&server.server).await { + error!("Stopping server: {error}"); + return Err(error); + } + } + + Ok((starts, reloads)) +} + +pub(crate) async fn open(server: &Arc) -> Result { + let mut mods_lock = server.mods.write().await; + let mods: &mut Vec = &mut mods_lock; + debug!( + available = %available(), + loaded = %mods.len(), + "Loading modules", + ); + + for (i, name) in MODULE_NAMES.iter().enumerate() { + if mods.get(i).is_none() { + mods.push(mods::Module::from_name(name)?); + } + } + + Ok(mods.len()) +} + +pub(crate) async fn close(server: &Arc, force: bool) -> Result { + let stale = stale_count(server).await; + let mut mods_lock = server.mods.write().await; + let mods: &mut Vec = &mut mods_lock; + debug!( + available = %available(), + loaded = %mods.len(), + stale = %stale, + force, + "Unloading modules", + ); + + while mods.last().is_some() { + let module = &mods.last().expect("module"); + if force || module.deleted()? { + mods.pop(); + } else { + break; + } + } + + Ok(mods.len()) +} + +async fn stale_count(server: &Arc) -> usize { + let watermark = stale(server).await.unwrap_or(available()); + available() - watermark +} + +async fn stale(server: &Arc) -> Result { + let mods_lock = server.mods.read().await; + let mods: &Vec = &mods_lock; + for (i, module) in mods.iter().enumerate() { + if module.deleted()? { + return Ok(i); + } + } + + Ok(mods.len()) +} + +fn restart_thresh() -> usize { + MODULE_NAMES + .iter() + .position(|&name| name.ends_with(RESTART_THRESH)) + .unwrap_or(MODULE_NAMES.len()) +} + +const fn available() -> usize { MODULE_NAMES.len() } diff --git a/src/bin/server.rs b/src/bin/server.rs new file mode 100644 index 00000000..e63c0dc0 --- /dev/null +++ b/src/bin/server.rs @@ -0,0 +1,186 @@ +use std::sync::Arc; + +use conduit::{ + conduwuit_version, + config::Config, + info, + log::{LogLevelReloadHandles, ReloadHandle}, + utils::{clap, maximize_fd_limit}, + Error, Result, +}; +use tokio::runtime; +use tracing_subscriber::{prelude::*, reload, EnvFilter, Registry}; + +/// Server runtime state; complete +pub(crate) struct Server { + /// Server runtime state; public portion + pub(crate) server: Arc, + + _tracing_flame_guard: TracingFlameGuard, + + #[cfg(feature = "sentry_telemetry")] + _sentry_guard: Option, + + // Module instances; TODO: move to mods::loaded mgmt vector + #[cfg(feature = "mods")] + pub(crate) mods: tokio::sync::RwLock>, +} + +impl Server { + pub(crate) fn build(args: clap::Args, runtime: Option<&runtime::Handle>) -> Result, Error> { + let config = Config::new(args.config)?; + #[cfg(feature = "sentry_telemetry")] + let sentry_guard = init_sentry(&config); + let (tracing_reload_handle, tracing_flame_guard) = init_tracing(&config); + + config.check()?; + #[cfg(unix)] + maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit"); + info!( + server_name = %config.server_name, + database_path = ?config.database_path, + log_levels = %config.log, + "{}", + conduwuit_version(), + ); + + Ok(Arc::new(Server { + server: Arc::new(conduit::Server::new(config, runtime.cloned(), tracing_reload_handle)), + + _tracing_flame_guard: tracing_flame_guard, + + #[cfg(feature = "sentry_telemetry")] + _sentry_guard: sentry_guard, + + #[cfg(feature = "mods")] + mods: tokio::sync::RwLock::new(Vec::new()), + })) + } +} + +#[cfg(feature = "sentry_telemetry")] +fn init_sentry(config: &Config) -> Option { + if !config.sentry { + return None; + } + + let sentry_endpoint = config + .sentry_endpoint + .as_ref() + .expect("init_sentry should only be called if sentry is enabled and this is not None") + .as_str(); + + let server_name = if config.sentry_send_server_name { + Some(config.server_name.to_string().into()) + } else { + None + }; + + Some(sentry::init(( + sentry_endpoint, + sentry::ClientOptions { + release: sentry::release_name!(), + traces_sample_rate: config.sentry_traces_sample_rate, + server_name, + ..Default::default() + }, + ))) +} + +#[cfg(feature = "perf_measurements")] +type TracingFlameGuard = Option>>; +#[cfg(not(feature = "perf_measurements"))] +type TracingFlameGuard = (); + +// clippy thinks the filter_layer clones are redundant if the next usage is +// behind a disabled feature. +#[allow(clippy::redundant_clone)] +fn init_tracing(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard) { + let registry = Registry::default(); + let fmt_layer = tracing_subscriber::fmt::Layer::new(); + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); + EnvFilter::try_new("warn").unwrap() + }, + }; + + let mut reload_handles = Vec:: + Send + Sync>>::new(); + let subscriber = registry; + + #[cfg(feature = "tokio_console")] + let subscriber = { + let console_layer = console_subscriber::spawn(); + subscriber.with(console_layer) + }; + + let (fmt_reload_filter, fmt_reload_handle) = reload::Layer::new(filter_layer.clone()); + reload_handles.push(Box::new(fmt_reload_handle)); + let subscriber = subscriber.with(fmt_layer.with_filter(fmt_reload_filter)); + + #[cfg(feature = "sentry_telemetry")] + let subscriber = { + let sentry_layer = sentry_tracing::layer(); + let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(filter_layer.clone()); + reload_handles.push(Box::new(sentry_reload_handle)); + subscriber.with(sentry_layer.with_filter(sentry_reload_filter)) + }; + + #[cfg(feature = "perf_measurements")] + let (subscriber, flame_guard) = { + let (flame_layer, flame_guard) = if config.tracing_flame { + let flame_filter = match EnvFilter::try_new(&config.tracing_flame_filter) { + Ok(flame_filter) => flame_filter, + Err(e) => panic!("tracing_flame_filter config value is invalid: {e}"), + }; + + let (flame_layer, flame_guard) = + match tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) { + Ok(ok) => ok, + Err(e) => { + panic!("failed to initialize tracing-flame: {e}"); + }, + }; + let flame_layer = flame_layer + .with_empty_samples(false) + .with_filter(flame_filter); + (Some(flame_layer), Some(flame_guard)) + } else { + (None, None) + }; + + let jaeger_layer = if config.allow_jaeger { + opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); + let tracer = opentelemetry_jaeger::new_agent_pipeline() + .with_auto_split_batch(true) + .with_service_name("conduwuit") + .install_batch(opentelemetry_sdk::runtime::Tokio) + .unwrap(); + let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); + + let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(filter_layer); + reload_handles.push(Box::new(jaeger_reload_handle)); + Some(telemetry.with_filter(jaeger_reload_filter)) + } else { + None + }; + + let subscriber = subscriber.with(flame_layer).with(jaeger_layer); + (subscriber, flame_guard) + }; + + #[cfg(not(feature = "perf_measurements"))] + #[cfg_attr(not(feature = "perf_measurements"), allow(clippy::let_unit_value))] + let flame_guard = (); + + tracing::subscriber::set_global_default(subscriber).unwrap(); + + #[cfg(all(feature = "tokio_console", feature = "release_max_log_level"))] + tracing::error!( + "'tokio_console' feature and 'release_max_log_level' feature are incompatible, because console-subscriber \ + needs access to trace-level events. 'release_max_log_level' must be disabled to use tokio-console." + ); + + (LogLevelReloadHandles::new(reload_handles), flame_guard) +} diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml new file mode 100644 index 00000000..89ec7248 --- /dev/null +++ b/src/core/Cargo.toml @@ -0,0 +1,133 @@ +[package] +name = "conduit_core" +version.workspace = true +edition.workspace = true + +[lib] +path = "mod.rs" +crate-type = [ + "rlib", +# "dylib", +] + +[features] +default = [ + "rocksdb", + "io_uring", + "jemalloc", + "gzip_compression", + "zstd_compression", + "brotli_compression", + "sentry_telemetry", + "release_max_log_level", +] + +dev_release_log_level = [] +release_max_log_level = [ + "tracing/max_level_trace", + "tracing/release_max_level_info", + "log/max_level_trace", + "log/release_max_level_info", +] +sqlite = [ + "dep:rusqlite", + "dep:parking_lot", + "dep:thread_local", +] +rocksdb = [ + "dep:rust-rocksdb", +] +jemalloc = [ + "dep:tikv-jemalloc-sys", + "dep:tikv-jemalloc-ctl", + "dep:tikv-jemallocator", + "rust-rocksdb/jemalloc", +] +jemalloc_prof = [ + "tikv-jemalloc-sys/profiling", +] +hardened_malloc = [ + "dep:hardened_malloc-rs" +] +io_uring = [ + "rust-rocksdb/io-uring", +] +zstd_compression = [ + "rust-rocksdb/zstd", +] +gzip_compression = [ + "reqwest/gzip", +] +brotli_compression = [ + "reqwest/brotli", +] +perf_measurements = [] +sentry_telemetry = [] +mods = [ + "dep:libloading" +] +panic_trap = [] + +[dependencies] +async-trait.workspace = true +axum-server.workspace = true +axum.workspace = true +base64.workspace = true +bytes.workspace = true +clap.workspace = true +cyborgtime.workspace = true +either.workspace = true +figment.workspace = true +futures-util.workspace = true +http-body-util.workspace = true +http.workspace = true +image.workspace = true +infer.workspace = true +ipaddress.workspace = true +itertools.workspace = true +libloading.workspace = true +libloading.optional = true +log.workspace = true +lru-cache.workspace = true +parking_lot.optional = true +parking_lot.workspace = true +rand.workspace = true +regex.workspace = true +reqwest.workspace = true +ring.workspace = true +ruma.workspace = true +rusqlite.optional = true +rusqlite.workspace = true +rust-rocksdb.optional = true +rust-rocksdb.workspace = true +sanitize-filename.workspace = true +serde_json.workspace = true +serde_regex.workspace = true +serde.workspace = true +serde_yaml.workspace = true +sha-1.workspace = true +thiserror.workspace = true +thread_local.optional = true +thread_local.workspace = true +tikv-jemallocator.optional = true +tikv-jemallocator.workspace = true +tikv-jemalloc-ctl.optional = true +tikv-jemalloc-ctl.workspace = true +tikv-jemalloc-sys.optional = true +tikv-jemalloc-sys.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true +url.workspace = true +zstd.optional = true +zstd.workspace = true + +[target.'cfg(unix)'.dependencies] +nix.workspace = true + +[target.'cfg(all(not(target_env = "msvc"), target_os = "linux"))'.dependencies] +hardened_malloc-rs.workspace = true +hardened_malloc-rs.optional = true + +[lints] +workspace = true diff --git a/src/core/alloc/default.rs b/src/core/alloc/default.rs new file mode 100644 index 00000000..6e4128bf --- /dev/null +++ b/src/core/alloc/default.rs @@ -0,0 +1,9 @@ +//! Default allocator with no special features + +/// Always returns the empty string +#[must_use] +pub fn memory_stats() -> String { Default::default() } + +/// Always returns the empty string +#[must_use] +pub fn memory_usage() -> String { Default::default() } diff --git a/src/core/alloc/hardened.rs b/src/core/alloc/hardened.rs new file mode 100644 index 00000000..4c9563cf --- /dev/null +++ b/src/core/alloc/hardened.rs @@ -0,0 +1,10 @@ +#[global_allocator] +static HMALLOC: hardened_malloc_rs::HardenedMalloc = hardened_malloc_rs::HardenedMalloc; + +#[must_use] +pub fn memory_usage() -> String { + String::default() //TODO: get usage +} + +#[must_use] +pub fn memory_stats() -> String { "Extended statistics are not available from hardened_malloc.".to_owned() } diff --git a/src/alloc/je.rs b/src/core/alloc/je.rs similarity index 95% rename from src/alloc/je.rs rename to src/core/alloc/je.rs index 1a33de79..4092d815 100644 --- a/src/alloc/je.rs +++ b/src/core/alloc/je.rs @@ -7,7 +7,8 @@ use tikv_jemallocator as jemalloc; #[global_allocator] static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc; -pub(crate) fn memory_usage() -> String { +#[must_use] +pub fn memory_usage() -> String { use mallctl::stats; let allocated = stats::allocated::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; let active = stats::active::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; @@ -21,7 +22,8 @@ pub(crate) fn memory_usage() -> String { ) } -pub(crate) fn memory_stats() -> String { +#[must_use] +pub fn memory_stats() -> String { const MAX_LENGTH: usize = 65536 - 4096; let opts_s = "d"; diff --git a/src/alloc/mod.rs b/src/core/alloc/mod.rs similarity index 80% rename from src/alloc/mod.rs rename to src/core/alloc/mod.rs index 6b7c89a1..ceb9f498 100644 --- a/src/alloc/mod.rs +++ b/src/core/alloc/mod.rs @@ -2,24 +2,24 @@ // jemalloc #[cfg(all(not(target_env = "msvc"), feature = "jemalloc", not(feature = "hardened_malloc")))] -mod je; +pub mod je; #[cfg(all(not(target_env = "msvc"), feature = "jemalloc", not(feature = "hardened_malloc")))] -pub(crate) use je::{memory_stats, memory_usage}; +pub use je::{memory_stats, memory_usage}; // hardened_malloc #[cfg(all(not(target_env = "msvc"), feature = "hardened_malloc", target_os = "linux", not(feature = "jemalloc")))] -mod hardened; +pub mod hardened; #[cfg(all(not(target_env = "msvc"), feature = "hardened_malloc", target_os = "linux", not(feature = "jemalloc")))] -pub(crate) use hardened::{memory_stats, memory_usage}; +pub use hardened::{memory_stats, memory_usage}; // default, enabled when none or multiple of the above are enabled #[cfg(any( not(any(feature = "jemalloc", feature = "hardened_malloc")), all(feature = "jemalloc", feature = "hardened_malloc"), ))] -mod default; +pub mod default; #[cfg(any( not(any(feature = "jemalloc", feature = "hardened_malloc")), all(feature = "jemalloc", feature = "hardened_malloc"), ))] -pub(crate) use default::{memory_stats, memory_usage}; +pub use default::{memory_stats, memory_usage}; diff --git a/src/config/check.rs b/src/core/config/check.rs similarity index 98% rename from src/config/check.rs rename to src/core/config/check.rs index 62a874d6..99ca3cfd 100644 --- a/src/config/check.rs +++ b/src/core/config/check.rs @@ -3,9 +3,9 @@ use std::path::Path; // not unix specific, just only for UNIX sockets stuff and use tracing::{debug, error, info, warn}; -use crate::{utils::error::Error, Config}; +use crate::{error::Error, Config}; -pub(crate) fn check(config: &Config) -> Result<(), Error> { +pub fn check(config: &Config) -> Result<(), Error> { config.warn_deprecated(); config.warn_unknown_key(); diff --git a/src/config/mod.rs b/src/core/config/mod.rs similarity index 80% rename from src/config/mod.rs rename to src/core/config/mod.rs index 8b84a9b1..4b85fb53 100644 --- a/src/config/mod.rs +++ b/src/core/config/mod.rs @@ -22,11 +22,12 @@ use serde::{de::IgnoredAny, Deserialize}; use tracing::{debug, error, warn}; use url::Url; -use self::{check::check, proxy::ProxyConfig}; -use crate::utils::error::Error; +pub use self::check::check; +use self::proxy::ProxyConfig; +use crate::error::Error; -pub(crate) mod check; -mod proxy; +pub mod check; +pub mod proxy; #[derive(Deserialize, Clone, Debug)] #[serde(transparent)] @@ -38,310 +39,310 @@ struct ListeningPort { /// all the config options for conduwuit #[derive(Clone, Debug, Deserialize)] #[allow(clippy::struct_excessive_bools)] -pub(crate) struct Config { +pub struct Config { /// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6) #[serde(default = "default_address")] - pub(crate) address: IpAddr, + pub address: IpAddr, /// default TCP port(s) conduwuit will listen on #[serde(default = "default_port")] port: ListeningPort, - pub(crate) tls: Option, - pub(crate) unix_socket_path: Option, + pub tls: Option, + pub unix_socket_path: Option, #[serde(default = "default_unix_socket_perms")] - pub(crate) unix_socket_perms: u32, - pub(crate) server_name: OwnedServerName, + pub unix_socket_perms: u32, + pub server_name: OwnedServerName, #[serde(default = "default_database_backend")] - pub(crate) database_backend: String, - pub(crate) database_path: PathBuf, - pub(crate) database_backup_path: Option, + pub database_backend: String, + pub database_path: PathBuf, + pub database_backup_path: Option, #[serde(default = "default_database_backups_to_keep")] - pub(crate) database_backups_to_keep: i16, + pub database_backups_to_keep: i16, #[serde(default = "default_db_cache_capacity_mb")] - pub(crate) db_cache_capacity_mb: f64, + pub db_cache_capacity_mb: f64, #[serde(default = "default_new_user_displayname_suffix")] - pub(crate) new_user_displayname_suffix: String, + pub new_user_displayname_suffix: String, #[serde(default)] - pub(crate) allow_check_for_updates: bool, + pub allow_check_for_updates: bool, #[serde(default = "default_pdu_cache_capacity")] - pub(crate) pdu_cache_capacity: u32, + pub pdu_cache_capacity: u32, #[serde(default = "default_conduit_cache_capacity_modifier")] - pub(crate) conduit_cache_capacity_modifier: f64, + pub conduit_cache_capacity_modifier: f64, #[serde(default = "default_auth_chain_cache_capacity")] - pub(crate) auth_chain_cache_capacity: u32, + pub auth_chain_cache_capacity: u32, #[serde(default = "default_shorteventid_cache_capacity")] - pub(crate) shorteventid_cache_capacity: u32, + pub shorteventid_cache_capacity: u32, #[serde(default = "default_eventidshort_cache_capacity")] - pub(crate) eventidshort_cache_capacity: u32, + pub eventidshort_cache_capacity: u32, #[serde(default = "default_shortstatekey_cache_capacity")] - pub(crate) shortstatekey_cache_capacity: u32, + pub shortstatekey_cache_capacity: u32, #[serde(default = "default_statekeyshort_cache_capacity")] - pub(crate) statekeyshort_cache_capacity: u32, + pub statekeyshort_cache_capacity: u32, #[serde(default = "default_server_visibility_cache_capacity")] - pub(crate) server_visibility_cache_capacity: u32, + pub server_visibility_cache_capacity: u32, #[serde(default = "default_user_visibility_cache_capacity")] - pub(crate) user_visibility_cache_capacity: u32, + pub user_visibility_cache_capacity: u32, #[serde(default = "default_stateinfo_cache_capacity")] - pub(crate) stateinfo_cache_capacity: u32, + pub stateinfo_cache_capacity: u32, #[serde(default = "default_roomid_spacehierarchy_cache_capacity")] - pub(crate) roomid_spacehierarchy_cache_capacity: u32, + pub roomid_spacehierarchy_cache_capacity: u32, #[serde(default = "default_cleanup_second_interval")] - pub(crate) cleanup_second_interval: u32, + pub cleanup_second_interval: u32, #[serde(default = "default_dns_cache_entries")] - pub(crate) dns_cache_entries: u32, + pub dns_cache_entries: u32, #[serde(default = "default_dns_min_ttl")] - pub(crate) dns_min_ttl: u64, + pub dns_min_ttl: u64, #[serde(default = "default_dns_min_ttl_nxdomain")] - pub(crate) dns_min_ttl_nxdomain: u64, + pub dns_min_ttl_nxdomain: u64, #[serde(default = "default_dns_attempts")] - pub(crate) dns_attempts: u16, + pub dns_attempts: u16, #[serde(default = "default_dns_timeout")] - pub(crate) dns_timeout: u64, + pub dns_timeout: u64, #[serde(default = "true_fn")] - pub(crate) dns_tcp_fallback: bool, + pub dns_tcp_fallback: bool, #[serde(default = "true_fn")] - pub(crate) query_all_nameservers: bool, + pub query_all_nameservers: bool, #[serde(default)] - pub(crate) query_over_tcp_only: bool, + pub query_over_tcp_only: bool, #[serde(default = "default_ip_lookup_strategy")] - pub(crate) ip_lookup_strategy: u8, + pub ip_lookup_strategy: u8, #[serde(default = "default_max_request_size")] - pub(crate) max_request_size: u32, + pub max_request_size: u32, #[serde(default = "default_max_fetch_prev_events")] - pub(crate) max_fetch_prev_events: u16, + pub max_fetch_prev_events: u16, #[serde(default = "default_request_conn_timeout")] - pub(crate) request_conn_timeout: u64, + pub request_conn_timeout: u64, #[serde(default = "default_request_timeout")] - pub(crate) request_timeout: u64, + pub request_timeout: u64, #[serde(default = "default_request_total_timeout")] - pub(crate) request_total_timeout: u64, + pub request_total_timeout: u64, #[serde(default = "default_request_idle_timeout")] - pub(crate) request_idle_timeout: u64, + pub request_idle_timeout: u64, #[serde(default = "default_request_idle_per_host")] - pub(crate) request_idle_per_host: u16, + pub request_idle_per_host: u16, #[serde(default = "default_well_known_conn_timeout")] - pub(crate) well_known_conn_timeout: u64, + pub well_known_conn_timeout: u64, #[serde(default = "default_well_known_timeout")] - pub(crate) well_known_timeout: u64, + pub well_known_timeout: u64, #[serde(default = "default_federation_timeout")] - pub(crate) federation_timeout: u64, + pub federation_timeout: u64, #[serde(default = "default_federation_idle_timeout")] - pub(crate) federation_idle_timeout: u64, + pub federation_idle_timeout: u64, #[serde(default = "default_federation_idle_per_host")] - pub(crate) federation_idle_per_host: u16, + pub federation_idle_per_host: u16, #[serde(default = "default_sender_timeout")] - pub(crate) sender_timeout: u64, + pub sender_timeout: u64, #[serde(default = "default_sender_idle_timeout")] - pub(crate) sender_idle_timeout: u64, + pub sender_idle_timeout: u64, #[serde(default = "default_sender_retry_backoff_limit")] - pub(crate) sender_retry_backoff_limit: u64, + pub sender_retry_backoff_limit: u64, #[serde(default = "default_appservice_timeout")] - pub(crate) appservice_timeout: u64, + pub appservice_timeout: u64, #[serde(default = "default_appservice_idle_timeout")] - pub(crate) appservice_idle_timeout: u64, + pub appservice_idle_timeout: u64, #[serde(default = "default_pusher_idle_timeout")] - pub(crate) pusher_idle_timeout: u64, + pub pusher_idle_timeout: u64, #[serde(default)] - pub(crate) allow_registration: bool, + pub allow_registration: bool, #[serde(default)] - pub(crate) yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool, - pub(crate) registration_token: Option, + pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool, + pub registration_token: Option, #[serde(default = "true_fn")] - pub(crate) allow_encryption: bool, + pub allow_encryption: bool, #[serde(default = "true_fn")] - pub(crate) allow_federation: bool, + pub allow_federation: bool, #[serde(default)] - pub(crate) allow_public_room_directory_over_federation: bool, + pub allow_public_room_directory_over_federation: bool, #[serde(default)] - pub(crate) allow_public_room_directory_without_auth: bool, + pub allow_public_room_directory_without_auth: bool, #[serde(default)] - pub(crate) lockdown_public_room_directory: bool, + pub lockdown_public_room_directory: bool, #[serde(default)] - pub(crate) allow_device_name_federation: bool, + pub allow_device_name_federation: bool, #[serde(default = "true_fn")] - pub(crate) allow_profile_lookup_federation_requests: bool, + pub allow_profile_lookup_federation_requests: bool, #[serde(default = "true_fn")] - pub(crate) allow_room_creation: bool, + pub allow_room_creation: bool, #[serde(default = "true_fn")] - pub(crate) allow_unstable_room_versions: bool, + pub allow_unstable_room_versions: bool, #[serde(default = "default_default_room_version")] - pub(crate) default_room_version: RoomVersionId, + pub default_room_version: RoomVersionId, #[serde(default)] - pub(crate) well_known: WellKnownConfig, + pub well_known: WellKnownConfig, #[serde(default)] #[cfg(feature = "perf_measurements")] - pub(crate) allow_jaeger: bool, + pub allow_jaeger: bool, #[serde(default)] #[cfg(feature = "perf_measurements")] - pub(crate) tracing_flame: bool, + pub tracing_flame: bool, #[serde(default = "default_tracing_flame_filter")] #[cfg(feature = "perf_measurements")] - pub(crate) tracing_flame_filter: String, + pub tracing_flame_filter: String, #[serde(default = "default_tracing_flame_output_path")] #[cfg(feature = "perf_measurements")] - pub(crate) tracing_flame_output_path: String, + pub tracing_flame_output_path: String, #[serde(default)] - pub(crate) proxy: ProxyConfig, - pub(crate) jwt_secret: Option, + pub proxy: ProxyConfig, + pub jwt_secret: Option, #[serde(default = "default_trusted_servers")] - pub(crate) trusted_servers: Vec, + pub trusted_servers: Vec, #[serde(default = "true_fn")] - pub(crate) query_trusted_key_servers_first: bool, + pub query_trusted_key_servers_first: bool, #[serde(default = "default_log")] - pub(crate) log: String, + pub log: String, #[serde(default)] - pub(crate) turn_username: String, + pub turn_username: String, #[serde(default)] - pub(crate) turn_password: String, + pub turn_password: String, #[serde(default = "Vec::new")] - pub(crate) turn_uris: Vec, + pub turn_uris: Vec, #[serde(default)] - pub(crate) turn_secret: String, + pub turn_secret: String, #[serde(default = "default_turn_ttl")] - pub(crate) turn_ttl: u64, + pub turn_ttl: u64, #[serde(default = "Vec::new")] - pub(crate) auto_join_rooms: Vec, + pub auto_join_rooms: Vec, #[serde(default)] - pub(crate) auto_deactivate_banned_room_attempts: bool, + pub auto_deactivate_banned_room_attempts: bool, #[serde(default = "default_rocksdb_log_level")] - pub(crate) rocksdb_log_level: String, + pub rocksdb_log_level: String, #[serde(default)] - pub(crate) rocksdb_log_stderr: bool, + pub rocksdb_log_stderr: bool, #[serde(default = "default_rocksdb_max_log_file_size")] - pub(crate) rocksdb_max_log_file_size: usize, + pub rocksdb_max_log_file_size: usize, #[serde(default = "default_rocksdb_log_time_to_roll")] - pub(crate) rocksdb_log_time_to_roll: usize, + pub rocksdb_log_time_to_roll: usize, #[serde(default)] - pub(crate) rocksdb_optimize_for_spinning_disks: bool, + pub rocksdb_optimize_for_spinning_disks: bool, #[serde(default = "true_fn")] - pub(crate) rocksdb_direct_io: bool, + pub rocksdb_direct_io: bool, #[serde(default = "default_rocksdb_parallelism_threads")] - pub(crate) rocksdb_parallelism_threads: usize, + pub rocksdb_parallelism_threads: usize, #[serde(default = "default_rocksdb_max_log_files")] - pub(crate) rocksdb_max_log_files: usize, + pub rocksdb_max_log_files: usize, #[serde(default = "default_rocksdb_compression_algo")] - pub(crate) rocksdb_compression_algo: String, + pub rocksdb_compression_algo: String, #[serde(default = "default_rocksdb_compression_level")] - pub(crate) rocksdb_compression_level: i32, + pub rocksdb_compression_level: i32, #[serde(default = "default_rocksdb_bottommost_compression_level")] - pub(crate) rocksdb_bottommost_compression_level: i32, + pub rocksdb_bottommost_compression_level: i32, #[serde(default)] - pub(crate) rocksdb_bottommost_compression: bool, + pub rocksdb_bottommost_compression: bool, #[serde(default = "default_rocksdb_recovery_mode")] - pub(crate) rocksdb_recovery_mode: u8, + pub rocksdb_recovery_mode: u8, #[serde(default)] - pub(crate) rocksdb_repair: bool, + pub rocksdb_repair: bool, #[serde(default)] - pub(crate) rocksdb_read_only: bool, + pub rocksdb_read_only: bool, #[serde(default)] - pub(crate) rocksdb_periodic_cleanup: bool, + pub rocksdb_periodic_cleanup: bool, #[serde(default)] - pub(crate) rocksdb_compaction_prio_idle: bool, + pub rocksdb_compaction_prio_idle: bool, #[serde(default = "true_fn")] - pub(crate) rocksdb_compaction_ioprio_idle: bool, + pub rocksdb_compaction_ioprio_idle: bool, - pub(crate) emergency_password: Option, + pub emergency_password: Option, #[serde(default = "default_notification_push_path")] - pub(crate) notification_push_path: String, + pub notification_push_path: String, #[serde(default = "true_fn")] - pub(crate) allow_local_presence: bool, + pub allow_local_presence: bool, #[serde(default = "true_fn")] - pub(crate) allow_incoming_presence: bool, + pub allow_incoming_presence: bool, #[serde(default = "true_fn")] - pub(crate) allow_outgoing_presence: bool, + pub allow_outgoing_presence: bool, #[serde(default = "default_presence_idle_timeout_s")] - pub(crate) presence_idle_timeout_s: u64, + pub presence_idle_timeout_s: u64, #[serde(default = "default_presence_offline_timeout_s")] - pub(crate) presence_offline_timeout_s: u64, + pub presence_offline_timeout_s: u64, #[serde(default = "true_fn")] - pub(crate) presence_timeout_remote_users: bool, + pub presence_timeout_remote_users: bool, #[serde(default = "true_fn")] - pub(crate) allow_incoming_read_receipts: bool, + pub allow_incoming_read_receipts: bool, #[serde(default = "true_fn")] - pub(crate) allow_outgoing_read_receipts: bool, + pub allow_outgoing_read_receipts: bool, #[serde(default = "true_fn")] - pub(crate) allow_outgoing_typing: bool, + pub allow_outgoing_typing: bool, #[serde(default = "true_fn")] - pub(crate) allow_incoming_typing: bool, + pub allow_incoming_typing: bool, #[serde(default = "default_typing_federation_timeout_s")] - pub(crate) typing_federation_timeout_s: u64, + pub typing_federation_timeout_s: u64, #[serde(default = "default_typing_client_timeout_min_s")] - pub(crate) typing_client_timeout_min_s: u64, + pub typing_client_timeout_min_s: u64, #[serde(default = "default_typing_client_timeout_max_s")] - pub(crate) typing_client_timeout_max_s: u64, + pub typing_client_timeout_max_s: u64, #[serde(default)] - pub(crate) zstd_compression: bool, + pub zstd_compression: bool, #[serde(default)] - pub(crate) gzip_compression: bool, + pub gzip_compression: bool, #[serde(default)] - pub(crate) brotli_compression: bool, + pub brotli_compression: bool, #[serde(default)] - pub(crate) allow_guest_registration: bool, + pub allow_guest_registration: bool, #[serde(default)] - pub(crate) log_guest_registrations: bool, + pub log_guest_registrations: bool, #[serde(default)] - pub(crate) allow_guests_auto_join_rooms: bool, + pub allow_guests_auto_join_rooms: bool, #[serde(default = "Vec::new")] - pub(crate) prevent_media_downloads_from: Vec, + pub prevent_media_downloads_from: Vec, #[serde(default = "Vec::new")] - pub(crate) forbidden_remote_server_names: Vec, + pub forbidden_remote_server_names: Vec, #[serde(default = "Vec::new")] - pub(crate) forbidden_remote_room_directory_server_names: Vec, + pub forbidden_remote_room_directory_server_names: Vec, #[serde(default = "default_ip_range_denylist")] - pub(crate) ip_range_denylist: Vec, + pub ip_range_denylist: Vec, #[serde(default = "Vec::new")] - pub(crate) url_preview_domain_contains_allowlist: Vec, + pub url_preview_domain_contains_allowlist: Vec, #[serde(default = "Vec::new")] - pub(crate) url_preview_domain_explicit_allowlist: Vec, + pub url_preview_domain_explicit_allowlist: Vec, #[serde(default = "Vec::new")] - pub(crate) url_preview_domain_explicit_denylist: Vec, + pub url_preview_domain_explicit_denylist: Vec, #[serde(default = "Vec::new")] - pub(crate) url_preview_url_contains_allowlist: Vec, + pub url_preview_url_contains_allowlist: Vec, #[serde(default = "default_url_preview_max_spider_size")] - pub(crate) url_preview_max_spider_size: usize, + pub url_preview_max_spider_size: usize, #[serde(default)] - pub(crate) url_preview_check_root_domain: bool, + pub url_preview_check_root_domain: bool, #[serde(default = "RegexSet::empty")] #[serde(with = "serde_regex")] - pub(crate) forbidden_alias_names: RegexSet, + pub forbidden_alias_names: RegexSet, #[serde(default = "RegexSet::empty")] #[serde(with = "serde_regex")] - pub(crate) forbidden_usernames: RegexSet, + pub forbidden_usernames: RegexSet, #[serde(default = "true_fn")] - pub(crate) startup_netburst: bool, + pub startup_netburst: bool, #[serde(default = "default_startup_netburst_keep")] - pub(crate) startup_netburst_keep: i64, + pub startup_netburst_keep: i64, #[serde(default)] - pub(crate) block_non_admin_invites: bool, + pub block_non_admin_invites: bool, #[serde(default)] - pub(crate) sentry: bool, + pub sentry: bool, #[serde(default = "default_sentry_endpoint")] - pub(crate) sentry_endpoint: Option, + pub sentry_endpoint: Option, #[serde(default)] - pub(crate) sentry_send_server_name: bool, + pub sentry_send_server_name: bool, #[serde(default = "default_sentry_traces_sample_rate")] - pub(crate) sentry_traces_sample_rate: f32, + pub sentry_traces_sample_rate: f32, #[serde(flatten)] #[allow(clippy::zero_sized_map_values)] // this is a catchall, the map shouldn't be zero at runtime @@ -349,24 +350,24 @@ pub(crate) struct Config { } #[derive(Clone, Debug, Deserialize)] -pub(crate) struct TlsConfig { - pub(crate) certs: String, - pub(crate) key: String, +pub struct TlsConfig { + pub certs: String, + pub key: String, #[serde(default)] /// Whether to listen and allow for HTTP and HTTPS connections (insecure!) /// Only works / does something if the `axum_dual_protocol` feature flag was /// built - pub(crate) dual_protocol: bool, + pub dual_protocol: bool, } #[derive(Clone, Debug, Deserialize, Default)] -pub(crate) struct WellKnownConfig { - pub(crate) client: Option, - pub(crate) server: Option, - pub(crate) support_page: Option, - pub(crate) support_role: Option, - pub(crate) support_email: Option, - pub(crate) support_mxid: Option, +pub struct WellKnownConfig { + pub client: Option, + pub server: Option, + pub support_page: Option, + pub support_role: Option, + pub support_email: Option, + pub support_mxid: Option, } const DEPRECATED_KEYS: &[&str] = &[ @@ -382,7 +383,7 @@ const DEPRECATED_KEYS: &[&str] = &[ impl Config { /// Initialize config - pub(crate) fn new(path: Option) -> Result { + pub fn new(path: Option) -> Result { let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { Figment::new() .merge(Toml::file(config_file_env).nested()) @@ -469,7 +470,7 @@ impl Config { } #[must_use] - pub(crate) fn get_bind_addrs(&self) -> Vec { + pub fn get_bind_addrs(&self) -> Vec { match &self.port.ports { Left(port) => { // Left is only 1 value, so make a vec with 1 value only @@ -489,7 +490,7 @@ impl Config { } } - pub(crate) fn check(&self) -> Result<(), Error> { check(self) } + pub fn check(&self) -> Result<(), Error> { check(self) } } impl fmt::Display for Config { @@ -1027,7 +1028,8 @@ fn default_rocksdb_compression_level() -> i32 { 32767 } fn default_rocksdb_bottommost_compression_level() -> i32 { 32767 } // I know, it's a great name -pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } +#[must_use] +pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } fn default_ip_range_denylist() -> Vec { vec![ diff --git a/src/config/proxy.rs b/src/core/config/proxy.rs similarity index 95% rename from src/config/proxy.rs rename to src/core/config/proxy.rs index 691c3394..f41a92f6 100644 --- a/src/config/proxy.rs +++ b/src/core/config/proxy.rs @@ -30,7 +30,7 @@ use crate::Result; /// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. #[derive(Clone, Default, Debug, Deserialize)] #[serde(rename_all = "snake_case")] -pub(crate) enum ProxyConfig { +pub enum ProxyConfig { #[default] None, Global { @@ -40,7 +40,7 @@ pub(crate) enum ProxyConfig { ByDomain(Vec), } impl ProxyConfig { - pub(crate) fn to_proxy(&self) -> Result> { + pub fn to_proxy(&self) -> Result> { Ok(match self.clone() { ProxyConfig::None => None, ProxyConfig::Global { @@ -55,7 +55,7 @@ impl ProxyConfig { } #[derive(Clone, Debug, Deserialize)] -pub(crate) struct PartialProxyConfig { +pub struct PartialProxyConfig { #[serde(deserialize_with = "crate::utils::deserialize_from_str")] url: Url, #[serde(default)] @@ -64,7 +64,8 @@ pub(crate) struct PartialProxyConfig { exclude: Vec, } impl PartialProxyConfig { - pub(crate) fn for_url(&self, url: &Url) -> Option<&Url> { + #[must_use] + pub fn for_url(&self, url: &Url) -> Option<&Url> { let domain = url.domain()?; let mut included_because = None; // most specific reason it was included let mut excluded_because = None; // most specific reason it was excluded diff --git a/src/utils/debug.rs b/src/core/debug.rs similarity index 66% rename from src/utils/debug.rs rename to src/core/debug.rs index 3974ae67..fa998265 100644 --- a/src/utils/debug.rs +++ b/src/core/debug.rs @@ -1,3 +1,7 @@ +#![allow(dead_code)] // this is a developer's toolbox + +use std::{panic, panic::PanicInfo}; + /// Log event at given level in debug-mode (when debug-assertions are enabled). /// In release-mode it becomes DEBUG level, and possibly subject to elision. /// @@ -43,3 +47,32 @@ macro_rules! debug_info { $crate::debug_event!(tracing::Level::INFO, $($x)+ ); } } + +pub fn set_panic_trap() { + let next = panic::take_hook(); + panic::set_hook(Box::new(move |info| { + panic_handler(info, &next); + })); +} + +#[inline(always)] +fn panic_handler(info: &PanicInfo<'_>, next: &dyn Fn(&PanicInfo<'_>)) { + trap(); + next(info); +} + +#[inline(always)] +#[allow(unexpected_cfgs)] +pub fn trap() { + #[cfg(core_intrinsics)] + //SAFETY: embeds llvm intrinsic for hardware breakpoint + unsafe { + std::intrinsics::breakpoint(); + } + + #[cfg(all(not(core_intrinsics), target_arch = "x86_64"))] + //SAFETY: embeds instruction for hardware breakpoint + unsafe { + std::arch::asm!("int3"); + } +} diff --git a/src/utils/error.rs b/src/core/error.rs similarity index 78% rename from src/utils/error.rs rename to src/core/error.rs index 04ed6bf3..5a671da4 100644 --- a/src/utils/error.rs +++ b/src/core/error.rs @@ -1,10 +1,16 @@ use std::{convert::Infallible, fmt}; +use axum::response::{IntoResponse, Response}; +use bytes::BytesMut; use http::StatusCode; +use http_body_util::Full; use ruma::{ - api::client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::{UiaaInfo, UiaaResponse}, + api::{ + client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::{UiaaInfo, UiaaResponse}, + }, + OutgoingResponse, }, OwnedServerName, }; @@ -15,12 +21,10 @@ use ErrorKind::{ TooLarge, Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated, WrongRoomKeysVersion, }; -use crate::RumaResponse; - -pub(crate) type Result = std::result::Result; +pub type Result = std::result::Result; #[derive(Error)] -pub(crate) enum Error { +pub enum Error { #[cfg(feature = "sqlite")] #[error("There was a problem with the connection to the sqlite database: {source}")] Sqlite { @@ -83,17 +87,70 @@ pub(crate) enum Error { } impl Error { - pub(crate) fn bad_database(message: &'static str) -> Self { + pub fn bad_database(message: &'static str) -> Self { error!("BadDatabase: {}", message); Self::BadDatabase(message) } - pub(crate) fn bad_config(message: &str) -> Self { + pub fn bad_config(message: &str) -> Self { error!("BadConfig: {}", message); Self::BadConfig(message.to_owned()) } - pub(crate) fn to_response(&self) -> RumaResponse { + /// Returns the Matrix error code / error kind + pub fn error_code(&self) -> ErrorKind { + if let Self::Federation(_, error) = self { + return error.error_kind().unwrap_or_else(|| &Unknown).clone(); + } + + match self { + Self::BadRequest(kind, _) => kind.clone(), + _ => Unknown, + } + } + + /// Sanitizes public-facing errors that can leak sensitive information. + pub fn sanitized_error(&self) -> String { + let db_error = String::from("Database or I/O error occurred."); + + match self { + #[cfg(feature = "sqlite")] + Self::Sqlite { + .. + } => db_error, + #[cfg(feature = "rocksdb")] + Self::RocksDb { + .. + } => db_error, + Self::Io { + .. + } => db_error, + _ => self.to_string(), + } + } +} + +impl From for Error { + fn from(i: Infallible) -> Self { match i {} } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self) } +} + +#[derive(Clone)] +pub struct RumaResponse(pub T); + +impl From for RumaResponse { + fn from(t: T) -> Self { Self(t) } +} + +impl From for RumaResponse { + fn from(t: Error) -> Self { t.to_response() } +} + +impl Error { + pub fn to_response(&self) -> RumaResponse { if let Self::Uiaa(uiaainfo) = self { return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone())); } @@ -147,48 +204,17 @@ impl Error { status_code, })) } +} - /// Returns the Matrix error code / error kind - pub(crate) fn error_code(&self) -> ErrorKind { - if let Self::Federation(_, error) = self { - return error.error_kind().unwrap_or_else(|| &Unknown).clone(); - } +impl ::axum::response::IntoResponse for Error { + fn into_response(self) -> ::axum::response::Response { self.to_response().into_response() } +} - match self { - Self::BadRequest(kind, _) => kind.clone(), - _ => Unknown, - } - } - - /// Sanitizes public-facing errors that can leak sensitive information. - pub(crate) fn sanitized_error(&self) -> String { - let db_error = String::from("Database or I/O error occurred."); - - match self { - #[cfg(feature = "sqlite")] - Self::Sqlite { - .. - } => db_error, - #[cfg(feature = "rocksdb")] - Self::RocksDb { - .. - } => db_error, - Self::Io { - .. - } => db_error, - _ => self.to_string(), +impl IntoResponse for RumaResponse { + fn into_response(self) -> Response { + match self.0.try_into_http_response::() { + Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), + Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), } } } - -impl From for Error { - fn from(i: Infallible) -> Self { match i {} } -} - -impl axum::response::IntoResponse for Error { - fn into_response(self) -> axum::response::Response { self.to_response().into_response() } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self) } -} diff --git a/src/core/log.rs b/src/core/log.rs new file mode 100644 index 00000000..d31ca194 --- /dev/null +++ b/src/core/log.rs @@ -0,0 +1,79 @@ +use std::sync::Arc; + +use tracing_subscriber::{reload, EnvFilter}; + +/// We need to store a reload::Handle value, but can't name it's type explicitly +/// because the S type parameter depends on the subscriber's previous layers. In +/// our case, this includes unnameable 'impl Trait' types. +/// +/// This is fixed[1] in the unreleased tracing-subscriber from the master +/// branch, which removes the S parameter. Unfortunately can't use it without +/// pulling in a version of tracing that's incompatible with the rest of our +/// deps. +/// +/// To work around this, we define an trait without the S paramter that forwards +/// to the reload::Handle::reload method, and then store the handle as a trait +/// object. +/// +/// [1]: +pub trait ReloadHandle { + fn reload(&self, new_value: L) -> Result<(), reload::Error>; +} + +impl ReloadHandle for reload::Handle { + fn reload(&self, new_value: L) -> Result<(), reload::Error> { reload::Handle::reload(self, new_value) } +} + +struct LogLevelReloadHandlesInner { + handles: Vec + Send + Sync>>, +} + +/// Wrapper to allow reloading the filter on several several +/// [`tracing_subscriber::reload::Handle`]s at once, with the same value. +#[derive(Clone)] +pub struct LogLevelReloadHandles { + inner: Arc, +} + +impl LogLevelReloadHandles { + #[must_use] + pub fn new(handles: Vec + Send + Sync>>) -> LogLevelReloadHandles { + LogLevelReloadHandles { + inner: Arc::new(LogLevelReloadHandlesInner { + handles, + }), + } + } + + pub fn reload(&self, new_value: &EnvFilter) -> Result<(), reload::Error> { + for handle in &self.inner.handles { + handle.reload(new_value.clone())?; + } + Ok(()) + } +} + +#[macro_export] +macro_rules! error { + ( $($x:tt)+ ) => { tracing::error!( $($x)+ ); } +} + +#[macro_export] +macro_rules! warn { + ( $($x:tt)+ ) => { tracing::warn!( $($x)+ ); } +} + +#[macro_export] +macro_rules! info { + ( $($x:tt)+ ) => { tracing::info!( $($x)+ ); } +} + +#[macro_export] +macro_rules! debug { + ( $($x:tt)+ ) => { tracing::debug!( $($x)+ ); } +} + +#[macro_export] +macro_rules! trace { + ( $($x:tt)+ ) => { tracing::trace!( $($x)+ ); } +} diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 00000000..8e0a481b --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,27 @@ +pub mod alloc; +pub mod config; +pub mod debug; +pub mod error; +pub mod log; +pub mod mods; +pub mod pducount; +pub mod server; +pub mod utils; + +pub use config::Config; +pub use error::{Error, Result, RumaResponse}; +pub use pducount::PduCount; +pub use server::Server; +pub use utils::conduwuit_version; + +#[cfg(not(feature = "mods"))] +mod mods { + #[macro_export] + macro_rules! mod_ctor { + () => {}; + } + #[macro_export] + macro_rules! mod_dtor { + () => {}; + } +} diff --git a/src/core/mods/canary.rs b/src/core/mods/canary.rs new file mode 100644 index 00000000..6095608c --- /dev/null +++ b/src/core/mods/canary.rs @@ -0,0 +1,28 @@ +use std::sync::atomic::{AtomicI32, Ordering}; + +const ORDERING: Ordering = Ordering::Relaxed; +static STATIC_DTORS: AtomicI32 = AtomicI32::new(0); + +/// Called by Module::unload() to indicate module is about to be unloaded and +/// static destruction is intended. This will allow verifying it actually took +/// place. +pub(crate) fn prepare() { + let count = STATIC_DTORS.fetch_sub(1, ORDERING); + debug_assert!(count <= 0, "STATIC_DTORS should not be greater than zero."); +} + +/// Called by static destructor of a module. This call should only be found +/// inside a mod_fini! macro. Do not call from anywhere else. +#[inline(always)] +pub fn report() { let _count = STATIC_DTORS.fetch_add(1, ORDERING); } + +/// Called by Module::unload() (see check()) with action in case a check() +/// failed. This can allow a stuck module to be noted while allowing for other +/// independent modules to be diagnosed. +pub(crate) fn check_and_reset() -> bool { STATIC_DTORS.swap(0, ORDERING) == 0 } + +/// Called by Module::unload() after unload to verify static destruction took +/// place. A call to prepare() must be made prior to Module::unload() and making +/// this call. +#[allow(dead_code)] +pub(crate) fn check() -> bool { STATIC_DTORS.load(ORDERING) == 0 } diff --git a/src/core/mods/macros.rs b/src/core/mods/macros.rs new file mode 100644 index 00000000..aa0999c9 --- /dev/null +++ b/src/core/mods/macros.rs @@ -0,0 +1,44 @@ +#[macro_export] +macro_rules! mod_ctor { + ( $($body:block)? ) => { + $crate::mod_init! {{ + $crate::debug_info!("Module loaded"); + $($body)? + }} + } +} + +#[macro_export] +macro_rules! mod_dtor { + ( $($body:block)? ) => { + $crate::mod_fini! {{ + $crate::debug_info!("Module unloading"); + $($body)? + $crate::mods::canary::report(); + }} + } +} + +#[macro_export] +macro_rules! mod_init { + ($body:block) => { + #[used] + #[cfg_attr(target_family = "unix", link_section = ".init_array")] + static MOD_INIT: extern "C" fn() = { _mod_init }; + + #[cfg_attr(target_family = "unix", link_section = ".text.startup")] + extern "C" fn _mod_init() -> () $body + }; +} + +#[macro_export] +macro_rules! mod_fini { + ($body:block) => { + #[used] + #[cfg_attr(target_family = "unix", link_section = ".fini_array")] + static MOD_FINI: extern "C" fn() = { _mod_fini }; + + #[cfg_attr(target_family = "unix", link_section = ".text.startup")] + extern "C" fn _mod_fini() -> () $body + }; +} diff --git a/src/core/mods/mod.rs b/src/core/mods/mod.rs new file mode 100644 index 00000000..e60a0f5e --- /dev/null +++ b/src/core/mods/mod.rs @@ -0,0 +1,11 @@ +#![cfg(feature = "mods")] + +pub(crate) use libloading::os::unix::{Library, Symbol}; + +pub mod canary; +pub mod macros; +pub mod module; +pub mod new; +pub mod path; + +pub use module::Module; diff --git a/src/core/mods/module.rs b/src/core/mods/module.rs new file mode 100644 index 00000000..ff181e4f --- /dev/null +++ b/src/core/mods/module.rs @@ -0,0 +1,74 @@ +use std::{ + ffi::{CString, OsString}, + time::SystemTime, +}; + +use super::{canary, new, path, Library, Symbol}; +use crate::{error, Result}; + +pub struct Module { + handle: Option, + loaded: SystemTime, + path: OsString, +} + +impl Module { + pub fn from_name(name: &str) -> Result { Self::from_path(path::from_name(name)?) } + + pub fn from_path(path: OsString) -> Result { + Ok(Self { + handle: Some(new::from_path(&path)?), + loaded: SystemTime::now(), + path, + }) + } + + pub fn unload(&mut self) { + canary::prepare(); + self.close(); + if !canary::check_and_reset() { + let name = self.name().expect("Module is named"); + error!("Module {name:?} is stuck and failed to unload."); + } + } + + pub(crate) fn close(&mut self) { + if let Some(handle) = self.handle.take() { + handle.close().expect("Module handle closed"); + } + } + + pub fn get(&self, name: &str) -> Result> { + let cname = CString::new(name.to_owned()).expect("terminated string from provided name"); + let handle = self + .handle + .as_ref() + .expect("backing library loaded by this instance"); + // SAFETY: Calls dlsym(3) on unix platforms. This might not have to be unsafe + // if wrapped in libloading with_dlerror(). + let sym = unsafe { handle.get::(cname.as_bytes()) }; + let sym = sym.expect("symbol found; binding successful"); + + Ok(sym) + } + + pub fn deleted(&self) -> Result { + let mtime = path::mtime(self.path())?; + let res = mtime > self.loaded; + + Ok(res) + } + + pub fn name(&self) -> Result { path::to_name(self.path()) } + + #[must_use] + pub fn path(&self) -> &OsString { &self.path } +} + +impl Drop for Module { + fn drop(&mut self) { + if self.handle.is_some() { + self.unload(); + } + } +} diff --git a/src/core/mods/new.rs b/src/core/mods/new.rs new file mode 100644 index 00000000..99d7756a --- /dev/null +++ b/src/core/mods/new.rs @@ -0,0 +1,23 @@ +use std::ffi::OsStr; + +use super::{path, Library}; +use crate::{Error, Result}; + +const OPEN_FLAGS: i32 = libloading::os::unix::RTLD_LAZY | libloading::os::unix::RTLD_GLOBAL; + +pub fn from_name(name: &str) -> Result { + let path = path::from_name(name)?; + from_path(&path) +} + +pub fn from_path(path: &OsStr) -> Result { + //SAFETY: Calls dlopen(3) on unix platforms. This might not have to be unsafe + // if wrapped in with_dlerror. + let lib = unsafe { Library::open(Some(path), OPEN_FLAGS) }; + if let Err(e) = lib { + let name = path::to_name(path)?; + return Err(Error::Err(format!("Loading module {name:?} failed: {e}"))); + } + + Ok(lib.expect("module loaded")) +} diff --git a/src/core/mods/path.rs b/src/core/mods/path.rs new file mode 100644 index 00000000..cde251b3 --- /dev/null +++ b/src/core/mods/path.rs @@ -0,0 +1,40 @@ +use std::{ + env::current_exe, + ffi::{OsStr, OsString}, + path::{Path, PathBuf}, + time::SystemTime, +}; + +use libloading::library_filename; + +use crate::Result; + +pub fn from_name(name: &str) -> Result { + let root = PathBuf::new(); + let exe_path = current_exe()?; + let exe_dir = exe_path.parent().unwrap_or(&root); + let mut mod_path = exe_dir.to_path_buf(); + let mod_file = library_filename(name); + mod_path.push(mod_file); + + Ok(mod_path.into_os_string()) +} + +pub fn to_name(path: &OsStr) -> Result { + let path = Path::new(path); + let name = path + .file_stem() + .expect("path file stem") + .to_str() + .expect("name string"); + let name = name.strip_prefix("lib").unwrap_or(name).to_owned(); + + Ok(name) +} + +pub fn mtime(path: &OsStr) -> Result { + let meta = std::fs::metadata(path)?; + let mtime = meta.modified()?; + + Ok(mtime) +} diff --git a/src/core/pducount.rs b/src/core/pducount.rs new file mode 100644 index 00000000..8adb4ca5 --- /dev/null +++ b/src/core/pducount.rs @@ -0,0 +1,51 @@ +use std::cmp::Ordering; + +use ruma::api::client::error::ErrorKind; + +use crate::{Error, Result}; + +#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] +pub enum PduCount { + Backfilled(u64), + Normal(u64), +} + +impl PduCount { + #[must_use] + pub fn min() -> Self { Self::Backfilled(u64::MAX) } + + #[must_use] + pub fn max() -> Self { Self::Normal(u64::MAX) } + + pub fn try_from_string(token: &str) -> Result { + if let Some(stripped_token) = token.strip_prefix('-') { + stripped_token.parse().map(PduCount::Backfilled) + } else { + token.parse().map(PduCount::Normal) + } + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) + } + + #[must_use] + pub fn stringify(&self) -> String { + match self { + PduCount::Backfilled(x) => format!("-{x}"), + PduCount::Normal(x) => x.to_string(), + } + } +} + +impl PartialOrd for PduCount { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } +} + +impl Ord for PduCount { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o), + (PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s), + (PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater, + (PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less, + } + } +} diff --git a/src/core/server.rs b/src/core/server.rs new file mode 100644 index 00000000..4bfff340 --- /dev/null +++ b/src/core/server.rs @@ -0,0 +1,72 @@ +use std::{ + sync::{ + atomic::{AtomicBool, AtomicU32}, + Mutex, + }, + time::SystemTime, +}; + +use tokio::runtime; + +use crate::{config::Config, log::LogLevelReloadHandles}; + +/// Server runtime state; public portion +pub struct Server { + /// Server-wide configuration instance + pub config: Config, + + /// Timestamp server was started; used for uptime. + pub started: SystemTime, + + /// Reload/shutdown signal channel. Called from the signal handler or admin + /// command to initiate shutdown. + pub shutdown: Mutex>, + + /// Reload/shutdown desired indicator; when false, shutdown is desired. This + /// is an observable used on shutdown and modifying is not recommended. + pub reload: AtomicBool, + + /// Reload/shutdown pending indicator; server is shutting down. This is an + /// observable used on shutdown and should not be modified. + pub interrupt: AtomicBool, + + /// Handle to the runtime + pub runtime: Option, + + /// Log level reload handles. + pub tracing_reload_handle: LogLevelReloadHandles, + + /// TODO: move stats + pub requests_spawn_active: AtomicU32, + pub requests_spawn_finished: AtomicU32, + pub requests_handle_active: AtomicU32, + pub requests_handle_finished: AtomicU32, + pub requests_panic: AtomicU32, +} + +impl Server { + #[must_use] + pub fn new(config: Config, runtime: Option, tracing_reload_handle: LogLevelReloadHandles) -> Self { + Self { + config, + started: SystemTime::now(), + shutdown: Mutex::new(None), + reload: AtomicBool::new(false), + interrupt: AtomicBool::new(false), + runtime, + tracing_reload_handle, + requests_spawn_active: AtomicU32::new(0), + requests_spawn_finished: AtomicU32::new(0), + requests_handle_active: AtomicU32::new(0), + requests_handle_finished: AtomicU32::new(0), + requests_panic: AtomicU32::new(0), + } + } + + #[inline] + pub fn runtime(&self) -> &runtime::Handle { + self.runtime + .as_ref() + .expect("runtime handle available in Server") + } +} diff --git a/src/utils/clap.rs b/src/core/utils/clap.rs similarity index 73% rename from src/utils/clap.rs rename to src/core/utils/clap.rs index 4c88d836..c1dcb586 100644 --- a/src/utils/clap.rs +++ b/src/core/utils/clap.rs @@ -2,19 +2,19 @@ use std::path::PathBuf; -use clap::Parser; +pub use clap::Parser; use super::conduwuit_version; /// Commandline arguments #[derive(Parser, Debug)] #[clap(version = conduwuit_version(), about, long_about = None)] -pub(crate) struct Args { +pub struct Args { #[arg(short, long)] /// Optional argument to the path of a conduwuit config TOML file - pub(crate) config: Option, + pub config: Option, } /// Parse commandline arguments into structured data #[must_use] -pub(crate) fn parse() -> Args { Args::parse() } +pub fn parse() -> Args { Args::parse() } diff --git a/src/utils/content_disposition.rs b/src/core/utils/content_disposition.rs similarity index 93% rename from src/utils/content_disposition.rs rename to src/core/utils/content_disposition.rs index 9ef93bbf..85828be7 100644 --- a/src/utils/content_disposition.rs +++ b/src/core/utils/content_disposition.rs @@ -17,8 +17,9 @@ const IMAGE_SVG_XML: &str = "image/svg+xml"; /// /// TODO: add a "strict" function for comparing the Content-Type with what we /// detected: `file_type.mime_type() != content_type` +#[must_use] #[tracing::instrument(skip(buf))] -pub(crate) fn content_disposition_type(buf: &[u8], content_type: &Option) -> &'static str { +pub fn content_disposition_type(buf: &[u8], content_type: &Option) -> &'static str { let Some(file_type) = infer::get(buf) else { return ATTACHMENT; }; @@ -41,8 +42,9 @@ pub(crate) fn content_disposition_type(buf: &[u8], content_type: &Option /// /// SVG is special-cased due to the MIME type being classified as `text/xml` but /// browsers need `image/svg+xml` +#[must_use] #[tracing::instrument(skip(buf))] -pub(crate) fn make_content_type(buf: &[u8], content_type: &Option) -> &'static str { +pub fn make_content_type(buf: &[u8], content_type: &Option) -> &'static str { let Some(file_type) = infer::get(buf) else { debug_info!("Failed to infer the file's contents"); return APPLICATION_OCTET_STREAM; @@ -62,7 +64,7 @@ pub(crate) fn make_content_type(buf: &[u8], content_type: &Option) -> &' /// sanitises the file name for the Content-Disposition using /// `sanitize_filename` crate #[tracing::instrument] -pub(crate) fn sanitise_filename(filename: String) -> String { +pub fn sanitise_filename(filename: String) -> String { let options = sanitize_filename::Options { truncate: false, ..Default::default() @@ -79,7 +81,7 @@ pub(crate) fn sanitise_filename(filename: String) -> String { /// /// else: `Content-Disposition: attachment/inline` #[tracing::instrument(skip(file))] -pub(crate) fn make_content_disposition( +pub fn make_content_disposition( file: &[u8], content_type: &Option, content_disposition: Option, ) -> String { let filename = content_disposition.map_or_else(String::new, |content_disposition| { diff --git a/src/core/utils/defer.rs b/src/core/utils/defer.rs new file mode 100644 index 00000000..2762d4fa --- /dev/null +++ b/src/core/utils/defer.rs @@ -0,0 +1,22 @@ +#[macro_export] +macro_rules! defer { + ($body:block) => { + struct _Defer_ + where + F: FnMut(), + { + closure: F, + } + + impl Drop for _Defer_ + where + F: FnMut(), + { + fn drop(&mut self) { (self.closure)(); } + } + + let _defer_ = _Defer_ { + closure: || $body, + }; + }; +} diff --git a/src/utils/mod.rs b/src/core/utils/mod.rs similarity index 72% rename from src/utils/mod.rs rename to src/core/utils/mod.rs index 2ceb0ff5..1cdb6727 100644 --- a/src/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,10 +1,3 @@ -pub(crate) mod clap; -pub(crate) mod content_disposition; -pub(crate) mod debug; -pub(crate) mod error; -pub(crate) mod server_name; -pub(crate) mod user_id; - use std::{ cmp, cmp::Ordering, @@ -13,24 +6,29 @@ use std::{ time::{SystemTime, UNIX_EPOCH}, }; -use argon2::{password_hash::SaltString, PasswordHasher}; use rand::prelude::*; use ring::digest; use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId}; +use tracing::debug; -use crate::{services, Error, Result}; +use crate::{Error, Result}; -pub(crate) fn clamp(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } +pub mod clap; +pub mod content_disposition; +pub mod defer; +pub fn clamp(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } + +#[must_use] #[allow(clippy::as_conversions)] -pub(crate) fn millis_since_unix_epoch() -> u64 { +pub fn millis_since_unix_epoch() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time is valid") .as_millis() as u64 } -pub(crate) fn increment(old: Option<&[u8]>) -> Vec { +pub fn increment(old: Option<&[u8]>) -> Vec { let number = match old.map(TryInto::try_into) { Some(Ok(bytes)) => { let number = u64::from_be_bytes(bytes); @@ -42,7 +40,8 @@ pub(crate) fn increment(old: Option<&[u8]>) -> Vec { number.to_be_bytes().to_vec() } -pub(crate) fn generate_keypair() -> Vec { +#[must_use] +pub fn generate_keypair() -> Vec { let mut value = random_string(8).as_bytes().to_vec(); value.push(0xFF); value.extend_from_slice( @@ -52,25 +51,25 @@ pub(crate) fn generate_keypair() -> Vec { } /// Parses the bytes into an u64. -pub(crate) fn u64_from_bytes(bytes: &[u8]) -> Result { +pub fn u64_from_bytes(bytes: &[u8]) -> Result { let array: [u8; 8] = bytes.try_into()?; Ok(u64::from_be_bytes(array)) } /// Parses the bytes into a string. -pub(crate) fn string_from_bytes(bytes: &[u8]) -> Result { +pub fn string_from_bytes(bytes: &[u8]) -> Result { String::from_utf8(bytes.to_vec()) } /// Parses a `OwnedUserId` from bytes. -pub(crate) fn user_id_from_bytes(bytes: &[u8]) -> Result { +pub fn user_id_from_bytes(bytes: &[u8]) -> Result { OwnedUserId::try_from( string_from_bytes(bytes).map_err(|_| Error::bad_database("Failed to parse string from bytes"))?, ) .map_err(|_| Error::bad_database("Failed to parse user id from bytes")) } -pub(crate) fn random_string(length: usize) -> String { +pub fn random_string(length: usize) -> String { thread_rng() .sample_iter(&rand::distributions::Alphanumeric) .take(length) @@ -78,25 +77,16 @@ pub(crate) fn random_string(length: usize) -> String { .collect() } -/// Calculate a new hash for the given password -pub(crate) fn calculate_password_hash(password: &str) -> Result { - let salt = SaltString::generate(thread_rng()); - services() - .globals - .argon - .hash_password(password.as_bytes(), &salt) - .map(|it| it.to_string()) -} - #[tracing::instrument(skip(keys))] -pub(crate) fn calculate_hash(keys: &[&[u8]]) -> Vec { +pub fn calculate_hash(keys: &[&[u8]]) -> Vec { // We only hash the pdu's event ids, not the whole pdu let bytes = keys.join(&0xFF); let hash = digest::digest(&digest::SHA256, &bytes); hash.as_ref().to_owned() } -pub(crate) fn common_elements( +#[allow(clippy::impl_trait_in_params)] +pub fn common_elements( mut iterators: impl Iterator>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, ) -> Option>> { let first_iterator = iterators.next()?; @@ -123,7 +113,7 @@ pub(crate) fn common_elements( /// `CanonicalJsonObject`. /// /// `value` must serialize to an `serde_json::Value::Object`. -pub(crate) fn to_canonical_object(value: T) -> Result { +pub fn to_canonical_object(value: T) -> Result { use serde::ser::Error; match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? { @@ -132,7 +122,7 @@ pub(crate) fn to_canonical_object(value: T) -> Result, T: FromStr, E: fmt::Display>( +pub fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: FromStr, E: fmt::Display>( deserializer: D, ) -> Result { struct Visitor, E>(std::marker::PhantomData); @@ -158,7 +148,7 @@ pub(crate) fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: From /// Wrapper struct which will emit the HTML-escaped version of the contained /// string when passed to a format string. -pub(crate) struct HtmlEscape<'a>(pub(crate) &'a str); +pub struct HtmlEscape<'a>(pub &'a str); impl fmt::Display for HtmlEscape<'_> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -196,7 +186,8 @@ impl fmt::Display for HtmlEscape<'_> { /// Set the environment variable `CONDUWUIT_VERSION_EXTRA` to any UTF-8 string /// to include it in parenthesis after the SemVer version. A common value are /// git commit hashes. -pub(crate) fn conduwuit_version() -> String { +#[must_use] +pub fn conduwuit_version() -> String { match option_env!("CONDUWUIT_VERSION_EXTRA") { Some(extra) => { if extra.is_empty() { @@ -222,7 +213,7 @@ pub(crate) fn conduwuit_version() -> String { /// Any further elements are replaced by an ellipsis. /// /// See also [`debug_slice_truncated()`], -pub(crate) struct TruncatedDebugSlice<'a, T> { +pub struct TruncatedDebugSlice<'a, T> { inner: &'a [T], max_len: usize, } @@ -243,11 +234,12 @@ impl fmt::Debug for TruncatedDebugSlice<'_, T> { /// See [`TruncatedDebugSlice`]. Useful for `#[instrument]`: /// /// ``` -/// #[tracing::instrument(fields( -/// foos = debug_slice_truncated(foos, N) -/// ))] +/// use conduit_core::utils::debug_slice_truncated; +/// +/// #[tracing::instrument(fields(foos = debug_slice_truncated(foos, 42)))] +/// fn bar(foos: &[&str]); /// ``` -pub(crate) fn debug_slice_truncated( +pub fn debug_slice_truncated( slice: &[T], max_len: usize, ) -> tracing::field::DebugValue> { tracing::field::debug(TruncatedDebugSlice { @@ -255,3 +247,24 @@ pub(crate) fn debug_slice_truncated( max_len, }) } + +/// This is needed for opening lots of file descriptors, which tends to +/// happen more often when using RocksDB and making lots of federation +/// connections at startup. The soft limit is usually 1024, and the hard +/// limit is usually 512000; I've personally seen it hit >2000. +/// +/// * +/// * +#[cfg(unix)] +pub fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { + use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE}; + + let (soft_limit, hard_limit) = getrlimit(NOFILE)?; + if soft_limit < hard_limit { + setrlimit(NOFILE, hard_limit, hard_limit)?; + assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit"); + debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",); + } + + Ok(()) +} diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml new file mode 100644 index 00000000..990a303b --- /dev/null +++ b/src/database/Cargo.toml @@ -0,0 +1,81 @@ +[package] +name = "conduit_database" +version.workspace = true +edition.workspace = true + +[lib] +path = "mod.rs" +crate-type = [ + "rlib", +# "dylib", +] + +[features] +default = [ + "rocksdb", + "io_uring", + "jemalloc", + "zstd_compression", + "release_max_log_level", +] + +dev_release_log_level = [] +release_max_log_level = [ + "tracing/max_level_trace", + "tracing/release_max_level_info", + "log/max_level_trace", + "log/release_max_level_info", +] +sqlite = [ + "dep:rusqlite", + "dep:parking_lot", + "dep:thread_local", +] +rocksdb = [ + "dep:rust-rocksdb", +] +jemalloc = [ + "dep:tikv-jemalloc-sys", + "dep:tikv-jemalloc-ctl", + "dep:tikv-jemallocator", + "rust-rocksdb/jemalloc", +] +jemalloc_prof = [ + "tikv-jemalloc-sys/profiling", +] +io_uring = [ + "rust-rocksdb/io-uring", +] +zstd_compression = [ + "rust-rocksdb/zstd", +] + +[dependencies] +chrono.workspace = true +conduit-core.workspace = true +futures-util.workspace = true +log.workspace = true +lru-cache.workspace = true +num_cpus.workspace = true +parking_lot.optional = true +parking_lot.workspace = true +ruma.workspace = true +rusqlite.optional = true +rusqlite.workspace = true +rust-rocksdb.optional = true +rust-rocksdb.workspace = true +thread_local.optional = true +thread_local.workspace = true +tikv-jemallocator.optional = true +tikv-jemallocator.workspace = true +tikv-jemalloc-ctl.optional = true +tikv-jemalloc-ctl.workspace = true +tikv-jemalloc-sys.optional = true +tikv-jemalloc-sys.workspace = true +tokio.workspace = true +tracing.workspace = true +zstd.optional = true +zstd.workspace = true + +[lints] +workspace = true diff --git a/src/database/cork.rs b/src/database/cork.rs index db7dfac2..752260a6 100644 --- a/src/database/cork.rs +++ b/src/database/cork.rs @@ -2,14 +2,14 @@ use std::sync::Arc; use super::KeyValueDatabaseEngine; -pub(crate) struct Cork { +pub struct Cork { db: Arc, flush: bool, sync: bool, } impl Cork { - pub(crate) fn new(db: &Arc, flush: bool, sync: bool) -> Self { + pub fn new(db: &Arc, flush: bool, sync: bool) -> Self { db.cork().unwrap(); Cork { db: db.clone(), diff --git a/src/database/kvdatabase.rs b/src/database/kvdatabase.rs new file mode 100644 index 00000000..ddbb8e92 --- /dev/null +++ b/src/database/kvdatabase.rs @@ -0,0 +1,320 @@ +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + path::Path, + sync::{Arc, Mutex, RwLock}, +}; + +use conduit::{Config, Error, PduCount, Result, Server}; +use lru_cache::LruCache; +use ruma::{CanonicalJsonValue, OwnedDeviceId, OwnedRoomId, OwnedUserId}; +use tracing::debug; + +use crate::{KeyValueDatabaseEngine, KvTree}; + +pub struct KeyValueDatabase { + pub db: Arc, + + //pub globals: globals::Globals, + pub global: Arc, + pub server_signingkeys: Arc, + + pub roomid_inviteviaservers: Arc, + + //pub users: users::Users, + pub userid_password: Arc, + pub userid_displayname: Arc, + pub userid_avatarurl: Arc, + pub userid_blurhash: Arc, + pub userdeviceid_token: Arc, + pub userdeviceid_metadata: Arc, // This is also used to check if a device exists + pub userid_devicelistversion: Arc, // DevicelistVersion = u64 + pub token_userdeviceid: Arc, + + pub onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId + pub userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count + pub keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count + pub keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) + pub userid_masterkeyid: Arc, + pub userid_selfsigningkeyid: Arc, + pub userid_usersigningkeyid: Arc, + + pub userfilterid_filter: Arc, // UserFilterId = UserId + FilterId + pub todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count + pub userid_presenceid: Arc, // UserId => Count + pub presenceid_presence: Arc, // Count + UserId => Presence + + //pub uiaa: uiaa::Uiaa, + pub userdevicesessionid_uiaainfo: Arc, // User-interactive authentication + pub userdevicesessionid_uiaarequest: RwLock>, + + //pub edus: RoomEdus, + pub readreceiptid_readreceipt: Arc, // ReadReceiptId = RoomId + Count + UserId + pub roomuserid_privateread: Arc, // RoomUserId = Room + User, PrivateRead = Count + pub roomuserid_lastprivatereadupdate: Arc, // LastPrivateReadUpdate = Count + + //pub rooms: rooms::Rooms, + pub pduid_pdu: Arc, // PduId = ShortRoomId + Count + pub eventid_pduid: Arc, + pub roomid_pduleaves: Arc, + pub alias_roomid: Arc, + pub aliasid_alias: Arc, // AliasId = RoomId + Count + pub publicroomids: Arc, + + pub threadid_userids: Arc, // ThreadId = RoomId + Count + + pub tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount + + /// Participating servers in a room. + pub roomserverids: Arc, // RoomServerId = RoomId + ServerName + pub serverroomids: Arc, // ServerRoomId = ServerName + RoomId + + pub userroomid_joined: Arc, + pub roomuserid_joined: Arc, + pub roomid_joinedcount: Arc, + pub roomid_invitedcount: Arc, + pub roomuseroncejoinedids: Arc, + pub userroomid_invitestate: Arc, // InviteState = Vec> + pub roomuserid_invitecount: Arc, // InviteCount = Count + pub userroomid_leftstate: Arc, + pub roomuserid_leftcount: Arc, + + pub disabledroomids: Arc, // Rooms where incoming federation handling is disabled + + pub bannedroomids: Arc, // Rooms where local users are not allowed to join + + pub lazyloadedids: Arc, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId + + pub userroomid_notificationcount: Arc, // NotifyCount = u64 + pub userroomid_highlightcount: Arc, // HightlightCount = u64 + pub roomuserid_lastnotificationread: Arc, // LastNotificationRead = u64 + + /// Remember the current state hash of a room. + pub roomid_shortstatehash: Arc, + pub roomsynctoken_shortstatehash: Arc, + /// Remember the state hash at events in the past. + pub shorteventid_shortstatehash: Arc, + pub statekey_shortstatekey: Arc, /* StateKey = EventType + StateKey, ShortStateKey = + * Count */ + pub shortstatekey_statekey: Arc, + + pub roomid_shortroomid: Arc, + + pub shorteventid_eventid: Arc, + pub eventid_shorteventid: Arc, + + pub statehash_shortstatehash: Arc, + pub shortstatehash_statediff: Arc, /* StateDiff = parent (or 0) + + * (shortstatekey+shorteventid++) + 0_u64 + + * (shortstatekey+shorteventid--) */ + + pub shorteventid_authchain: Arc, + + /// RoomId + EventId -> outlier PDU. + /// Any pdu that has passed the steps 1-8 in the incoming event + /// /federation/send/txn. + pub eventid_outlierpdu: Arc, + pub softfailedeventids: Arc, + + /// ShortEventId + ShortEventId -> (). + pub tofrom_relation: Arc, + /// RoomId + EventId -> Parent PDU EventId. + pub referencedevents: Arc, + + //pub account_data: account_data::AccountData, + pub roomuserdataid_accountdata: Arc, // RoomUserDataId = Room + User + Count + Type + pub roomusertype_roomuserdataid: Arc, // RoomUserType = Room + User + Type + + //pub media: media::Media, + pub mediaid_file: Arc, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType + pub url_previews: Arc, + pub mediaid_user: Arc, + //pub key_backups: key_backups::KeyBackups, + pub backupid_algorithm: Arc, // BackupId = UserId + Version(Count) + pub backupid_etag: Arc, // BackupId = UserId + Version(Count) + pub backupkeyid_backup: Arc, // BackupKeyId = UserId + Version + RoomId + SessionId + + //pub transaction_ids: transaction_ids::TransactionIds, + pub userdevicetxnid_response: Arc, /* Response can be empty (/sendToDevice) or the event id + * (/send) */ + //pub sending: sending::Sending, + pub servername_educount: Arc, // EduCount: Count of last EDU sync + pub servernameevent_data: Arc, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId + + * PduId / Id (for edus), Data = EDU content */ + pub servercurrentevent_data: Arc, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId + * / Id (for edus), Data = EDU content */ + + //pub appservice: appservice::Appservice, + pub id_appserviceregistrations: Arc, + + //pub pusher: pusher::PushData, + pub senderkey_pusher: Arc, + + pub auth_chain_cache: Mutex, Arc<[u64]>>>, + pub our_real_users_cache: RwLock>>>, + pub appservice_in_room_cache: RwLock>>, + pub lasttimelinecount_cache: Mutex>, +} + +impl KeyValueDatabase { + /// Load an existing database or create a new one. + #[allow(clippy::too_many_lines)] + pub async fn load_or_create(server: &Arc) -> Result { + let config = &server.config; + check_db_setup(config)?; + let builder = build(config)?; + Ok(Self { + db: builder.clone(), + userid_password: builder.open_tree("userid_password")?, + userid_displayname: builder.open_tree("userid_displayname")?, + userid_avatarurl: builder.open_tree("userid_avatarurl")?, + userid_blurhash: builder.open_tree("userid_blurhash")?, + userdeviceid_token: builder.open_tree("userdeviceid_token")?, + userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, + userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, + token_userdeviceid: builder.open_tree("token_userdeviceid")?, + onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, + userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, + keychangeid_userid: builder.open_tree("keychangeid_userid")?, + keyid_key: builder.open_tree("keyid_key")?, + userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, + userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, + userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, + userfilterid_filter: builder.open_tree("userfilterid_filter")?, + todeviceid_events: builder.open_tree("todeviceid_events")?, + userid_presenceid: builder.open_tree("userid_presenceid")?, + presenceid_presence: builder.open_tree("presenceid_presence")?, + + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, + userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), + readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, + roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt + roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?, + pduid_pdu: builder.open_tree("pduid_pdu")?, + eventid_pduid: builder.open_tree("eventid_pduid")?, + roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, + + alias_roomid: builder.open_tree("alias_roomid")?, + aliasid_alias: builder.open_tree("aliasid_alias")?, + publicroomids: builder.open_tree("publicroomids")?, + + threadid_userids: builder.open_tree("threadid_userids")?, + + tokenids: builder.open_tree("tokenids")?, + + roomserverids: builder.open_tree("roomserverids")?, + serverroomids: builder.open_tree("serverroomids")?, + userroomid_joined: builder.open_tree("userroomid_joined")?, + roomuserid_joined: builder.open_tree("roomuserid_joined")?, + roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, + roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, + roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, + userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, + roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, + userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, + roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + + disabledroomids: builder.open_tree("disabledroomids")?, + + bannedroomids: builder.open_tree("bannedroomids")?, + + lazyloadedids: builder.open_tree("lazyloadedids")?, + + userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, + userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, + roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, + + statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, + shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, + + shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, + + roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, + + shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, + eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, + shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, + shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, + roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, + roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, + statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, + + eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, + softfailedeventids: builder.open_tree("softfailedeventids")?, + + tofrom_relation: builder.open_tree("tofrom_relation")?, + referencedevents: builder.open_tree("referencedevents")?, + roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, + roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, + mediaid_file: builder.open_tree("mediaid_file")?, + url_previews: builder.open_tree("url_previews")?, + mediaid_user: builder.open_tree("mediaid_user")?, + backupid_algorithm: builder.open_tree("backupid_algorithm")?, + backupid_etag: builder.open_tree("backupid_etag")?, + backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, + userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, + servername_educount: builder.open_tree("servername_educount")?, + servernameevent_data: builder.open_tree("servernameevent_data")?, + servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, + id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, + senderkey_pusher: builder.open_tree("senderkey_pusher")?, + global: builder.open_tree("global")?, + server_signingkeys: builder.open_tree("server_signingkeys")?, + + roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?, + + auth_chain_cache: Mutex::new(LruCache::new( + (f64::from(config.auth_chain_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, + )), + our_real_users_cache: RwLock::new(HashMap::new()), + appservice_in_room_cache: RwLock::new(HashMap::new()), + lasttimelinecount_cache: Mutex::new(HashMap::new()), + }) + } +} + +fn build(config: &Config) -> Result> { + match &*config.database_backend { + "sqlite" => { + debug!("Got sqlite database backend"); + #[cfg(not(feature = "sqlite"))] + return Err(Error::bad_config("Database backend not found.")); + #[cfg(feature = "sqlite")] + Ok(Arc::new(Arc::::open(config)?)) + }, + "rocksdb" => { + debug!("Got rocksdb database backend"); + #[cfg(not(feature = "rocksdb"))] + return Err(Error::bad_config("Database backend not found.")); + #[cfg(feature = "rocksdb")] + Ok(Arc::new(Arc::::open(config)?)) + }, + _ => Err(Error::bad_config( + "Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.", + )), + } +} + +fn check_db_setup(config: &Config) -> Result<()> { + let path = Path::new(&config.database_path); + + let sqlite_exists = path.join("conduit.db").exists(); + let rocksdb_exists = path.join("IDENTITY").exists(); + + if sqlite_exists && rocksdb_exists { + return Err(Error::bad_config("Multiple databases at database_path detected.")); + } + + if sqlite_exists && config.database_backend != "sqlite" { + return Err(Error::bad_config( + "Found sqlite at database_path, but is not specified in config.", + )); + } + + if rocksdb_exists && config.database_backend != "rocksdb" { + return Err(Error::bad_config( + "Found rocksdb at database_path, but is not specified in config.", + )); + } + + Ok(()) +} diff --git a/src/database/kvengine.rs b/src/database/kvengine.rs index c67a7e98..1b27c571 100644 --- a/src/database/kvengine.rs +++ b/src/database/kvengine.rs @@ -3,7 +3,7 @@ use std::{error::Error, sync::Arc}; use super::{Config, KvTree}; use crate::Result; -pub(crate) trait KeyValueDatabaseEngine: Send + Sync { +pub trait KeyValueDatabaseEngine: Send + Sync { fn open(config: &Config) -> Result where Self: Sized; diff --git a/src/database/kvtree.rs b/src/database/kvtree.rs index 52e5b146..009e45d5 100644 --- a/src/database/kvtree.rs +++ b/src/database/kvtree.rs @@ -2,7 +2,7 @@ use std::{future::Future, pin::Pin}; use crate::Result; -pub(crate) trait KvTree: Send + Sync { +pub trait KvTree: Send + Sync { fn get(&self, key: &[u8]) -> Result>>; #[allow(dead_code)] diff --git a/src/database/mod.rs b/src/database/mod.rs index 4fbc3f97..506f7ac1 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,573 +1,23 @@ -mod cork; -mod key_value; +pub mod cork; +mod kvdatabase; mod kvengine; mod kvtree; -mod migrations; #[cfg(feature = "rocksdb")] -mod rocksdb; +pub(crate) mod rocksdb; #[cfg(feature = "sqlite")] -mod sqlite; +pub(crate) mod sqlite; #[cfg(any(feature = "sqlite", feature = "rocksdb"))] pub(crate) mod watchers; -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - fs::{self}, - path::Path, - sync::{Arc, Mutex, RwLock}, - time::Duration, -}; - -pub(crate) use cork::Cork; -pub(crate) use kvengine::KeyValueDatabaseEngine; -pub(crate) use kvtree::KvTree; -use lru_cache::LruCache; -use ruma::{ - events::{ - push_rules::PushRulesEventContent, room::message::RoomMessageEventContent, GlobalAccountDataEvent, - GlobalAccountDataEventType, - }, - push::Ruleset, - CanonicalJsonValue, OwnedDeviceId, OwnedRoomId, OwnedUserId, UserId, -}; -use serde::Deserialize; -#[cfg(unix)] -use tokio::signal::unix::{signal, SignalKind}; -use tokio::time::{interval, Instant}; -use tracing::{debug, error, warn}; - -use crate::{ - database::migrations::migrations, service::rooms::timeline::PduCount, services, Config, Error, - LogLevelReloadHandles, Result, Services, SERVICES, -}; - -pub(crate) struct KeyValueDatabase { - db: Arc, - - //pub(crate) globals: globals::Globals, - pub(crate) global: Arc, - pub(crate) server_signingkeys: Arc, - - pub(crate) roomid_inviteviaservers: Arc, - - //pub(crate) users: users::Users, - pub(crate) userid_password: Arc, - pub(crate) userid_displayname: Arc, - pub(crate) userid_avatarurl: Arc, - pub(crate) userid_blurhash: Arc, - pub(crate) userdeviceid_token: Arc, - pub(crate) userdeviceid_metadata: Arc, // This is also used to check if a device exists - pub(crate) userid_devicelistversion: Arc, // DevicelistVersion = u64 - pub(crate) token_userdeviceid: Arc, - - pub(crate) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId - pub(crate) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count - pub(crate) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count - pub(crate) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) - pub(crate) userid_masterkeyid: Arc, - pub(crate) userid_selfsigningkeyid: Arc, - pub(crate) userid_usersigningkeyid: Arc, - - pub(crate) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId - pub(crate) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count - pub(crate) userid_presenceid: Arc, // UserId => Count - pub(crate) presenceid_presence: Arc, // Count + UserId => Presence - - //pub(crate) uiaa: uiaa::Uiaa, - pub(crate) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication - pub(crate) userdevicesessionid_uiaarequest: - RwLock>, - - //pub(crate) edus: RoomEdus, - pub(crate) readreceiptid_readreceipt: Arc, // ReadReceiptId = RoomId + Count + UserId - pub(crate) roomuserid_privateread: Arc, // RoomUserId = Room + User, PrivateRead = Count - pub(crate) roomuserid_lastprivatereadupdate: Arc, // LastPrivateReadUpdate = Count - - //pub(crate) rooms: rooms::Rooms, - pub(crate) pduid_pdu: Arc, // PduId = ShortRoomId + Count - pub(crate) eventid_pduid: Arc, - pub(crate) roomid_pduleaves: Arc, - pub(crate) alias_roomid: Arc, - pub(crate) aliasid_alias: Arc, // AliasId = RoomId + Count - pub(crate) publicroomids: Arc, - - pub(crate) threadid_userids: Arc, // ThreadId = RoomId + Count - - pub(crate) tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount - - /// Participating servers in a room. - pub(crate) roomserverids: Arc, // RoomServerId = RoomId + ServerName - pub(crate) serverroomids: Arc, // ServerRoomId = ServerName + RoomId - - pub(crate) userroomid_joined: Arc, - pub(crate) roomuserid_joined: Arc, - pub(crate) roomid_joinedcount: Arc, - pub(crate) roomid_invitedcount: Arc, - pub(crate) roomuseroncejoinedids: Arc, - pub(crate) userroomid_invitestate: Arc, // InviteState = Vec> - pub(crate) roomuserid_invitecount: Arc, // InviteCount = Count - pub(crate) userroomid_leftstate: Arc, - pub(crate) roomuserid_leftcount: Arc, - - pub(crate) disabledroomids: Arc, // Rooms where incoming federation handling is disabled - - pub(crate) bannedroomids: Arc, // Rooms where local users are not allowed to join - - pub(crate) lazyloadedids: Arc, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId - - pub(crate) userroomid_notificationcount: Arc, // NotifyCount = u64 - pub(crate) userroomid_highlightcount: Arc, // HightlightCount = u64 - pub(crate) roomuserid_lastnotificationread: Arc, // LastNotificationRead = u64 - - /// Remember the current state hash of a room. - pub(crate) roomid_shortstatehash: Arc, - pub(crate) roomsynctoken_shortstatehash: Arc, - /// Remember the state hash at events in the past. - pub(crate) shorteventid_shortstatehash: Arc, - pub(crate) statekey_shortstatekey: Arc, /* StateKey = EventType + StateKey, ShortStateKey = - * Count */ - pub(crate) shortstatekey_statekey: Arc, - - pub(crate) roomid_shortroomid: Arc, - - pub(crate) shorteventid_eventid: Arc, - pub(crate) eventid_shorteventid: Arc, - - pub(crate) statehash_shortstatehash: Arc, - pub(crate) shortstatehash_statediff: Arc, /* StateDiff = parent (or 0) + - * (shortstatekey+shorteventid++) + 0_u64 + - * (shortstatekey+shorteventid--) */ - - pub(crate) shorteventid_authchain: Arc, - - /// RoomId + EventId -> outlier PDU. - /// Any pdu that has passed the steps 1-8 in the incoming event - /// /federation/send/txn. - pub(crate) eventid_outlierpdu: Arc, - pub(crate) softfailedeventids: Arc, - - /// ShortEventId + ShortEventId -> (). - pub(crate) tofrom_relation: Arc, - /// RoomId + EventId -> Parent PDU EventId. - pub(crate) referencedevents: Arc, - - //pub(crate) account_data: account_data::AccountData, - pub(crate) roomuserdataid_accountdata: Arc, // RoomUserDataId = Room + User + Count + Type - pub(crate) roomusertype_roomuserdataid: Arc, // RoomUserType = Room + User + Type - - //pub(crate) media: media::Media, - pub(crate) mediaid_file: Arc, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType - pub(crate) url_previews: Arc, - pub(crate) mediaid_user: Arc, - //pub(crate) key_backups: key_backups::KeyBackups, - pub(crate) backupid_algorithm: Arc, // BackupId = UserId + Version(Count) - pub(crate) backupid_etag: Arc, // BackupId = UserId + Version(Count) - pub(crate) backupkeyid_backup: Arc, // BackupKeyId = UserId + Version + RoomId + SessionId - - //pub(crate) transaction_ids: transaction_ids::TransactionIds, - pub(crate) userdevicetxnid_response: Arc, /* Response can be empty (/sendToDevice) or the event id - * (/send) */ - //pub(crate) sending: sending::Sending, - pub(crate) servername_educount: Arc, // EduCount: Count of last EDU sync - pub(crate) servernameevent_data: Arc, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId + - * PduId / Id (for edus), Data = EDU content */ - pub(crate) servercurrentevent_data: Arc, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId - * / Id (for edus), Data = EDU content */ - - //pub(crate) appservice: appservice::Appservice, - pub(crate) id_appserviceregistrations: Arc, - - //pub(crate) pusher: pusher::PushData, - pub(crate) senderkey_pusher: Arc, - - pub(crate) auth_chain_cache: Mutex, Arc<[u64]>>>, - pub(crate) our_real_users_cache: RwLock>>>, - pub(crate) appservice_in_room_cache: RwLock>>, - pub(crate) lasttimelinecount_cache: Mutex>, -} - -#[derive(Deserialize)] -struct CheckForUpdatesResponseEntry { - id: u64, - date: String, - message: String, -} -#[derive(Deserialize)] -struct CheckForUpdatesResponse { - updates: Vec, -} - -impl KeyValueDatabase { - /// Load an existing database or create a new one. - #[allow(clippy::too_many_lines)] - pub(crate) async fn load_or_create(config: Config, tracing_reload_handler: LogLevelReloadHandles) -> Result<()> { - Self::check_db_setup(&config)?; - - if !Path::new(&config.database_path).exists() { - debug!("Database path does not exist, assuming this is a new setup and creating it"); - fs::create_dir_all(&config.database_path).map_err(|e| { - error!("Failed to create database path: {e}"); - Error::bad_config( - "Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please \ - create the database folder yourself or allow conduwuit the permissions to create directories and \ - files.", - ) - })?; - } - - let builder: Arc = match &*config.database_backend { - "sqlite" => { - debug!("Got sqlite database backend"); - #[cfg(not(feature = "sqlite"))] - return Err(Error::bad_config("Database backend not found.")); - #[cfg(feature = "sqlite")] - Arc::new(Arc::::open(&config)?) - }, - "rocksdb" => { - debug!("Got rocksdb database backend"); - #[cfg(not(feature = "rocksdb"))] - return Err(Error::bad_config("Database backend not found.")); - #[cfg(feature = "rocksdb")] - Arc::new(Arc::::open(&config)?) - }, - _ => { - return Err(Error::bad_config( - "Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.", - )); - }, - }; - - let db_raw = Box::new(Self { - db: builder.clone(), - userid_password: builder.open_tree("userid_password")?, - userid_displayname: builder.open_tree("userid_displayname")?, - userid_avatarurl: builder.open_tree("userid_avatarurl")?, - userid_blurhash: builder.open_tree("userid_blurhash")?, - userdeviceid_token: builder.open_tree("userdeviceid_token")?, - userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, - userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, - token_userdeviceid: builder.open_tree("token_userdeviceid")?, - onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, - userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, - keychangeid_userid: builder.open_tree("keychangeid_userid")?, - keyid_key: builder.open_tree("keyid_key")?, - userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, - userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, - userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, - userfilterid_filter: builder.open_tree("userfilterid_filter")?, - todeviceid_events: builder.open_tree("todeviceid_events")?, - userid_presenceid: builder.open_tree("userid_presenceid")?, - presenceid_presence: builder.open_tree("presenceid_presence")?, - - userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, - userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, - roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt - roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?, - pduid_pdu: builder.open_tree("pduid_pdu")?, - eventid_pduid: builder.open_tree("eventid_pduid")?, - roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, - - alias_roomid: builder.open_tree("alias_roomid")?, - aliasid_alias: builder.open_tree("aliasid_alias")?, - publicroomids: builder.open_tree("publicroomids")?, - - threadid_userids: builder.open_tree("threadid_userids")?, - - tokenids: builder.open_tree("tokenids")?, - - roomserverids: builder.open_tree("roomserverids")?, - serverroomids: builder.open_tree("serverroomids")?, - userroomid_joined: builder.open_tree("userroomid_joined")?, - roomuserid_joined: builder.open_tree("roomuserid_joined")?, - roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, - roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, - roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, - userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, - roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, - userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, - roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, - - disabledroomids: builder.open_tree("disabledroomids")?, - - bannedroomids: builder.open_tree("bannedroomids")?, - - lazyloadedids: builder.open_tree("lazyloadedids")?, - - userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, - userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, - roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, - - statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, - shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, - - shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, - - roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, - - shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, - eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, - shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, - shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, - roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, - roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, - statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, - - eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, - softfailedeventids: builder.open_tree("softfailedeventids")?, - - tofrom_relation: builder.open_tree("tofrom_relation")?, - referencedevents: builder.open_tree("referencedevents")?, - roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, - roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, - mediaid_file: builder.open_tree("mediaid_file")?, - url_previews: builder.open_tree("url_previews")?, - mediaid_user: builder.open_tree("mediaid_user")?, - backupid_algorithm: builder.open_tree("backupid_algorithm")?, - backupid_etag: builder.open_tree("backupid_etag")?, - backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, - userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, - servername_educount: builder.open_tree("servername_educount")?, - servernameevent_data: builder.open_tree("servernameevent_data")?, - servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, - id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, - senderkey_pusher: builder.open_tree("senderkey_pusher")?, - global: builder.open_tree("global")?, - server_signingkeys: builder.open_tree("server_signingkeys")?, - - roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?, - - #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] - auth_chain_cache: Mutex::new(LruCache::new( - (f64::from(config.auth_chain_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), - our_real_users_cache: RwLock::new(HashMap::new()), - appservice_in_room_cache: RwLock::new(HashMap::new()), - lasttimelinecount_cache: Mutex::new(HashMap::new()), - }); - - let db = Box::leak(db_raw); - - let services_raw = Box::new(Services::build(db, &config, tracing_reload_handler)?); - - // This is the first and only time we initialize the SERVICE static - *SERVICES.write().unwrap() = Some(Box::leak(services_raw)); - - migrations(db, &config).await?; - - services().admin.start_handler(); - - // Set emergency access for the conduit user - match set_emergency_access() { - Ok(pwd_set) => { - if pwd_set { - warn!( - "The Conduit account emergency password is set! Please unset it as soon as you finish admin \ - account recovery!" - ); - services() - .admin - .send_message(RoomMessageEventContent::text_plain( - "The Conduit account emergency password is set! Please unset it as soon as you finish \ - admin account recovery!", - )) - .await; - } - }, - Err(e) => { - error!("Could not set the configured emergency password for the conduit user: {}", e); - }, - }; - - services().sending.start_handler(); - - if config.allow_local_presence { - services().presence.start_handler(); - } - - Self::start_cleanup_task().await; - if services().globals.allow_check_for_updates() { - Self::start_check_for_updates_task().await; - } - - Ok(()) - } - - fn check_db_setup(config: &Config) -> Result<()> { - let path = Path::new(&config.database_path); - - let sqlite_exists = path.join("conduit.db").exists(); - let rocksdb_exists = path.join("IDENTITY").exists(); - - if sqlite_exists && rocksdb_exists { - return Err(Error::bad_config("Multiple databases at database_path detected.")); - } - - if sqlite_exists && config.database_backend != "sqlite" { - return Err(Error::bad_config( - "Found sqlite at database_path, but is not specified in config.", - )); - } - - if rocksdb_exists && config.database_backend != "rocksdb" { - return Err(Error::bad_config( - "Found rocksdb at database_path, but is not specified in config.", - )); - } - - Ok(()) - } - - #[tracing::instrument] - async fn start_check_for_updates_task() { - let timer_interval = Duration::from_secs(7200); // 2 hours - - tokio::spawn(async move { - let mut i = interval(timer_interval); - - loop { - tokio::select! { - _ = i.tick() => { - debug!(target: "start_check_for_updates_task", "Timer ticked"); - }, - } - - _ = Self::try_handle_updates().await; - } - }); - } - - async fn try_handle_updates() -> Result<()> { - let response = services() - .globals - .client - .default - .get("https://pupbrain.dev/check-for-updates/stable") - .send() - .await?; - - let response = serde_json::from_str::(&response.text().await?).map_err(|e| { - error!("Bad check for updates response: {e}"); - Error::BadServerResponse("Bad version check response") - })?; - - let mut last_update_id = services().globals.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > services().globals.last_check_for_updates_id()? { - error!("{}", update.message); - services() - .admin - .send_message(RoomMessageEventContent::text_plain(format!( - "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", - update.date, update.message - ))) - .await; - } - } - services() - .globals - .update_check_for_updates_id(last_update_id)?; - - Ok(()) - } - - #[tracing::instrument] - async fn start_cleanup_task() { - let timer_interval = Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval)); - - tokio::spawn(async move { - let mut i = interval(timer_interval); - - #[cfg(unix)] - let mut hangup = signal(SignalKind::hangup()).expect("Failed to register SIGHUP signal receiver"); - #[cfg(unix)] - let mut ctrl_c = signal(SignalKind::interrupt()).expect("Failed to register SIGINT signal receiver"); - #[cfg(unix)] - let mut terminate = signal(SignalKind::terminate()).expect("Failed to register SIGTERM signal receiver"); - - loop { - #[cfg(unix)] - tokio::select! { - _ = i.tick() => { - debug!(target: "database-cleanup", "Timer ticked"); - } - _ = hangup.recv() => { - debug!(target: "database-cleanup","Received SIGHUP"); - } - _ = ctrl_c.recv() => { - debug!(target: "database-cleanup", "Received Ctrl+C"); - } - _ = terminate.recv() => { - debug!(target: "database-cleanup","Received SIGTERM"); - } - } - - #[cfg(not(unix))] - { - i.tick().await; - debug!(target: "database-cleanup", "Timer ticked") - } - - Self::perform_cleanup(); - } - }); - } - - fn perform_cleanup() { - if !services().globals.config.rocksdb_periodic_cleanup { - return; - } - - let start = Instant::now(); - if let Err(e) = services().globals.cleanup() { - error!(target: "database-cleanup", "Ran into an error during cleanup: {}", e); - } else { - debug!(target: "database-cleanup", "Finished cleanup in {:#?}.", start.elapsed()); - } - } - - #[allow(dead_code)] - fn flush(&self) -> Result<()> { - let start = std::time::Instant::now(); - - let res = self.db.flush(); - - debug!("flush: took {:?}", start.elapsed()); - - res - } -} - -/// Sets the emergency password and push rules for the @conduit account in case -/// emergency password is set -fn set_emergency_access() -> Result { - let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("@conduit:server_name is a valid UserId"); - - services() - .users - .set_password(&conduit_user, services().globals.emergency_password().as_deref())?; - - let (ruleset, res) = match services().globals.emergency_password() { - Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), - None => (Ruleset::new(), Ok(false)), - }; - - services().account_data.update( - None, - &conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { - global: ruleset, - }, - }) - .expect("to json value always works"), - )?; - - res -} +extern crate conduit_core as conduit; +pub(crate) use conduit::{Config, Result}; +pub use cork::Cork; +pub use kvdatabase::KeyValueDatabase; +pub use kvengine::KeyValueDatabaseEngine; +pub use kvtree::KvTree; + +conduit::mod_ctor! {} +conduit::mod_dtor! {} diff --git a/src/database/rocksdb/kvtree.rs b/src/database/rocksdb/kvtree.rs index 4761624b..02a0f3bf 100644 --- a/src/database/rocksdb/kvtree.rs +++ b/src/database/rocksdb/kvtree.rs @@ -1,9 +1,8 @@ use std::{future::Future, pin::Pin, sync::Arc}; -use rust_rocksdb::WriteBatchWithTransaction; +use conduit::{utils, Result}; -use super::{watchers::Watchers, Engine, KeyValueDatabaseEngine, KvTree}; -use crate::{utils, Result}; +use super::{rust_rocksdb::WriteBatchWithTransaction, watchers::Watchers, Engine, KeyValueDatabaseEngine, KvTree}; pub(crate) struct RocksDbEngineTree<'a> { pub(crate) db: Arc, diff --git a/src/database/rocksdb/mod.rs b/src/database/rocksdb/mod.rs index 4ef8f9a7..3f39c292 100644 --- a/src/database/rocksdb/mod.rs +++ b/src/database/rocksdb/mod.rs @@ -1,3 +1,8 @@ +// no_link to prevent double-inclusion of librocksdb.a here and with +// libconduit_core.so +#[no_link] +extern crate rust_rocksdb; + use std::{ collections::HashMap, sync::{atomic::AtomicU32, Arc}, @@ -6,12 +11,12 @@ use std::{ use chrono::{DateTime, Utc}; use rust_rocksdb::{ backup::{BackupEngine, BackupEngineOptions}, + perf::get_memory_usage_stats, Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode as Db, Env, MultiThreaded, Options, }; use tracing::{debug, error, info, warn}; -use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree}; -use crate::Result; +use crate::{watchers::Watchers, Config, KeyValueDatabaseEngine, KvTree, Result}; pub(crate) mod kvtree; pub(crate) mod opts; @@ -22,13 +27,13 @@ use opts::{cf_options, db_options}; use super::watchers; pub(crate) struct Engine { - rocks: Db, + config: Config, row_cache: Cache, col_cache: HashMap, - old_cfs: Vec, opts: Options, env: Env, - config: Config, + old_cfs: Vec, + rocks: Db, corks: AtomicU32, } @@ -79,13 +84,13 @@ impl KeyValueDatabaseEngine for Arc { load_time.elapsed() ); Ok(Arc::new(Engine { - rocks: db, + config: config.clone(), row_cache, col_cache, - old_cfs: cfs, opts: db_opts, env: db_env, - config: config.clone(), + old_cfs: cfs, + rocks: db, corks: AtomicU32::new(0), })) } @@ -135,7 +140,7 @@ impl KeyValueDatabaseEngine for Arc { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] fn memory_usage(&self) -> Result { let mut res = String::new(); - let stats = rust_rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.row_cache]))?; + let stats = get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.row_cache]))?; _ = std::fmt::write( &mut res, format_args!( @@ -258,3 +263,20 @@ impl KeyValueDatabaseEngine for Arc { #[allow(dead_code)] fn clear_caches(&self) {} } + +impl Drop for Engine { + fn drop(&mut self) { + debug!("Waiting for background tasks to finish..."); + const BLOCKING: bool = true; + self.rocks.cancel_all_background_work(BLOCKING); + + debug!("Shutting down background threads"); + self.env.set_high_priority_background_threads(0); + self.env.set_low_priority_background_threads(0); + self.env.set_bottom_priority_background_threads(0); + self.env.set_background_threads(0); + + debug!("Joining background threads..."); + self.env.join_all_threads(); + } +} diff --git a/src/database/rocksdb/opts.rs b/src/database/rocksdb/opts.rs index 78b6db95..b417b126 100644 --- a/src/database/rocksdb/opts.rs +++ b/src/database/rocksdb/opts.rs @@ -1,14 +1,14 @@ #![allow(dead_code)] - use std::collections::HashMap; -use rust_rocksdb::{ - BlockBasedOptions, Cache, DBCompactionStyle, DBCompressionType, DBRecoveryMode, Env, LogLevel, Options, - UniversalCompactOptions, UniversalCompactionStopStyle, +use super::{ + rust_rocksdb::{ + BlockBasedOptions, Cache, DBCompactionStyle, DBCompressionType, DBRecoveryMode, Env, LogLevel, Options, + UniversalCompactOptions, UniversalCompactionStopStyle, + }, + Config, }; -use super::Config; - /// Create database-wide options suitable for opening the database. This also /// sets our default column options in case of opening a column with the same /// resulting value. Note that we require special per-column options on some diff --git a/src/database/sqlite/mod.rs b/src/database/sqlite/mod.rs index c43d2fe0..4e8c079e 100644 --- a/src/database/sqlite/mod.rs +++ b/src/database/sqlite/mod.rs @@ -6,13 +6,13 @@ use std::{ sync::Arc, }; +use conduit::{Config, Result}; use parking_lot::{Mutex, MutexGuard}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use thread_local::ThreadLocal; use tracing::debug; use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; -use crate::{database::Config, Result}; thread_local! { static READ_CONNECTION: RefCell> = const { RefCell::new(None) }; @@ -224,7 +224,7 @@ impl KvTree for SqliteTable { guard.execute("BEGIN", [])?; for key in iter { let old = self.get_with_guard(&guard, &key)?; - let new = crate::utils::increment(old.as_deref()); + let new = conduit::utils::increment(old.as_deref()); self.insert_with_guard(&guard, &key, &new)?; } guard.execute("COMMIT", [])?; @@ -307,7 +307,7 @@ impl KvTree for SqliteTable { let old = self.get_with_guard(&guard, key)?; - let new = crate::utils::increment(old.as_deref()); + let new = conduit::utils::increment(old.as_deref()); self.insert_with_guard(&guard, key, &new)?; diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 7f40038b..00000000 --- a/src/main.rs +++ /dev/null @@ -1,503 +0,0 @@ -#[cfg(unix)] -use std::fs::Permissions; // not unix specific, just only for UNIX sockets stuff and *nix container checks -#[cfg(unix)] -use std::os::unix::fs::PermissionsExt as _; /* not unix specific, just only for UNIX sockets stuff and *nix - * container checks */ -// Not async due to services() being used in many closures, and async closures -// are not stable as of writing This is the case for every other occurence of -// sync Mutex/RwLock, except for database related ones -use std::sync::{Arc, RwLock}; -use std::{io, net::SocketAddr, time::Duration}; - -use api::ruma_wrapper::{Ruma, RumaResponse}; -use axum::Router; -use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; -#[cfg(feature = "axum_dual_protocol")] -use axum_server_dual_protocol::ServerExt; -use config::Config; -use database::KeyValueDatabase; -use service::{pdu::PduEvent, Services}; -use tokio::{ - signal, - sync::oneshot::{self, Sender}, - task::JoinSet, -}; -use tracing::{debug, error, info, warn}; -use tracing_subscriber::{prelude::*, reload, EnvFilter, Registry}; -use utils::{ - clap, - error::{Error, Result}, -}; - -pub(crate) mod alloc; -mod api; -mod config; -mod database; -mod router; -mod service; -mod utils; - -pub(crate) static SERVICES: RwLock>> = RwLock::new(None); - -#[must_use] -pub(crate) fn services() -> &'static Services<'static> { - SERVICES - .read() - .unwrap() - .expect("SERVICES should be initialized when this is called") -} - -pub(crate) struct Server { - config: Config, - - runtime: tokio::runtime::Runtime, - - tracing_reload_handle: LogLevelReloadHandles, - - #[cfg(feature = "sentry_telemetry")] - _sentry_guard: Option, - - _tracing_flame_guard: TracingFlameGuard, -} - -fn main() -> Result<(), Error> { - let args = clap::parse(); - let conduwuit: Server = init(args)?; - - conduwuit - .runtime - .block_on(async { async_main(&conduwuit).await }) -} - -async fn async_main(server: &Server) -> Result<(), Error> { - if let Err(error) = start(server).await { - error!("Critical error starting server: {error}"); - return Err(Error::Err(format!("{error}"))); - } - - if let Err(error) = run(server).await { - error!("Critical error running server: {error}"); - return Err(Error::Err(format!("{error}"))); - } - - if let Err(error) = stop(server).await { - error!("Critical error stopping server: {error}"); - return Err(Error::Err(format!("{error}"))); - } - - Ok(()) -} - -async fn run(server: &Server) -> io::Result<()> { - let app = router::build(server).await?; - let (tx, rx) = oneshot::channel::<()>(); - let handle = ServerHandle::new(); - tokio::spawn(shutdown(handle.clone(), tx)); - - #[cfg(unix)] - if server.config.unix_socket_path.is_some() { - return run_unix_socket_server(server, app, rx).await; - } - - let addrs = server.config.get_bind_addrs(); - if server.config.tls.is_some() { - return run_tls_server(server, app, handle, addrs).await; - } - - let mut join_set = JoinSet::new(); - for addr in &addrs { - join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone())); - } - - #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - - info!("Listening on {:?}", addrs); - join_set.join_next().await; - - Ok(()) -} - -async fn run_tls_server( - server: &Server, app: axum::routing::IntoMakeService, handle: ServerHandle, addrs: Vec, -) -> io::Result<()> { - let tls = server.config.tls.as_ref().unwrap(); - - debug!( - "Using direct TLS. Certificate path {} and certificate private key path {}", - &tls.certs, &tls.key - ); - info!( - "Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS." - ); - let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; - - if cfg!(feature = "axum_dual_protocol") { - info!( - "conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only \ - take affect if `dual_protocol` is enabled in `[global.tls]`" - ); - } - - let mut join_set = JoinSet::new(); - - if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { - #[cfg(feature = "axum_dual_protocol")] - for addr in &addrs { - join_set.spawn( - axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) - .set_upgrade(false) - .handle(handle.clone()) - .serve(app.clone()), - ); - } - } else { - for addr in &addrs { - join_set.spawn( - bind_rustls(*addr, conf.clone()) - .handle(handle.clone()) - .serve(app.clone()), - ); - } - } - - #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - - if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { - warn!( - "Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)", - addrs, &tls.certs - ); - } else { - info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs); - } - - join_set.join_next().await; - - Ok(()) -} - -#[cfg(unix)] -#[allow(unused_variables)] -async fn run_unix_socket_server( - server: &Server, app: axum::routing::IntoMakeService, rx: oneshot::Receiver<()>, -) -> io::Result<()> { - let path = server.config.unix_socket_path.as_ref().unwrap(); - - if path.exists() { - warn!( - "UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", - path.display() - ); - tokio::fs::remove_file(&path).await?; - } - - tokio::fs::create_dir_all(path.parent().unwrap()).await?; - - let socket_perms = server.config.unix_socket_perms.to_string(); - let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); - tokio::fs::set_permissions(&path, Permissions::from_mode(octal_perms)) - .await - .unwrap(); - - #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - let bind = tokio::net::UnixListener::bind(path)?; - info!("Listening at {:?}", path); - - Ok(()) -} - -async fn shutdown(handle: ServerHandle, tx: Sender<()>) -> Result<()> { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install SIGTERM handler") - .recv() - .await; - }; - - let sig: &str; - #[cfg(unix)] - tokio::select! { - () = ctrl_c => { sig = "Ctrl+C"; }, - () = terminate => { sig = "SIGTERM"; }, - } - #[cfg(not(unix))] - tokio::select! { - _ = ctrl_c => { sig = "Ctrl+C"; }, - } - - warn!("Received {}, shutting down...", sig); - handle.graceful_shutdown(Some(Duration::from_secs(180))); - services().globals.shutdown(); - - #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); - - tx.send(()).expect( - "failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your \ - system may not be in an okay/ideal state.)", - ); - - Ok(()) -} - -async fn stop(_server: &Server) -> io::Result<()> { - info!("Shutdown complete."); - - Ok(()) -} - -/// Async initializations -async fn start(server: &Server) -> Result<(), Error> { - KeyValueDatabase::load_or_create(server.config.clone(), server.tracing_reload_handle.clone()).await?; - - Ok(()) -} - -/// Non-async initializations -fn init(args: clap::Args) -> Result { - let config = Config::new(args.config)?; - - #[cfg(feature = "sentry_telemetry")] - let sentry_guard = if config.sentry { - Some(init_sentry(&config)) - } else { - None - }; - - let (tracing_reload_handle, tracing_flame_guard) = init_tracing(&config); - - config.check()?; - - info!( - server_name = ?config.server_name, - database_path = ?config.database_path, - log_levels = ?config.log, - "{}", - utils::conduwuit_version(), - ); - - #[cfg(unix)] - maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit"); - - Ok(Server { - config, - - runtime: tokio::runtime::Builder::new_multi_thread() - .enable_io() - .enable_time() - .thread_name("conduwuit:worker") - .worker_threads(std::cmp::max(2, num_cpus::get())) - .build() - .unwrap(), - - tracing_reload_handle, - - #[cfg(feature = "sentry_telemetry")] - _sentry_guard: sentry_guard, - _tracing_flame_guard: tracing_flame_guard, - }) -} - -#[cfg(feature = "sentry_telemetry")] -fn init_sentry(config: &Config) -> sentry::ClientInitGuard { - sentry::init(( - config - .sentry_endpoint - .as_ref() - .expect("init_sentry should only be called if sentry is enabled and this is not None") - .as_str(), - sentry::ClientOptions { - release: sentry::release_name!(), - traces_sample_rate: config.sentry_traces_sample_rate, - server_name: if config.sentry_send_server_name { - Some(config.server_name.to_string().into()) - } else { - None - }, - ..Default::default() - }, - )) -} - -/// We need to store a reload::Handle value, but can't name it's type explicitly -/// because the S type parameter depends on the subscriber's previous layers. In -/// our case, this includes unnameable 'impl Trait' types. -/// -/// This is fixed[1] in the unreleased tracing-subscriber from the master -/// branch, which removes the S parameter. Unfortunately can't use it without -/// pulling in a version of tracing that's incompatible with the rest of our -/// deps. -/// -/// To work around this, we define an trait without the S paramter that forwards -/// to the reload::Handle::reload method, and then store the handle as a trait -/// object. -/// -/// [1]: -trait ReloadHandle { - fn reload(&self, new_value: L) -> Result<(), reload::Error>; -} - -impl ReloadHandle for reload::Handle { - fn reload(&self, new_value: L) -> Result<(), reload::Error> { reload::Handle::reload(self, new_value) } -} - -struct LogLevelReloadHandlesInner { - handles: Vec + Send + Sync>>, -} - -/// Wrapper to allow reloading the filter on several several -/// [`tracing_subscriber::reload::Handle`]s at once, with the same value. -#[derive(Clone)] -struct LogLevelReloadHandles { - inner: Arc, -} - -impl LogLevelReloadHandles { - fn new(handles: Vec + Send + Sync>>) -> LogLevelReloadHandles { - LogLevelReloadHandles { - inner: Arc::new(LogLevelReloadHandlesInner { - handles, - }), - } - } - - fn reload(&self, new_value: &EnvFilter) -> Result<(), reload::Error> { - for handle in &self.inner.handles { - handle.reload(new_value.clone())?; - } - Ok(()) - } -} - -#[cfg(feature = "perf_measurements")] -type TracingFlameGuard = Option>>; -#[cfg(not(feature = "perf_measurements"))] -type TracingFlameGuard = (); - -// clippy thinks the filter_layer clones are redundant if the next usage is -// behind a disabled feature. -#[allow(clippy::redundant_clone)] -fn init_tracing(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard) { - let registry = Registry::default(); - let fmt_layer = tracing_subscriber::fmt::Layer::new(); - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); - EnvFilter::try_new("warn").unwrap() - }, - }; - - let mut reload_handles = Vec:: + Send + Sync>>::new(); - let subscriber = registry; - - #[cfg(feature = "tokio_console")] - let subscriber = { - let console_layer = console_subscriber::spawn(); - subscriber.with(console_layer) - }; - - let (fmt_reload_filter, fmt_reload_handle) = reload::Layer::new(filter_layer.clone()); - reload_handles.push(Box::new(fmt_reload_handle)); - let subscriber = subscriber.with(fmt_layer.with_filter(fmt_reload_filter)); - - #[cfg(feature = "sentry_telemetry")] - let subscriber = { - let sentry_layer = sentry_tracing::layer(); - let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(filter_layer.clone()); - reload_handles.push(Box::new(sentry_reload_handle)); - subscriber.with(sentry_layer.with_filter(sentry_reload_filter)) - }; - - #[cfg(feature = "perf_measurements")] - let (subscriber, flame_guard) = { - let (flame_layer, flame_guard) = if config.tracing_flame { - let flame_filter = match EnvFilter::try_new(&config.tracing_flame_filter) { - Ok(flame_filter) => flame_filter, - Err(e) => panic!("tracing_flame_filter config value is invalid: {e}"), - }; - - let (flame_layer, flame_guard) = - match tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) { - Ok(ok) => ok, - Err(e) => { - panic!("failed to initialize tracing-flame: {e}"); - }, - }; - let flame_layer = flame_layer - .with_empty_samples(false) - .with_filter(flame_filter); - (Some(flame_layer), Some(flame_guard)) - } else { - (None, None) - }; - - let jaeger_layer = if config.allow_jaeger { - opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); - let tracer = opentelemetry_jaeger::new_agent_pipeline() - .with_auto_split_batch(true) - .with_service_name("conduwuit") - .install_batch(opentelemetry_sdk::runtime::Tokio) - .unwrap(); - let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); - - let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(filter_layer); - reload_handles.push(Box::new(jaeger_reload_handle)); - Some(telemetry.with_filter(jaeger_reload_filter)) - } else { - None - }; - - let subscriber = subscriber.with(flame_layer).with(jaeger_layer); - (subscriber, flame_guard) - }; - - #[cfg(not(feature = "perf_measurements"))] - #[cfg_attr(not(feature = "perf_measurements"), allow(clippy::let_unit_value))] - let flame_guard = (); - - tracing::subscriber::set_global_default(subscriber).unwrap(); - - #[cfg(all(feature = "tokio_console", feature = "release_max_log_level"))] - error!( - "'tokio_console' feature and 'release_max_log_level' feature are incompatible, because console-subscriber \ - needs access to trace-level events. 'release_max_log_level' must be disabled to use tokio-console." - ); - - (LogLevelReloadHandles::new(reload_handles), flame_guard) -} - -/// This is needed for opening lots of file descriptors, which tends to -/// happen more often when using RocksDB and making lots of federation -/// connections at startup. The soft limit is usually 1024, and the hard -/// limit is usually 512000; I've personally seen it hit >2000. -/// -/// * -/// * -#[cfg(unix)] -fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { - use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE}; - - let (soft_limit, hard_limit) = getrlimit(NOFILE)?; - if soft_limit < hard_limit { - setrlimit(NOFILE, hard_limit, hard_limit)?; - assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit"); - debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",); - } - - Ok(()) -} diff --git a/src/router/Cargo.toml b/src/router/Cargo.toml new file mode 100644 index 00000000..dae9d14c --- /dev/null +++ b/src/router/Cargo.toml @@ -0,0 +1,85 @@ +[package] +name = "conduit_router" +version.workspace = true +edition.workspace = true + +[lib] +path = "mod.rs" +crate-type = [ + "rlib", +# "dylib", +] + +[features] +default = [ + "systemd", + "sentry_telemetry", + "gzip_compression", + "zstd_compression", + "brotli_compression", + "release_max_log_level", +] + +dev_release_log_level = [] +release_max_log_level = [ + "tracing/max_level_trace", + "tracing/release_max_level_info", + "log/max_level_trace", + "log/release_max_level_info", +] +sentry_telemetry = [ + "dep:sentry", + "dep:sentry-tracing", + "dep:sentry-tower", +] +zstd_compression = [ + "tower-http/compression-zstd", +] +gzip_compression = [ + "tower-http/compression-gzip", +] +brotli_compression = [ + "tower-http/compression-br", +] +systemd = [ + "dep:sd-notify", +] +axum_dual_protocol = [ + "dep:axum-server-dual-protocol" +] + +[dependencies] +axum-server-dual-protocol.optional = true +axum-server-dual-protocol.workspace = true +axum-server.workspace = true +axum.workspace = true +conduit-admin.workspace = true +conduit-api.workspace = true +conduit-core.workspace = true +conduit-database.workspace = true +conduit-service.workspace = true +log.workspace = true +tokio.workspace = true +tower.workspace = true +tracing.workspace = true +bytes.workspace = true +clap.workspace = true +http-body-util.workspace = true +http.workspace = true +regex.workspace = true +ruma.workspace = true +sentry.optional = true +sentry-tower.optional = true +sentry-tower.workspace = true +sentry-tracing.optional = true +sentry-tracing.workspace = true +sentry.workspace = true +serde_json.workspace = true +tower-http.workspace = true + +[target.'cfg(unix)'.dependencies] +sd-notify.workspace = true +sd-notify.optional = true + +[lints] +workspace = true diff --git a/src/router/layers.rs b/src/router/layers.rs new file mode 100644 index 00000000..a6e30e98 --- /dev/null +++ b/src/router/layers.rs @@ -0,0 +1,190 @@ +use std::{any::Any, io, sync::Arc, time::Duration}; + +use axum::{ + extract::{DefaultBodyLimit, MatchedPath}, + Router, +}; +use conduit::Server; +use http::{ + header::{self, HeaderName}, + HeaderValue, Method, StatusCode, +}; +use tower::ServiceBuilder; +use tower_http::{ + catch_panic::CatchPanicLayer, + cors::{self, CorsLayer}, + set_header::SetResponseHeaderLayer, + trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer}, + ServiceBuilderExt as _, +}; +use tracing::Level; + +use crate::{request, router}; + +pub(crate) fn build(server: &Arc) -> io::Result> { + let layers = ServiceBuilder::new(); + + #[cfg(feature = "sentry_telemetry")] + let layers = layers.layer(sentry_tower::NewSentryLayer::>::new_from_top()); + + #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] + let layers = layers.layer(compression_layer(server)); + + let layers = layers + .sensitive_headers([header::AUTHORIZATION]) + .sensitive_request_headers([HeaderName::from_static("x-forwarded-for")].into()) + .layer(axum::middleware::from_fn_with_state(Arc::clone(server), request::spawn)) + .layer( + TraceLayer::new_for_http() + .make_span_with(tracing_span::<_>) + .on_failure(DefaultOnFailure::new().level(Level::ERROR)) + .on_request(DefaultOnRequest::new().level(Level::TRACE)) + .on_response(DefaultOnResponse::new().level(Level::DEBUG)), + ) + .layer(axum::middleware::from_fn_with_state(Arc::clone(server), request::handle)) + .layer(SetResponseHeaderLayer::if_not_present( + HeaderName::from_static("origin-agent-cluster"), // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster + HeaderValue::from_static("?1"), + )) + .layer(SetResponseHeaderLayer::if_not_present( + header::X_CONTENT_TYPE_OPTIONS, + HeaderValue::from_static("nosniff"), + )) + .layer(SetResponseHeaderLayer::if_not_present( + header::X_XSS_PROTECTION, + HeaderValue::from_static("0"), + )) + .layer(SetResponseHeaderLayer::if_not_present( + header::X_FRAME_OPTIONS, + HeaderValue::from_static("DENY"), + )) + .layer(SetResponseHeaderLayer::if_not_present( + HeaderName::from_static("permissions-policy"), + HeaderValue::from_static("interest-cohort=(),browsing-topics=()"), + )) + .layer(SetResponseHeaderLayer::if_not_present( + header::CONTENT_SECURITY_POLICY, + HeaderValue::from_static( + "sandbox; default-src 'none'; font-src 'none'; script-src 'none'; plugin-types application/pdf; \ + style-src 'unsafe-inline'; object-src 'self'; frame-ancesors 'none';", + ), + )) + .layer(cors_layer(server)) + .layer(body_limit_layer(server)) + .layer(CatchPanicLayer::custom(catch_panic)); + + let routes = router::build(server); + let layers = routes.layer(layers); + + Ok(layers.into_make_service()) +} + +#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] +fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer { + let mut compression_layer = tower_http::compression::CompressionLayer::new(); + + #[cfg(feature = "zstd_compression")] + { + if server.config.zstd_compression { + compression_layer = compression_layer.zstd(true); + } else { + compression_layer = compression_layer.no_zstd(); + }; + }; + + #[cfg(feature = "gzip_compression")] + { + if server.config.gzip_compression { + compression_layer = compression_layer.gzip(true); + } else { + compression_layer = compression_layer.no_gzip(); + }; + }; + + #[cfg(feature = "brotli_compression")] + { + if server.config.brotli_compression { + compression_layer = compression_layer.br(true); + } else { + compression_layer = compression_layer.no_br(); + }; + }; + + compression_layer +} + +fn cors_layer(_server: &Server) -> CorsLayer { + const METHODS: [Method; 7] = [ + Method::GET, + Method::HEAD, + Method::PATCH, + Method::POST, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ]; + + let headers: [HeaderName; 5] = [ + header::ORIGIN, + HeaderName::from_lowercase(b"x-requested-with").unwrap(), + header::CONTENT_TYPE, + header::ACCEPT, + header::AUTHORIZATION, + ]; + + CorsLayer::new() + .allow_origin(cors::Any) + .allow_methods(METHODS) + .allow_headers(headers) + .max_age(Duration::from_secs(86400)) +} + +fn body_limit_layer(server: &Server) -> DefaultBodyLimit { + DefaultBodyLimit::max( + server + .config + .max_request_size + .try_into() + .expect("failed to convert max request size"), + ) +} + +#[allow(clippy::needless_pass_by_value)] +#[tracing::instrument(skip_all)] +fn catch_panic(err: Box) -> http::Response> { + conduit_service::services() + .server + .requests_panic + .fetch_add(1, std::sync::atomic::Ordering::Release); + + let details = if let Some(s) = err.downcast_ref::() { + s.clone() + } else if let Some(s) = err.downcast_ref::<&str>() { + s.to_string() + } else { + "Unknown internal server error occurred.".to_owned() + }; + + let body = serde_json::json!({ + "errcode": "M_UNKNOWN", + "error": "M_UNKNOWN: Internal server error occurred", + "details": details, + }) + .to_string(); + + http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header(header::CONTENT_TYPE, "application/json") + .body(http_body_util::Full::from(body)) + .expect("Failed to create response for our panic catcher?") +} + +fn tracing_span(request: &http::Request) -> tracing::Span { + let path = if let Some(path) = request.extensions().get::() { + path.as_str() + } else { + request.uri().path() + }; + + tracing::info_span!("router:", %path) +} diff --git a/src/router/mod.rs b/src/router/mod.rs index aa928685..6467d5ee 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,258 +1,29 @@ -use std::{any::Any, io, sync::atomic, time::Duration}; +pub(crate) mod layers; +pub(crate) mod request; +pub(crate) mod router; +pub(crate) mod run; +pub(crate) mod serve; -use axum::{ - extract::{DefaultBodyLimit, MatchedPath}, - response::IntoResponse, - Router, -}; -use http::{ - header::{self, HeaderName, HeaderValue}, - Method, StatusCode, Uri, -}; -use ruma::api::client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, -}; -use tower::ServiceBuilder; -use tower_http::{ - catch_panic::CatchPanicLayer, - cors::{self, CorsLayer}, - set_header::SetResponseHeaderLayer, - trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer}, - ServiceBuilderExt as _, -}; -use tracing::{debug, error, trace, Level}; +extern crate conduit_core as conduit; -use super::{api::ruma_wrapper::RumaResponse, debug_error, services, utils::error::Result, Server}; +use std::{future::Future, pin::Pin, sync::Arc}; -mod routes; +use conduit::{Result, Server}; -pub(crate) async fn build(server: &Server) -> io::Result> { - let base_middlewares = ServiceBuilder::new(); - #[cfg(feature = "sentry_telemetry")] - let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::>::new_from_top()); +conduit::mod_ctor! {} +conduit::mod_dtor! {} - let x_forwarded_for = HeaderName::from_static("x-forwarded-for"); - let permissions_policy = HeaderName::from_static("permissions-policy"); - let origin_agent_cluster = HeaderName::from_static("origin-agent-cluster"); // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster - - let middlewares = base_middlewares - .sensitive_headers([header::AUTHORIZATION]) - .sensitive_request_headers([x_forwarded_for].into()) - .layer(axum::middleware::from_fn(request_spawn)) - .layer( - TraceLayer::new_for_http() - .make_span_with(tracing_span::<_>) - .on_failure(DefaultOnFailure::new().level(Level::ERROR)) - .on_request(DefaultOnRequest::new().level(Level::TRACE)) - .on_response(DefaultOnResponse::new().level(Level::DEBUG)), - ) - .layer(axum::middleware::from_fn(request_handle)) - .layer(SetResponseHeaderLayer::if_not_present( - origin_agent_cluster, - HeaderValue::from_static("?1"), - )) - .layer(SetResponseHeaderLayer::if_not_present( - header::X_CONTENT_TYPE_OPTIONS, - HeaderValue::from_static("nosniff"), - )) - .layer(SetResponseHeaderLayer::if_not_present( - header::X_XSS_PROTECTION, - HeaderValue::from_static("0"), - )) - .layer(SetResponseHeaderLayer::if_not_present( - header::X_FRAME_OPTIONS, - HeaderValue::from_static("DENY"), - )) - .layer(SetResponseHeaderLayer::if_not_present( - permissions_policy, - HeaderValue::from_static("interest-cohort=(),browsing-topics=()"), - )) - .layer(SetResponseHeaderLayer::if_not_present( - header::CONTENT_SECURITY_POLICY, - HeaderValue::from_static( - "sandbox; default-src 'none'; font-src 'none'; script-src 'none'; plugin-types application/pdf; \ - style-src 'unsafe-inline'; object-src 'self'; frame-ancesors 'none';", - ), - )) - .layer(cors_layer(server)) - .layer(DefaultBodyLimit::max( - server - .config - .max_request_size - .try_into() - .expect("failed to convert max request size"), - )) - .layer(CatchPanicLayer::custom(catch_panic_layer)); - - #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] - { - Ok(routes::routes(&server.config) - .layer(compression_layer(server)) - .layer(middlewares) - .into_make_service()) - } - #[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))] - { - Ok(routes::routes().layer(middlewares).into_make_service()) - } +#[no_mangle] +pub extern "Rust" fn start(server: &Arc) -> Pin>>> { + Box::pin(run::start(server.clone())) } -#[tracing::instrument(skip_all, name = "spawn")] -async fn request_spawn( - req: http::Request, next: axum::middleware::Next, -) -> Result { - if services().globals.shutdown.load(atomic::Ordering::Relaxed) { - return Err(StatusCode::SERVICE_UNAVAILABLE); - } - - let fut = next.run(req); - let task = tokio::spawn(fut); - task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) +#[no_mangle] +pub extern "Rust" fn stop(server: &Arc) -> Pin>>> { + Box::pin(run::stop(server.clone())) } -#[tracing::instrument(skip_all, name = "handle")] -async fn request_handle( - req: http::Request, next: axum::middleware::Next, -) -> Result { - let method = req.method().clone(); - let uri = req.uri().clone(); - let result = next.run(req).await; - request_result(&method, &uri, result) -} - -fn request_result( - method: &Method, uri: &Uri, result: axum::response::Response, -) -> Result { - request_result_log(method, uri, &result); - match result.status() { - StatusCode::METHOD_NOT_ALLOWED => request_result_403(method, uri, &result), - _ => Ok(result), - } -} - -#[allow(clippy::unnecessary_wraps)] -fn request_result_403( - _method: &Method, _uri: &Uri, result: &axum::response::Response, -) -> Result { - let error = UiaaResponse::MatrixError(RumaError { - status_code: result.status(), - body: ErrorBody::Standard { - kind: ErrorKind::Unrecognized, - message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), - }, - }); - - Ok(RumaResponse(error).into_response()) -} - -fn request_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { - let status = result.status(); - let reason = status.canonical_reason().unwrap_or("Unknown Reason"); - let code = status.as_u16(); - if status.is_server_error() { - error!(method = ?method, uri = ?uri, "{code} {reason}"); - } else if status.is_client_error() { - debug_error!(method = ?method, uri = ?uri, "{code} {reason}"); - } else if status.is_redirection() { - debug!(method = ?method, uri = ?uri, "{code} {reason}"); - } else { - trace!(method = ?method, uri = ?uri, "{code} {reason}"); - } -} - -/// Cross-Origin-Resource-Sharing header as defined by spec: -/// -fn cors_layer(_server: &Server) -> CorsLayer { - const METHODS: [Method; 7] = [ - Method::GET, - Method::HEAD, - Method::PATCH, - Method::POST, - Method::PUT, - Method::DELETE, - Method::OPTIONS, - ]; - - let headers: [HeaderName; 5] = [ - header::ORIGIN, - HeaderName::from_lowercase(b"x-requested-with").unwrap(), - header::CONTENT_TYPE, - header::ACCEPT, - header::AUTHORIZATION, - ]; - - CorsLayer::new() - .allow_origin(cors::Any) - .allow_methods(METHODS) - .allow_headers(headers) - .max_age(Duration::from_secs(86400)) -} - -#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] -fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer { - let mut compression_layer = tower_http::compression::CompressionLayer::new(); - - #[cfg(feature = "zstd_compression")] - { - if server.config.zstd_compression { - compression_layer = compression_layer.zstd(true); - } else { - compression_layer = compression_layer.no_zstd(); - }; - }; - - #[cfg(feature = "gzip_compression")] - { - if server.config.gzip_compression { - compression_layer = compression_layer.gzip(true); - } else { - compression_layer = compression_layer.no_gzip(); - }; - }; - - #[cfg(feature = "brotli_compression")] - { - if server.config.brotli_compression { - compression_layer = compression_layer.br(true); - } else { - compression_layer = compression_layer.no_br(); - }; - }; - - compression_layer -} - -fn tracing_span(request: &http::Request) -> tracing::Span { - let path = if let Some(path) = request.extensions().get::() { - path.as_str() - } else { - request.uri().path() - }; - - tracing::info_span!("router:", %path) -} - -#[allow(clippy::needless_pass_by_value)] -fn catch_panic_layer(err: Box) -> http::Response> { - let details = if let Some(s) = err.downcast_ref::() { - s.clone() - } else if let Some(s) = err.downcast_ref::<&str>() { - s.to_string() - } else { - "Unknown internal server error occurred.".to_owned() - }; - - let body = serde_json::json!({ - "errcode": "M_UNKNOWN", - "error": "M_UNKNOWN: Internal server error occurred", - "details": details, - }) - .to_string(); - - http::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header(header::CONTENT_TYPE, "application/json") - .body(http_body_util::Full::from(body)) - .expect("Failed to create response for our panic catcher?") +#[no_mangle] +pub extern "Rust" fn run(server: &Arc) -> Pin>>> { + Box::pin(run::run(server.clone())) } diff --git a/src/router/request.rs b/src/router/request.rs new file mode 100644 index 00000000..25bdd98e --- /dev/null +++ b/src/router/request.rs @@ -0,0 +1,102 @@ +use std::sync::{atomic::Ordering, Arc}; + +use axum::{extract::State, response::IntoResponse}; +use conduit::{debug_error, debug_warn, defer, Result, RumaResponse, Server}; +use http::{Method, StatusCode, Uri}; +use ruma::api::client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::UiaaResponse, +}; +use tracing::{debug, error, trace}; + +#[tracing::instrument(skip_all)] +pub(crate) async fn spawn( + State(server): State>, req: http::Request, next: axum::middleware::Next, +) -> Result { + if server.interrupt.load(Ordering::Relaxed) { + debug_warn!("unavailable pending shutdown"); + return Err(StatusCode::SERVICE_UNAVAILABLE); + } + + let active = server.requests_spawn_active.fetch_add(1, Ordering::Relaxed); + trace!(active, "enter"); + defer! {{ + let active = server.requests_spawn_active.fetch_sub(1, Ordering::Relaxed); + let finished = server.requests_spawn_finished.fetch_add(1, Ordering::Relaxed); + trace!(active, finished, "leave"); + }}; + + let fut = next.run(req); + let task = server.runtime().spawn(fut); + task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) +} + +#[tracing::instrument(skip_all, name = "handle")] +pub(crate) async fn handle( + State(server): State>, req: http::Request, next: axum::middleware::Next, +) -> Result { + if server.interrupt.load(Ordering::Relaxed) { + debug_warn!( + method = %req.method(), + uri = %req.uri(), + "unavailable pending shutdown" + ); + + return Err(StatusCode::SERVICE_UNAVAILABLE); + } + + let active = server + .requests_handle_active + .fetch_add(1, Ordering::Relaxed); + trace!(active, "enter"); + defer! {{ + let active = server.requests_handle_active.fetch_sub(1, Ordering::Relaxed); + let finished = server.requests_handle_finished.fetch_add(1, Ordering::Relaxed); + trace!(active, finished, "leave"); + }}; + + let method = req.method().clone(); + let uri = req.uri().clone(); + let result = next.run(req).await; + handle_result(&method, &uri, result) +} + +fn handle_result( + method: &Method, uri: &Uri, result: axum::response::Response, +) -> Result { + handle_result_log(method, uri, &result); + match result.status() { + StatusCode::METHOD_NOT_ALLOWED => handle_result_403(method, uri, &result), + _ => Ok(result), + } +} + +#[allow(clippy::unnecessary_wraps)] +fn handle_result_403( + _method: &Method, _uri: &Uri, result: &axum::response::Response, +) -> Result { + let error = UiaaResponse::MatrixError(RumaError { + status_code: result.status(), + body: ErrorBody::Standard { + kind: ErrorKind::Unrecognized, + message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), + }, + }); + + Ok(RumaResponse(error).into_response()) +} + +fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { + let status = result.status(); + let reason = status.canonical_reason().unwrap_or("Unknown Reason"); + let code = status.as_u16(); + if status.is_server_error() { + error!(method = ?method, uri = ?uri, "{code} {reason}"); + } else if status.is_client_error() { + debug_error!(method = ?method, uri = ?uri, "{code} {reason}"); + } else if status.is_redirection() { + debug!(method = ?method, uri = ?uri, "{code} {reason}"); + } else { + trace!(method = ?method, uri = ?uri, "{code} {reason}"); + } +} diff --git a/src/router/router.rs b/src/router/router.rs new file mode 100644 index 00000000..4f45df54 --- /dev/null +++ b/src/router/router.rs @@ -0,0 +1,20 @@ +use std::sync::Arc; + +use axum::{response::IntoResponse, routing::get, Router}; +use conduit::{Error, Server}; +use http::Uri; +use ruma::api::client::error::ErrorKind; + +extern crate conduit_api as api; + +pub(crate) fn build(server: &Arc) -> Router { + let router = Router::new().fallback(not_found).route("/", get(it_works)); + + api::router::build(router, server) +} + +async fn not_found(_uri: Uri) -> impl IntoResponse { + Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") +} + +async fn it_works() -> &'static str { "hewwo from conduwuit woof!" } diff --git a/src/router/run.rs b/src/router/run.rs new file mode 100644 index 00000000..ee75ffec --- /dev/null +++ b/src/router/run.rs @@ -0,0 +1,185 @@ +use std::{sync::Arc, time::Duration}; + +use axum_server::Handle as ServerHandle; +use tokio::{ + signal, + sync::oneshot::{self, Sender}, +}; +use tracing::{debug, info, warn}; + +extern crate conduit_admin as admin; +extern crate conduit_core as conduit; +extern crate conduit_database as database; +extern crate conduit_service as service; + +use std::sync::atomic::Ordering; + +use conduit::{debug_info, trace, Error, Result, Server}; +use database::KeyValueDatabase; +use service::{services, Services}; + +use crate::{layers, serve}; + +/// Main loop base +#[tracing::instrument(skip_all)] +pub(crate) async fn run(server: Arc) -> Result<(), Error> { + let config = &server.config; + let app = layers::build(&server)?; + let addrs = config.get_bind_addrs(); + + // Install the admin room callback here for now + _ = services().admin.handle.lock().await.insert(admin::handle); + + // Setup shutdown/signal handling + let handle = ServerHandle::new(); + _ = server + .shutdown + .lock() + .expect("locked") + .insert(handle.clone()); + + server.interrupt.store(false, Ordering::Release); + let (tx, rx) = oneshot::channel::<()>(); + let sigs = server.runtime().spawn(sighandle(server.clone(), tx)); + + // Prepare to serve http clients + let res; + // Serve clients + if cfg!(unix) && config.unix_socket_path.is_some() { + res = serve::unix_socket(&server, app, rx).await; + } else if config.tls.is_some() { + res = serve::tls(&server, app, handle.clone(), addrs).await; + } else { + res = serve::plain(&server, app, handle.clone(), addrs).await; + } + + // Join the signal handler before we leave. + sigs.abort(); + _ = sigs.await; + + // Reset the axum handle instance; this should be reusable and might be + // reload-survivable but better to be safe than sorry. + _ = server.shutdown.lock().expect("locked").take(); + + // Remove the admin room callback + _ = services().admin.handle.lock().await.take(); + + debug_info!("Finished"); + Ok(res?) +} + +/// Async initializations +#[tracing::instrument(skip_all)] +pub(crate) async fn start(server: Arc) -> Result<(), Error> { + debug!("Starting..."); + let d = Arc::new(KeyValueDatabase::load_or_create(&server).await?); + let s = Box::new(Services::build(server, d.clone()).await?); + _ = service::SERVICES + .write() + .expect("write locked") + .insert(Box::leak(s)); + services().start().await?; + + #[cfg(feature = "systemd")] + #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + + debug!("Started"); + Ok(()) +} + +/// Async destructions +#[tracing::instrument(skip_all)] +pub(crate) async fn stop(_server: Arc) -> Result<(), Error> { + debug!("Shutting down..."); + + // Wait for all completions before dropping or we'll lose them to the module + // unload and explode. + services().shutdown().await; + + // Deactivate services(). Any further use will panic the caller. + let s = service::SERVICES + .write() + .expect("write locked") + .take() + .unwrap(); + + let s = std::ptr::from_ref(s) as *mut Services; + //SAFETY: Services was instantiated in start() and leaked into the SERVICES + // global perusing as 'static for the duration of run_server(). Now we reclaim + // it to drop it before unloading the module. If this is not done there will be + // multiple instances after module reload. + let s = unsafe { Box::from_raw(s) }; + debug!("Cleaning up..."); + // Drop it so we encounter any trouble before the infolog message + drop(s); + + #[cfg(feature = "systemd")] + #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); + + info!("Shutdown complete."); + Ok(()) +} + +#[tracing::instrument(skip_all)] +async fn sighandle(server: Arc, tx: Sender<()>) -> Result<(), Error> { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + + let reload = cfg!(unix) && cfg!(debug_assertions); + server.reload.store(reload, Ordering::Release); + }; + + #[cfg(unix)] + let ctrl_bs = async { + signal::unix::signal(signal::unix::SignalKind::quit()) + .expect("failed to install Ctrl+\\ handler") + .recv() + .await; + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install SIGTERM handler") + .recv() + .await; + }; + + debug!("Installed signal handlers"); + let sig: &str; + #[cfg(unix)] + tokio::select! { + () = ctrl_c => { sig = "Ctrl+C"; }, + () = ctrl_bs => { sig = "Ctrl+\\"; }, + () = terminate => { sig = "SIGTERM"; }, + } + + #[cfg(not(unix))] + tokio::select! { + _ = ctrl_c => { sig = "Ctrl+C"; }, + } + + warn!("Received {}", sig); + server.interrupt.store(true, Ordering::Release); + services().globals.rotate.fire(); + tx.send(()) + .expect("failed sending shutdown transaction to oneshot channel"); + + if let Some(handle) = server.shutdown.lock().expect("locked").as_ref() { + let pending = server.requests_spawn_active.load(Ordering::Relaxed); + if pending > 0 { + let timeout = Duration::from_secs(36); + trace!(pending, ?timeout, "Notifying for graceful shutdown"); + handle.graceful_shutdown(Some(timeout)); + } else { + debug!(pending, "Notifying for immediate shutdown"); + handle.shutdown(); + } + } + + Ok(()) +} diff --git a/src/router/serve.rs b/src/router/serve.rs new file mode 100644 index 00000000..37ed9902 --- /dev/null +++ b/src/router/serve.rs @@ -0,0 +1,137 @@ +#[cfg(unix)] +use std::fs::Permissions; // only for UNIX sockets stuff and *nix container checks +#[cfg(unix)] +use std::os::unix::fs::PermissionsExt as _; +use std::{ + io, + net::SocketAddr, + sync::{atomic::Ordering, Arc}, +}; + +use axum::Router; +use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; +#[cfg(feature = "axum_dual_protocol")] +use axum_server_dual_protocol::ServerExt; +use conduit::{debug_info, Server}; +use tokio::{ + sync::oneshot::{self}, + task::JoinSet, +}; +use tracing::{debug, info, warn}; + +pub(crate) async fn plain( + server: &Arc, app: axum::routing::IntoMakeService, handle: ServerHandle, addrs: Vec, +) -> io::Result<()> { + let mut join_set = JoinSet::new(); + for addr in &addrs { + join_set.spawn_on(bind(*addr).handle(handle.clone()).serve(app.clone()), server.runtime()); + } + + info!("Listening on {addrs:?}"); + while join_set.join_next().await.is_some() {} + + let spawn_active = server.requests_spawn_active.load(Ordering::Relaxed); + let handle_active = server.requests_handle_active.load(Ordering::Relaxed); + debug_info!( + spawn_finished = server.requests_spawn_finished.load(Ordering::Relaxed), + handle_finished = server.requests_handle_finished.load(Ordering::Relaxed), + panics = server.requests_panic.load(Ordering::Relaxed), + spawn_active, + handle_active, + "Stopped listening on {addrs:?}", + ); + + debug_assert!(spawn_active == 0, "active request tasks are not joined"); + debug_assert!(handle_active == 0, "active request handles still pending"); + + Ok(()) +} + +pub(crate) async fn tls( + server: &Arc, app: axum::routing::IntoMakeService, handle: ServerHandle, addrs: Vec, +) -> io::Result<()> { + let config = &server.config; + let tls = config.tls.as_ref().expect("TLS configuration"); + + debug!( + "Using direct TLS. Certificate path {} and certificate private key path {}", + &tls.certs, &tls.key + ); + info!( + "Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS." + ); + let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; + + if cfg!(feature = "axum_dual_protocol") { + info!( + "conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only \ + take effect if `dual_protocol` is enabled in `[global.tls]`" + ); + } + + let mut join_set = JoinSet::new(); + if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { + #[cfg(feature = "axum_dual_protocol")] + for addr in &addrs { + join_set.spawn_on( + axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) + .set_upgrade(false) + .handle(handle.clone()) + .serve(app.clone()), + server.runtime(), + ); + } + } else { + for addr in &addrs { + join_set.spawn_on( + bind_rustls(*addr, conf.clone()) + .handle(handle.clone()) + .serve(app.clone()), + server.runtime(), + ); + } + } + + if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { + warn!( + "Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)", + addrs, &tls.certs + ); + } else { + info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs); + } + + while join_set.join_next().await.is_some() {} + + Ok(()) +} + +#[cfg(unix)] +#[allow(unused_variables)] +pub(crate) async fn unix_socket( + server: &Arc, app: axum::routing::IntoMakeService, rx: oneshot::Receiver<()>, +) -> io::Result<()> { + let config = &server.config; + let path = config.unix_socket_path.as_ref().unwrap(); + + if path.exists() { + warn!( + "UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", + path.display() + ); + tokio::fs::remove_file(&path).await?; + } + + tokio::fs::create_dir_all(path.parent().unwrap()).await?; + + let socket_perms = config.unix_socket_perms.to_string(); + let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); + tokio::fs::set_permissions(&path, Permissions::from_mode(octal_perms)) + .await + .unwrap(); + + let bind = tokio::net::UnixListener::bind(path)?; + info!("Listening at {:?}", path); + + Ok(()) +} diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml new file mode 100644 index 00000000..580706ff --- /dev/null +++ b/src/service/Cargo.toml @@ -0,0 +1,115 @@ +[package] +name = "conduit_service" +version.workspace = true +edition.workspace = true + +[lib] +path = "mod.rs" +crate-type = [ + "rlib", +# "dylib", +] + +[features] +default = [ + "rocksdb", + "io_uring", + "jemalloc", + "gzip_compression", + "zstd_compression", + "brotli_compression", + "release_max_log_level", +] + +dev_release_log_level = [] +release_max_log_level = [ + "tracing/max_level_trace", + "tracing/release_max_level_info", + "log/max_level_trace", + "log/release_max_level_info", +] +sqlite = [ + "dep:rusqlite", + "dep:parking_lot", + "dep:thread_local", +] +rocksdb = [ + "dep:rust-rocksdb", +] +jemalloc = [ + "dep:tikv-jemalloc-sys", + "dep:tikv-jemalloc-ctl", + "dep:tikv-jemallocator", + "rust-rocksdb/jemalloc", +] +io_uring = [ + "rust-rocksdb/io-uring", +] +zstd_compression = [ + "rust-rocksdb/zstd", +] +gzip_compression = [ + "reqwest/gzip", +] +brotli_compression = [ + "reqwest/brotli", +] +sha256_media = [ + "dep:sha2", +] + +[dependencies] +argon2.workspace = true +async-trait.workspace = true +base64.workspace = true +bytes.workspace = true +clap.workspace = true +conduit-core.workspace = true +conduit-database.workspace = true +cyborgtime.workspace = true +futures-util.workspace = true +hickory-resolver.workspace = true +hmac.workspace = true +http.workspace = true +image.workspace = true +ipaddress.workspace = true +itertools.workspace = true +jsonwebtoken.workspace = true +log.workspace = true +loole.workspace = true +lru-cache.workspace = true +parking_lot.optional = true +parking_lot.workspace = true +rand.workspace = true +regex.workspace = true +reqwest.workspace = true +ruma-identifiers-validation.workspace = true +ruma.workspace = true +rusqlite.optional = true +rusqlite.workspace = true +rust-rocksdb.optional = true +rust-rocksdb.workspace = true +serde_json.workspace = true +serde.workspace = true +serde_yaml.workspace = true +sha-1.workspace = true +sha2.optional = true +sha2.workspace = true +thread_local.optional = true +thread_local.workspace = true +tikv-jemallocator.optional = true +tikv-jemallocator.workspace = true +tikv-jemalloc-ctl.optional = true +tikv-jemalloc-ctl.workspace = true +tikv-jemalloc-sys.optional = true +tikv-jemalloc-sys.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true +url.workspace = true +webpage.workspace = true +zstd.optional = true +zstd.workspace = true + +[lints] +workspace = true diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index b538ef76..492c500c 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -8,7 +8,7 @@ use ruma::{ use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Places one event in the account data of the user and removes the /// previous entry. fn update( diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index e8de80e6..7e17e145 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -1,8 +1,8 @@ mod data; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; -pub(crate) use data::Data; +pub use data::Data; use ruma::{ events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, serde::Raw, @@ -11,15 +11,15 @@ use ruma::{ use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { /// Places one event in the account data of the user and removes the /// previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - pub(crate) fn update( + pub fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, ) -> Result<()> { @@ -28,7 +28,7 @@ impl Service { /// Searches the account data for a specific kind. #[tracing::instrument(skip(self, room_id, user_id, event_type))] - pub(crate) fn get( + pub fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, ) -> Result>> { self.db.get(room_id, user_id, event_type) @@ -36,7 +36,7 @@ impl Service { /// Returns all changes to the account data that happened after `since`. #[tracing::instrument(skip(self, room_id, user_id, since))] - pub(crate) fn changes_since( + pub fn changes_since( &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, ) -> Result>> { self.db.changes_since(room_id, user_id, since) diff --git a/src/service/admin/mod.rs b/src/service/admin.rs similarity index 51% rename from src/service/admin/mod.rs rename to src/service/admin.rs index 8907869e..b44c9d0c 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin.rs @@ -1,11 +1,9 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc}; -use clap::Parser; -use regex::Regex; +use conduit::{Error, Result}; use ruma::{ api::client::error::ErrorKind, events::{ - relation::InReplyTo, room::{ canonical_alias::RoomCanonicalAliasEventContent, create::RoomCreateEventContent, @@ -13,133 +11,62 @@ use ruma::{ history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, join_rules::{JoinRule, RoomJoinRulesEventContent}, member::{MembershipState, RoomMemberEventContent}, - message::{Relation::Reply, RoomMessageEventContent}, + message::RoomMessageEventContent, name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent, topic::RoomTopicEventContent, }, TimelineEventType, }, - EventId, MxcUri, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, + EventId, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomVersionId, UserId, }; use serde_json::value::to_raw_value; -use tokio::sync::{Mutex, MutexGuard}; +use tokio::{sync::Mutex, task::JoinHandle}; use tracing::{error, warn}; -use self::{fsck::FsckCommand, tester::TesterCommands}; -use super::pdu::PduBuilder; -use crate::{ - service::admin::{ - appservice::AppserviceCommand, debug::DebugCommand, federation::FederationCommand, media::MediaCommand, - query::QueryCommand, room::RoomCommand, server::ServerCommand, user::UserCommand, - }, - services, Error, Result, -}; +use crate::{pdu::PduBuilder, services}; -pub(crate) mod appservice; -pub(crate) mod debug; -pub(crate) mod federation; -pub(crate) mod fsck; -pub(crate) mod media; -pub(crate) mod query; -pub(crate) mod room; -pub(crate) mod server; -pub(crate) mod tester; -pub(crate) mod user; +pub type HandlerResult = Pin> + Send>>; +pub type Handler = fn(AdminRoomEvent, OwnedRoomId, OwnedUserId) -> HandlerResult; -const PAGE_SIZE: usize = 100; - -#[cfg_attr(test, derive(Debug))] -#[derive(Parser)] -#[command(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))] -enum AdminCommand { - #[command(subcommand)] - /// - Commands for managing appservices - Appservices(AppserviceCommand), - - #[command(subcommand)] - /// - Commands for managing local users - Users(UserCommand), - - #[command(subcommand)] - /// - Commands for managing rooms - Rooms(RoomCommand), - - #[command(subcommand)] - /// - Commands for managing federation - Federation(FederationCommand), - - #[command(subcommand)] - /// - Commands for managing the server - Server(ServerCommand), - - #[command(subcommand)] - /// - Commands for managing media - Media(MediaCommand), - - #[command(subcommand)] - /// - Commands for debugging things - Debug(DebugCommand), - - #[command(subcommand)] - /// - Query all the database getters and iterators - Query(QueryCommand), - - #[command(subcommand)] - /// - Query all the database getters and iterators - Fsck(FsckCommand), - - #[command(subcommand)] - Tester(TesterCommands), +pub struct Service { + sender: loole::Sender, + receiver: Mutex>, + handler_join: Mutex>>, + pub handle: Mutex>, } #[derive(Debug)] -pub(crate) enum AdminRoomEvent { +pub enum AdminRoomEvent { ProcessMessage(String, Arc), SendMessage(RoomMessageEventContent), } -pub(crate) struct Service { - pub(crate) sender: loole::Sender, - receiver: Mutex>, -} - impl Service { - pub(crate) fn build() -> Arc { + #[must_use] + pub fn build() -> Arc { let (sender, receiver) = loole::unbounded(); Arc::new(Self { sender, receiver: Mutex::new(receiver), + handler_join: Mutex::new(None), + handle: Mutex::new(None), }) } - pub(crate) fn start_handler(self: &Arc) { - let self2 = Arc::clone(self); - tokio::spawn(async move { - self2 + pub async fn start_handler(self: &Arc) { + let self_ = Arc::clone(self); + let handle = services().server.runtime().spawn(async move { + self_ .handler() .await .expect("Failed to initialize admin room handler"); }); + + _ = self.handler_join.lock().await.insert(handle); } - pub(crate) async fn process_message(&self, room_message: String, event_id: Arc) { - self.send(AdminRoomEvent::ProcessMessage(room_message, event_id)) - .await; - } - - pub(crate) async fn send_message(&self, message_content: RoomMessageEventContent) { - self.send(AdminRoomEvent::SendMessage(message_content)) - .await; - } - - async fn send(&self, message: AdminRoomEvent) { - debug_assert!(!self.sender.is_full(), "channel full"); - debug_assert!(!self.sender.is_closed(), "channel closed"); - self.sender.send(message).expect("message sent"); - } - - async fn handler(&self) -> Result<()> { + async fn handler(self: &Arc) -> Result<()> { let receiver = self.receiver.lock().await; let Ok(Some(admin_room)) = Self::get_admin_room().await else { return Ok(()); @@ -151,246 +78,72 @@ impl Service { debug_assert!(!receiver.is_closed(), "channel closed"); tokio::select! { event = receiver.recv_async() => match event { - Ok(event) => self.handle_event(event, &admin_room, &server_user).await?, - Err(e) => error!("Failed to receive admin room event from channel: {e}"), + Ok(event) => self.receive(event, &admin_room, &server_user).await?, + Err(_e) => return Ok(()), } } } } - async fn handle_event(&self, event: AdminRoomEvent, admin_room: &OwnedRoomId, server_user: &UserId) -> Result<()> { - let (mut message_content, reply) = match event { - AdminRoomEvent::SendMessage(content) => (content, None), - AdminRoomEvent::ProcessMessage(room_message, reply_id) => { - (self.process_admin_message(room_message).await, Some(reply_id)) - }, - }; - - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(admin_room.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - if let Some(reply) = reply { - message_content.relates_to = Some(Reply { - in_reply_to: InReplyTo { - event_id: reply.into(), - }, - }); + pub async fn close(&self) { + self.interrupt(); + if let Some(handler_join) = self.handler_join.lock().await.take() { + if let Err(e) = handler_join.await { + error!("Failed to shutdown: {e:?}"); + } } - - let response_pdu = PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&message_content).expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }; - - if let Err(e) = services() - .rooms - .timeline - .build_and_append_pdu(response_pdu, server_user, admin_room, &state_lock) - .await - { - self.handle_response_error(&e, admin_room, server_user, &state_lock) - .await?; - } - - Ok(()) } - async fn handle_response_error( - &self, e: &Error, admin_room: &OwnedRoomId, server_user: &UserId, state_lock: &MutexGuard<'_, ()>, - ) -> Result<()> { - error!("Failed to build and append admin room response PDU: \"{e}\""); - let error_room_message = RoomMessageEventContent::text_plain(format!( - "Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \ - successfully, but we could not return the output." - )); + pub fn interrupt(&self) { + if !self.sender.is_closed() { + self.sender.close(); + } + } - let response_pdu = PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&error_room_message).expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }; + pub async fn send_message(&self, message_content: RoomMessageEventContent) { + self.send(AdminRoomEvent::SendMessage(message_content)) + .await; + } + + pub async fn process_message(&self, room_message: String, event_id: Arc) { + self.send(AdminRoomEvent::ProcessMessage(room_message, event_id)) + .await; + } + + async fn receive(&self, event: AdminRoomEvent, room: &OwnedRoomId, user: &UserId) -> Result<(), Error> { + if let Some(handle) = self.handle.lock().await.as_ref() { + handle(event, room.clone(), user.into()).await + } else { + Err(Error::Err("Admin module is not loaded.".into())) + } + } + + async fn send(&self, message: AdminRoomEvent) { + debug_assert!(!self.sender.is_full(), "channel full"); + debug_assert!(!self.sender.is_closed(), "channel closed"); + self.sender.send(message).expect("message sent"); + } + + /// Gets the room ID of the admin room + /// + /// Errors are propagated from the database, and will have None if there is + /// no admin room + pub async fn get_admin_room() -> Result> { + let admin_room_alias: Box = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); services() .rooms - .timeline - .build_and_append_pdu(response_pdu, server_user, admin_room, state_lock) - .await?; - - Ok(()) - } - - // Parse and process a message from the admin room - async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent { - let mut lines = room_message.lines().filter(|l| !l.trim().is_empty()); - let command_line = lines.next().expect("each string has at least one line"); - let body = lines.collect::>(); - - let admin_command = match self.parse_admin_command(command_line) { - Ok(command) => command, - Err(error) => { - let server_name = services().globals.server_name(); - let message = error.replace("server.name", server_name.as_str()); - let html_message = self.usage_to_html(&message, server_name); - - return RoomMessageEventContent::text_html(message, html_message); - }, - }; - - match self.process_admin_command(admin_command, body).await { - Ok(reply_message) => reply_message, - Err(error) => { - let markdown_message = format!("Encountered an error while handling the command:\n```\n{error}\n```",); - let html_message = format!("Encountered an error while handling the command:\n
\n{error}\n
",); - - RoomMessageEventContent::text_html(markdown_message, html_message) - }, - } - } - - // Parse chat messages from the admin room into an AdminCommand object - fn parse_admin_command(&self, command_line: &str) -> Result { - // Note: argv[0] is `@conduit:servername:`, which is treated as the main command - let mut argv = command_line.split_whitespace().collect::>(); - - // Replace `help command` with `command --help` - // Clap has a help subcommand, but it omits the long help description. - if argv.len() > 1 && argv[1] == "help" { - argv.remove(1); - argv.push("--help"); - } - - // Backwards compatibility with `register_appservice`-style commands - let command_with_dashes_argv1; - if argv.len() > 1 && argv[1].contains('_') { - command_with_dashes_argv1 = argv[1].replace('_', "-"); - argv[1] = &command_with_dashes_argv1; - } - - // Backwards compatibility with `register_appservice`-style commands - let command_with_dashes_argv2; - if argv.len() > 2 && argv[2].contains('_') { - command_with_dashes_argv2 = argv[2].replace('_', "-"); - argv[2] = &command_with_dashes_argv2; - } - - // if the user is using the `query` command (argv[1]), replace the database - // function/table calls with underscores to match the codebase - let command_with_dashes_argv3; - if argv.len() > 3 && argv[1].eq("query") { - command_with_dashes_argv3 = argv[3].replace('_', "-"); - argv[3] = &command_with_dashes_argv3; - } - - AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) - } - - async fn process_admin_command(&self, command: AdminCommand, body: Vec<&str>) -> Result { - let reply_message_content = match command { - AdminCommand::Appservices(command) => appservice::process(command, body).await?, - AdminCommand::Media(command) => media::process(command, body).await?, - AdminCommand::Users(command) => user::process(command, body).await?, - AdminCommand::Rooms(command) => room::process(command, body).await?, - AdminCommand::Federation(command) => federation::process(command, body).await?, - AdminCommand::Server(command) => server::process(command, body).await?, - AdminCommand::Debug(command) => debug::process(command, body).await?, - AdminCommand::Query(command) => query::process(command, body).await?, - AdminCommand::Fsck(command) => fsck::process(command, body).await?, - AdminCommand::Tester(command) => tester::process(command, body).await?, - }; - - Ok(reply_message_content) - } - - // Utility to turn clap's `--help` text to HTML. - fn usage_to_html(&self, text: &str, server_name: &ServerName) -> String { - // Replace `@conduit:servername:-subcmdname` with `@conduit:servername: - // subcmdname` - let text = text.replace(&format!("@conduit:{server_name}:-"), &format!("@conduit:{server_name}: ")); - - // For the conduit admin room, subcommands become main commands - let text = text.replace("SUBCOMMAND", "COMMAND"); - let text = text.replace("subcommand", "command"); - - // Escape option names (e.g. ``) since they look like HTML tags - let text = escape_html(&text); - - // Italicize the first line (command name and version text) - let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); - let text = re.replace_all(&text, "$1\n"); - - // Unmerge wrapped lines - let text = text.replace("\n ", " "); - - // Wrap option names in backticks. The lines look like: - // -V, --version Prints version information - // And are converted to: - // -V, --version: Prints version information - // (?m) enables multi-line mode for ^ and $ - let re = Regex::new("(?m)^ {4}(([a-zA-Z_&;-]+(, )?)+) +(.*)$").expect("Regex compilation should not fail"); - let text = re.replace_all(&text, "$1: $4"); - - // Look for a `[commandbody]` tag. If it exists, use all lines below it that - // start with a `#` in the USAGE section. - let mut text_lines = text.lines().collect::>(); - let mut command_body = String::new(); - - if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { - text_lines.remove(line_index); - - while text_lines - .get(line_index) - .is_some_and(|line| line.starts_with('#')) - { - command_body += if text_lines[line_index].starts_with("# ") { - &text_lines[line_index][2..] - } else { - &text_lines[line_index][1..] - }; - command_body += "[nobr]\n"; - text_lines.remove(line_index); - } - } - - let text = text_lines.join("\n"); - - // Improve the usage section - let text = if command_body.is_empty() { - // Wrap the usage line in code tags - let re = Regex::new("(?m)^USAGE:\n {4}(@conduit:.*)$").expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n$1").to_string() - } else { - // Wrap the usage line in a code block, and add a yaml block example - // This makes the usage of e.g. `register-appservice` more accurate - let re = Regex::new("(?m)^USAGE:\n {4}(.*?)\n\n").expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n
$1[nobr]\n[commandbodyblock]
") - .replace("[commandbodyblock]", &command_body) - }; - - // Add HTML line-breaks - - text.replace("\n\n\n", "\n\n") - .replace('\n', "
\n") - .replace("[nobr]
", "") + .alias + .resolve_local_alias(&admin_room_alias) } /// Create the admin room. /// /// Users in this room are considered admins by conduit, and the room can be /// used to issue admin commands by talking to the server user inside it. - pub(crate) async fn create_admin_room(&self) -> Result<()> { + pub async fn create_admin_room(&self) -> Result<()> { let room_id = RoomId::new(services().globals.server_name()); services().rooms.short.get_or_create_shortroomid(&room_id)?; @@ -637,25 +390,10 @@ impl Service { Ok(()) } - /// Gets the room ID of the admin room - /// - /// Errors are propagated from the database, and will have None if there is - /// no admin room - pub(crate) async fn get_admin_room() -> Result> { - let admin_room_alias: Box = format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - - services() - .rooms - .alias - .resolve_local_alias(&admin_room_alias) - } - /// Invite the user to the conduit admin room. /// /// In conduit, this is equivalent to granting admin privileges. - pub(crate) async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> { + pub async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> { if let Some(room_id) = Self::get_admin_room().await? { let mutex_state = Arc::clone( services() @@ -754,7 +492,7 @@ impl Service { // Send welcome message services().rooms.timeline.build_and_append_pdu( - PduBuilder { + PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_html( format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", services().globals.server_name()), @@ -776,54 +514,3 @@ impl Service { } } } - -fn escape_html(s: &str) -> String { - s.replace('&', "&") - .replace('<', "<") - .replace('>', ">") -} - -fn get_room_info(id: &OwnedRoomId) -> (OwnedRoomId, u64, String) { - ( - id.clone(), - services() - .rooms - .state_cache - .room_joined_count(id) - .ok() - .flatten() - .unwrap_or(0), - services() - .rooms - .state_accessor - .get_name(id) - .ok() - .flatten() - .unwrap_or_else(|| id.to_string()), - ) -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn get_help_short() { get_help_inner("-h"); } - - #[test] - fn get_help_long() { get_help_inner("--help"); } - - #[test] - fn get_help_subcommand() { get_help_inner("help"); } - - fn get_help_inner(input: &str) { - let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) - .unwrap_err() - .to_string(); - - // Search for a handful of keywords that suggest the help printed properly - assert!(error.contains("Usage:")); - assert!(error.contains("Commands:")); - assert!(error.contains("Options:")); - } -} diff --git a/src/service/admin/test_cmd/mod.rs b/src/service/admin/test_cmd/mod.rs deleted file mode 100644 index fc49885e..00000000 --- a/src/service/admin/test_cmd/mod.rs +++ /dev/null @@ -1,41 +0,0 @@ -//! test commands generally used for hot lib reloadable functions. -//! see for more details if you are a dev - -//#[cfg(not(feature = "hot_reload"))] -//#[allow(unused_imports)] -//#[allow(clippy::wildcard_imports)] -// non hot reloadable functions (?) -//use hot_lib::*; -#[cfg(feature = "hot_reload")] -#[allow(unused_imports)] -#[allow(clippy::wildcard_imports)] -use hot_lib_funcs::*; -use ruma::events::room::message::RoomMessageEventContent; - -use crate::{debug_error, Result}; - -#[cfg(feature = "hot_reload")] -#[hot_lib_reloader::hot_module(dylib = "lib")] -mod hot_lib_funcs { - // these will be functions from lib.rs, so `use hot_lib_funcs::test_command;` - hot_functions_from_file!("hot_lib/src/lib.rs"); -} - -#[cfg_attr(test, derive(Debug))] -#[derive(clap::Subcommand)] -pub(crate) enum TestCommands { - // !admin test test1 - Test1, -} - -pub(crate) async fn process(command: TestCommands, _body: Vec<&str>) -> Result { - Ok(match command { - TestCommands::Test1 => { - debug_error!("before calling test_command"); - test_command(); - debug_error!("after calling test_command"); - - RoomMessageEventContent::notice_plain(String::from("loaded")) - }, - }) -} diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index cb19ebb0..52c8b34d 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -2,7 +2,7 @@ use ruma::api::appservice::Registration; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Registers an appservice and returns the ID to the caller fn register_appservice(&self, yaml: Registration) -> Result; diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 1fd46d68..1d3bc98a 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,8 +1,8 @@ mod data; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; -pub(crate) use data::Data; +pub use data::Data; use futures_util::Future; use regex::RegexSet; use ruma::{ @@ -15,14 +15,15 @@ use crate::{services, Result}; /// Compiled regular expressions for a namespace #[derive(Clone, Debug)] -pub(crate) struct NamespaceRegex { - pub(crate) exclusive: Option, - pub(crate) non_exclusive: Option, +pub struct NamespaceRegex { + pub exclusive: Option, + pub non_exclusive: Option, } impl NamespaceRegex { /// Checks if this namespace has rights to a namespace - pub(crate) fn is_match(&self, heystack: &str) -> bool { + #[must_use] + pub fn is_match(&self, heystack: &str) -> bool { if self.is_exclusive_match(heystack) { return true; } @@ -36,7 +37,8 @@ impl NamespaceRegex { } /// Checks if this namespace has exlusive rights to a namespace - pub(crate) fn is_exclusive_match(&self, heystack: &str) -> bool { + #[must_use] + pub fn is_exclusive_match(&self, heystack: &str) -> bool { if let Some(exclusive) = &self.exclusive { if exclusive.is_match(heystack) { return true; @@ -47,11 +49,13 @@ impl NamespaceRegex { } impl RegistrationInfo { - pub(crate) fn is_user_match(&self, user_id: &UserId) -> bool { + #[must_use] + pub fn is_user_match(&self, user_id: &UserId) -> bool { self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() } - pub(crate) fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { + #[must_use] + pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() } } @@ -88,11 +92,11 @@ impl TryFrom> for NamespaceRegex { /// Appservice registration combined with its compiled regular expressions. #[derive(Clone, Debug)] -pub(crate) struct RegistrationInfo { - pub(crate) registration: Registration, - pub(crate) users: NamespaceRegex, - pub(crate) aliases: NamespaceRegex, - pub(crate) rooms: NamespaceRegex, +pub struct RegistrationInfo { + pub registration: Registration, + pub users: NamespaceRegex, + pub aliases: NamespaceRegex, + pub rooms: NamespaceRegex, } impl TryFrom for RegistrationInfo { @@ -108,13 +112,13 @@ impl TryFrom for RegistrationInfo { } } -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, registration_info: RwLock>, } impl Service { - pub(crate) fn build(db: &'static dyn Data) -> Result { + pub fn build(db: Arc) -> Result { let mut registration_info = BTreeMap::new(); // Inserting registrations into cache for appservice in db.all()? { @@ -134,7 +138,7 @@ impl Service { } /// Registers an appservice and returns the ID to the caller - pub(crate) async fn register_appservice(&self, yaml: Registration) -> Result { + pub async fn register_appservice(&self, yaml: Registration) -> Result { //TODO: Check for collisions between exclusive appservice namespaces services() .appservice @@ -151,7 +155,7 @@ impl Service { /// # Arguments /// /// * `service_name` - the name you send to register the service previously - pub(crate) async fn unregister_appservice(&self, service_name: &str) -> Result<()> { + pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> { // removes the appservice registration info services() .appservice @@ -171,7 +175,7 @@ impl Service { Ok(()) } - pub(crate) async fn get_registration(&self, id: &str) -> Option { + pub async fn get_registration(&self, id: &str) -> Option { self.registration_info .read() .await @@ -180,7 +184,7 @@ impl Service { .map(|info| info.registration) } - pub(crate) async fn iter_ids(&self) -> Vec { + pub async fn iter_ids(&self) -> Vec { self.registration_info .read() .await @@ -189,7 +193,7 @@ impl Service { .collect() } - pub(crate) async fn find_from_token(&self, token: &str) -> Option { + pub async fn find_from_token(&self, token: &str) -> Option { self.read() .await .values() @@ -198,7 +202,7 @@ impl Service { } /// Checks if a given user id matches any exclusive appservice regex - pub(crate) async fn is_exclusive_user_id(&self, user_id: &UserId) -> bool { + pub async fn is_exclusive_user_id(&self, user_id: &UserId) -> bool { self.read() .await .values() @@ -206,7 +210,7 @@ impl Service { } /// Checks if a given room alias matches any exclusive appservice regex - pub(crate) async fn is_exclusive_alias(&self, alias: &RoomAliasId) -> bool { + pub async fn is_exclusive_alias(&self, alias: &RoomAliasId) -> bool { self.read() .await .values() @@ -217,16 +221,14 @@ impl Service { /// /// TODO: use this? #[allow(dead_code)] - pub(crate) async fn is_exclusive_room_id(&self, room_id: &RoomId) -> bool { + pub async fn is_exclusive_room_id(&self, room_id: &RoomId) -> bool { self.read() .await .values() .any(|info| info.rooms.is_exclusive_match(room_id.as_str())) } - pub(crate) fn read( - &self, - ) -> impl Future>> { + pub fn read(&self) -> impl Future>> { self.registration_info.read() } } diff --git a/src/service/globals/client.rs b/src/service/globals/client.rs index c652eef9..82747ae7 100644 --- a/src/service/globals/client.rs +++ b/src/service/globals/client.rs @@ -4,18 +4,18 @@ use reqwest::redirect; use crate::{service::globals::resolver, utils::conduwuit_version, Config, Result}; -pub(crate) struct Client { - pub(crate) default: reqwest::Client, - pub(crate) url_preview: reqwest::Client, - pub(crate) well_known: reqwest::Client, - pub(crate) federation: reqwest::Client, - pub(crate) sender: reqwest::Client, - pub(crate) appservice: reqwest::Client, - pub(crate) pusher: reqwest::Client, +pub struct Client { + pub default: reqwest::Client, + pub url_preview: reqwest::Client, + pub well_known: reqwest::Client, + pub federation: reqwest::Client, + pub sender: reqwest::Client, + pub appservice: reqwest::Client, + pub pusher: reqwest::Client, } impl Client { - pub(crate) fn new(config: &Config, resolver: &Arc) -> Client { + pub fn new(config: &Config, resolver: &Arc) -> Client { Client { default: Self::base(config) .unwrap() diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 1d75867c..59ed4534 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -10,7 +10,7 @@ use ruma::{ use crate::{database::Cork, Result}; #[async_trait] -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn next_count(&self) -> Result; fn current_count(&self) -> Result; fn last_check_for_updates_id(&self) -> Result; diff --git a/src/service/globals/emerg_access.rs b/src/service/globals/emerg_access.rs new file mode 100644 index 00000000..bda664d6 --- /dev/null +++ b/src/service/globals/emerg_access.rs @@ -0,0 +1,66 @@ +use conduit::Result; +use ruma::{ + events::{ + push_rules::PushRulesEventContent, room::message::RoomMessageEventContent, GlobalAccountDataEvent, + GlobalAccountDataEventType, + }, + push::Ruleset, + UserId, +}; +use tracing::{error, warn}; + +use crate::services; + +pub(crate) async fn init_emergency_access() { + // Set emergency access for the conduit user + match set_emergency_access() { + Ok(pwd_set) => { + if pwd_set { + warn!( + "The Conduit account emergency password is set! Please unset it as soon as you finish admin \ + account recovery!" + ); + services() + .admin + .send_message(RoomMessageEventContent::text_plain( + "The Conduit account emergency password is set! Please unset it as soon as you finish admin \ + account recovery!", + )) + .await; + } + }, + Err(e) => { + error!("Could not set the configured emergency password for the conduit user: {}", e); + }, + }; +} + +/// Sets the emergency password and push rules for the @conduit account in case +/// emergency password is set +fn set_emergency_access() -> Result { + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is a valid UserId"); + + services() + .users + .set_password(&conduit_user, services().globals.emergency_password().as_deref())?; + + let (ruleset, res) = match services().globals.emergency_password() { + Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), + None => (Ruleset::new(), Ok(false)), + }; + + services().account_data.update( + None, + &conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + )?; + + res +} diff --git a/src/database/migrations.rs b/src/service/globals/migrations.rs similarity index 99% rename from src/database/migrations.rs rename to src/service/globals/migrations.rs index a6726528..0e04caeb 100644 --- a/src/database/migrations.rs +++ b/src/service/globals/migrations.rs @@ -7,6 +7,7 @@ use std::{ }; use argon2::{password_hash::SaltString, PasswordHasher, PasswordVerifier}; +use database::KeyValueDatabase; use itertools::Itertools; use rand::thread_rng; use ruma::{ @@ -16,7 +17,6 @@ use ruma::{ }; use tracing::{debug, error, info, warn}; -use super::KeyValueDatabase; use crate::{services, utils, Config, Error, Result}; pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result<()> { @@ -567,7 +567,7 @@ pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result ); { - let patterns = &config.forbidden_usernames; + let patterns = services().globals.forbidden_usernames(); if !patterns.is_empty() { for user_id in services() .users @@ -592,7 +592,7 @@ pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result } { - let patterns = &config.forbidden_alias_names; + let patterns = services().globals.forbidden_alias_names(); if !patterns.is_empty() { for address in services().rooms.metadata.iter_ids() { let room_id = address?; diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 4484b816..85fdad0e 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -3,16 +3,13 @@ use std::{ fs, future::Future, path::PathBuf, - sync::{ - atomic::{self, AtomicBool}, - Arc, - }, - time::{Instant, SystemTime}, + sync::Arc, + time::Instant, }; use argon2::Argon2; use base64::{engine::general_purpose, Engine as _}; -pub(crate) use data::Data; +pub use data::Data; use hickory_resolver::TokioAsyncResolver; use ipaddress::IPAddress; use regex::RegexSet; @@ -25,42 +22,46 @@ use ruma::{ DeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomVersionId, ServerName, UserId, }; -use tokio::sync::{broadcast, Mutex, RwLock}; -use tracing::{error, info, trace}; +use tokio::{ + sync::{broadcast, Mutex, RwLock}, + task::JoinHandle, +}; +use tracing::{error, trace}; use url::Url; -use crate::{services, Config, LogLevelReloadHandles, Result}; +use crate::{services, Config, Result}; mod client; -mod data; +pub mod data; +pub(crate) mod emerg_access; +pub(crate) mod migrations; mod resolver; +pub(crate) mod updates; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries -pub(crate) struct Service<'a> { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, - pub(crate) tracing_reload_handle: LogLevelReloadHandles, - pub(crate) config: Config, - pub(crate) cidr_range_denylist: Vec, + pub config: Config, + pub cidr_range_denylist: Vec, keypair: Arc, jwt_decoding_key: Option, - pub(crate) resolver: Arc, - pub(crate) client: client::Client, - pub(crate) stable_room_versions: Vec, - pub(crate) unstable_room_versions: Vec, - pub(crate) bad_event_ratelimiter: Arc>>, - pub(crate) bad_signature_ratelimiter: Arc, RateLimitState>>>, - pub(crate) bad_query_ratelimiter: Arc>>, - pub(crate) roomid_mutex_insert: RwLock>>>, - pub(crate) roomid_mutex_state: RwLock>>>, - pub(crate) roomid_mutex_federation: RwLock>>>, // this lock will be held longer - pub(crate) roomid_federationhandletime: RwLock>, - pub(crate) stateres_mutex: Arc>, - pub(crate) rotate: RotationHandler, - pub(crate) started: SystemTime, - pub(crate) shutdown: AtomicBool, - pub(crate) argon: Argon2<'a>, + pub resolver: Arc, + pub client: client::Client, + pub stable_room_versions: Vec, + pub unstable_room_versions: Vec, + pub bad_event_ratelimiter: Arc>>, + pub bad_signature_ratelimiter: Arc, RateLimitState>>>, + pub bad_query_ratelimiter: Arc>>, + pub roomid_mutex_insert: RwLock>>>, + pub roomid_mutex_state: RwLock>>>, + pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer + pub roomid_federationhandletime: RwLock>, + pub updates_handle: Mutex>>, + pub stateres_mutex: Arc>, + pub rotate: RotationHandler, + pub argon: Argon2<'static>, } /// Handles "rotation" of long-polling requests. "Rotation" in this context is @@ -68,7 +69,7 @@ pub(crate) struct Service<'a> { /// /// This is utilized to have sync workers return early and release read locks on /// the database. -pub(crate) struct RotationHandler(broadcast::Sender<()>, ()); +pub struct RotationHandler(broadcast::Sender<()>, ()); impl RotationHandler { fn new() -> Self { @@ -76,25 +77,22 @@ impl RotationHandler { Self(s, ()) } - pub(crate) fn watch(&self) -> impl Future { + pub fn watch(&self) -> impl Future { let mut r = self.0.subscribe(); - async move { _ = r.recv().await; } } - fn fire(&self) { _ = self.0.send(()); } + pub fn fire(&self) { _ = self.0.send(()); } } impl Default for RotationHandler { fn default() -> Self { Self::new() } } -impl Service<'_> { - pub(crate) fn load( - db: &'static dyn Data, config: &Config, tracing_reload_handle: LogLevelReloadHandles, - ) -> Result { +impl Service { + pub fn load(db: Arc, config: &Config) -> Result { let keypair = db.load_keypair(); let keypair = match keypair { @@ -140,7 +138,6 @@ impl Service<'_> { } let mut s = Self { - tracing_reload_handle, db, config: config.clone(), cidr_range_denylist, @@ -157,10 +154,9 @@ impl Service<'_> { roomid_mutex_insert: RwLock::new(HashMap::new()), roomid_mutex_federation: RwLock::new(HashMap::new()), roomid_federationhandletime: RwLock::new(HashMap::new()), + updates_handle: Mutex::new(None), stateres_mutex: Arc::new(Mutex::new(())), rotate: RotationHandler::new(), - started: SystemTime::now(), - shutdown: AtomicBool::new(false), argon, }; @@ -178,145 +174,141 @@ impl Service<'_> { } /// Returns this server's keypair. - pub(crate) fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair } + pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair } #[tracing::instrument(skip(self))] - pub(crate) fn next_count(&self) -> Result { self.db.next_count() } + pub fn next_count(&self) -> Result { self.db.next_count() } #[tracing::instrument(skip(self))] - pub(crate) fn current_count(&self) -> Result { self.db.current_count() } + pub fn current_count(&self) -> Result { self.db.current_count() } #[tracing::instrument(skip(self))] - pub(crate) fn last_check_for_updates_id(&self) -> Result { self.db.last_check_for_updates_id() } + pub fn last_check_for_updates_id(&self) -> Result { self.db.last_check_for_updates_id() } #[tracing::instrument(skip(self))] - pub(crate) fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) } + pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) } - pub(crate) async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { self.db.watch(user_id, device_id).await } - pub(crate) fn cleanup(&self) -> Result<()> { self.db.cleanup() } + pub fn cleanup(&self) -> Result<()> { self.db.cleanup() } /// TODO: use this? #[allow(dead_code)] - pub(crate) fn flush(&self) -> Result<()> { self.db.flush() } + pub fn flush(&self) -> Result<()> { self.db.flush() } - pub(crate) fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } + pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } - pub(crate) fn max_request_size(&self) -> u32 { self.config.max_request_size } + pub fn max_request_size(&self) -> u32 { self.config.max_request_size } - pub(crate) fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } + pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } - pub(crate) fn allow_registration(&self) -> bool { self.config.allow_registration } + pub fn allow_registration(&self) -> bool { self.config.allow_registration } - pub(crate) fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration } + pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration } - pub(crate) fn allow_guests_auto_join_rooms(&self) -> bool { self.config.allow_guests_auto_join_rooms } + pub fn allow_guests_auto_join_rooms(&self) -> bool { self.config.allow_guests_auto_join_rooms } - pub(crate) fn log_guest_registrations(&self) -> bool { self.config.log_guest_registrations } + pub fn log_guest_registrations(&self) -> bool { self.config.log_guest_registrations } - pub(crate) fn allow_encryption(&self) -> bool { self.config.allow_encryption } + pub fn allow_encryption(&self) -> bool { self.config.allow_encryption } - pub(crate) fn allow_federation(&self) -> bool { self.config.allow_federation } + pub fn allow_federation(&self) -> bool { self.config.allow_federation } - pub(crate) fn allow_public_room_directory_over_federation(&self) -> bool { + pub fn allow_public_room_directory_over_federation(&self) -> bool { self.config.allow_public_room_directory_over_federation } - pub(crate) fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation } + pub fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation } - pub(crate) fn allow_room_creation(&self) -> bool { self.config.allow_room_creation } + pub fn allow_room_creation(&self) -> bool { self.config.allow_room_creation } - pub(crate) fn allow_unstable_room_versions(&self) -> bool { self.config.allow_unstable_room_versions } + pub fn allow_unstable_room_versions(&self) -> bool { self.config.allow_unstable_room_versions } - pub(crate) fn default_room_version(&self) -> RoomVersionId { self.config.default_room_version.clone() } + pub fn default_room_version(&self) -> RoomVersionId { self.config.default_room_version.clone() } - pub(crate) fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix } + pub fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix } - pub(crate) fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates } + pub fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates } - pub(crate) fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers } + pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers } - pub(crate) fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first } + pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first } - pub(crate) fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver } + pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver } - pub(crate) fn actual_destinations(&self) -> &Arc> { &self.resolver.destinations } + pub fn actual_destinations(&self) -> &Arc> { &self.resolver.destinations } - pub(crate) fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } + pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } - pub(crate) fn turn_password(&self) -> &String { &self.config.turn_password } + pub fn turn_password(&self) -> &String { &self.config.turn_password } - pub(crate) fn turn_ttl(&self) -> u64 { self.config.turn_ttl } + pub fn turn_ttl(&self) -> u64 { self.config.turn_ttl } - pub(crate) fn turn_uris(&self) -> &[String] { &self.config.turn_uris } + pub fn turn_uris(&self) -> &[String] { &self.config.turn_uris } - pub(crate) fn turn_username(&self) -> &String { &self.config.turn_username } + pub fn turn_username(&self) -> &String { &self.config.turn_username } - pub(crate) fn turn_secret(&self) -> &String { &self.config.turn_secret } + pub fn turn_secret(&self) -> &String { &self.config.turn_secret } - pub(crate) fn allow_profile_lookup_federation_requests(&self) -> bool { + pub fn allow_profile_lookup_federation_requests(&self) -> bool { self.config.allow_profile_lookup_federation_requests } - pub(crate) fn notification_push_path(&self) -> &String { &self.config.notification_push_path } + pub fn notification_push_path(&self) -> &String { &self.config.notification_push_path } - pub(crate) fn emergency_password(&self) -> &Option { &self.config.emergency_password } + pub fn emergency_password(&self) -> &Option { &self.config.emergency_password } - pub(crate) fn url_preview_domain_contains_allowlist(&self) -> &Vec { + pub fn url_preview_domain_contains_allowlist(&self) -> &Vec { &self.config.url_preview_domain_contains_allowlist } - pub(crate) fn url_preview_domain_explicit_allowlist(&self) -> &Vec { + pub fn url_preview_domain_explicit_allowlist(&self) -> &Vec { &self.config.url_preview_domain_explicit_allowlist } - pub(crate) fn url_preview_domain_explicit_denylist(&self) -> &Vec { + pub fn url_preview_domain_explicit_denylist(&self) -> &Vec { &self.config.url_preview_domain_explicit_denylist } - pub(crate) fn url_preview_url_contains_allowlist(&self) -> &Vec { - &self.config.url_preview_url_contains_allowlist - } + pub fn url_preview_url_contains_allowlist(&self) -> &Vec { &self.config.url_preview_url_contains_allowlist } - pub(crate) fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size } + pub fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size } - pub(crate) fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain } + pub fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain } - pub(crate) fn forbidden_alias_names(&self) -> &RegexSet { &self.config.forbidden_alias_names } + pub fn forbidden_alias_names(&self) -> &RegexSet { &self.config.forbidden_alias_names } - pub(crate) fn forbidden_usernames(&self) -> &RegexSet { &self.config.forbidden_usernames } + pub fn forbidden_usernames(&self) -> &RegexSet { &self.config.forbidden_usernames } - pub(crate) fn allow_local_presence(&self) -> bool { self.config.allow_local_presence } + pub fn allow_local_presence(&self) -> bool { self.config.allow_local_presence } - pub(crate) fn allow_incoming_presence(&self) -> bool { self.config.allow_incoming_presence } + pub fn allow_incoming_presence(&self) -> bool { self.config.allow_incoming_presence } - pub(crate) fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence } + pub fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence } - pub(crate) fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_read_receipts } + pub fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_read_receipts } - pub(crate) fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts } + pub fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts } - pub(crate) fn prevent_media_downloads_from(&self) -> &[OwnedServerName] { - &self.config.prevent_media_downloads_from - } + pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] { &self.config.prevent_media_downloads_from } - pub(crate) fn forbidden_remote_room_directory_server_names(&self) -> &[OwnedServerName] { + pub fn forbidden_remote_room_directory_server_names(&self) -> &[OwnedServerName] { &self.config.forbidden_remote_room_directory_server_names } - pub(crate) fn well_known_support_page(&self) -> &Option { &self.config.well_known.support_page } + pub fn well_known_support_page(&self) -> &Option { &self.config.well_known.support_page } - pub(crate) fn well_known_support_role(&self) -> &Option { &self.config.well_known.support_role } + pub fn well_known_support_role(&self) -> &Option { &self.config.well_known.support_role } - pub(crate) fn well_known_support_email(&self) -> &Option { &self.config.well_known.support_email } + pub fn well_known_support_email(&self) -> &Option { &self.config.well_known.support_email } - pub(crate) fn well_known_support_mxid(&self) -> &Option { &self.config.well_known.support_mxid } + pub fn well_known_support_mxid(&self) -> &Option { &self.config.well_known.support_mxid } - pub(crate) fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites } + pub fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites } - pub(crate) fn supported_room_versions(&self) -> Vec { + pub fn supported_room_versions(&self) -> Vec { let mut room_versions: Vec = vec![]; room_versions.extend(self.stable_room_versions.clone()); if self.allow_unstable_room_versions() { @@ -332,7 +324,7 @@ impl Service<'_> { /// /// This doesn't actually check that the keys provided are newer than the /// old set. - pub(crate) fn add_signing_key( + pub fn add_signing_key( &self, origin: &ServerName, new_keys: ServerSigningKeys, ) -> Result> { self.db.add_signing_key(origin, new_keys) @@ -340,7 +332,7 @@ impl Service<'_> { /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub(crate) fn signing_keys_for(&self, origin: &ServerName) -> Result> { + pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { let mut keys = self.db.signing_keys_for(origin)?; if origin == self.server_name() { keys.insert( @@ -356,13 +348,11 @@ impl Service<'_> { Ok(keys) } - pub(crate) fn database_version(&self) -> Result { self.db.database_version() } + pub fn database_version(&self) -> Result { self.db.database_version() } - pub(crate) fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.db.bump_database_version(new_version) - } + pub fn bump_database_version(&self, new_version: u64) -> Result<()> { self.db.bump_database_version(new_version) } - pub(crate) fn get_media_folder(&self) -> PathBuf { + pub fn get_media_folder(&self) -> PathBuf { let mut r = PathBuf::new(); r.push(self.config.database_path.clone()); r.push("media"); @@ -373,7 +363,7 @@ impl Service<'_> { /// flag enabled and database migrated uses SHA256 hash of the base64 key as /// the file name #[cfg(feature = "sha256_media")] - pub(crate) fn get_media_file_new(&self, key: &[u8]) -> PathBuf { + pub fn get_media_file_new(&self, key: &[u8]) -> PathBuf { let mut r = PathBuf::new(); r.push(self.config.database_path.clone()); r.push("media"); @@ -387,7 +377,7 @@ impl Service<'_> { /// old base64 file name media function /// This is the old version of `get_media_file` that uses the full base64 /// key as the filename. - pub(crate) fn get_media_file(&self, key: &[u8]) -> PathBuf { + pub fn get_media_file(&self, key: &[u8]) -> PathBuf { let mut r = PathBuf::new(); r.push(self.config.database_path.clone()); r.push("media"); @@ -395,13 +385,13 @@ impl Service<'_> { r } - pub(crate) fn well_known_client(&self) -> &Option { &self.config.well_known.client } + pub fn well_known_client(&self) -> &Option { &self.config.well_known.client } - pub(crate) fn well_known_server(&self) -> &Option { &self.config.well_known.server } + pub fn well_known_server(&self) -> &Option { &self.config.well_known.server } - pub(crate) fn unix_socket_path(&self) -> &Option { &self.config.unix_socket_path } + pub fn unix_socket_path(&self) -> &Option { &self.config.unix_socket_path } - pub(crate) fn valid_cidr_range(&self, ip: &IPAddress) -> bool { + pub fn valid_cidr_range(&self, ip: &IPAddress) -> bool { for cidr in &self.cidr_range_denylist { if cidr.includes(ip) { return false; @@ -410,24 +400,13 @@ impl Service<'_> { true } - - pub(crate) fn shutdown(&self) { - self.shutdown.store(true, atomic::Ordering::Relaxed); - // On shutdown - - if self.unix_socket_path().is_some() { - match &self.unix_socket_path() { - Some(path) => { - fs::remove_file(path).unwrap(); - }, - None => error!( - "Unable to remove socket file at {:?} during shutdown.", - &self.unix_socket_path() - ), - }; - }; - - info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); - services().globals.rotate.fire(); - } } + +#[inline] +#[must_use] +pub fn server_is_ours(server_name: &ServerName) -> bool { server_name == services().globals.config.server_name } + +/// checks if `user_id` is local to us via server_name comparison +#[inline] +#[must_use] +pub fn user_is_local(user_id: &UserId) -> bool { server_is_ours(user_id.server_name()) } diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs index 190c8579..d109d389 100644 --- a/src/service/globals/resolver.rs +++ b/src/service/globals/resolver.rs @@ -17,21 +17,21 @@ use crate::{service::sending::FedDest, Config, Error}; pub(crate) type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; -pub(crate) struct Resolver { - pub(crate) destinations: Arc>, // actual_destination, host - pub(crate) overrides: Arc>, - pub(crate) resolver: Arc, - pub(crate) hooked: Arc, +pub struct Resolver { + pub destinations: Arc>, // actual_destination, host + pub overrides: Arc>, + pub resolver: Arc, + pub hooked: Arc, } -pub(crate) struct Hooked { - pub(crate) overrides: Arc>, - pub(crate) resolver: Arc, +pub struct Hooked { + pub overrides: Arc>, + pub resolver: Arc, } impl Resolver { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] - pub(crate) fn new(config: &Config) -> Self { + pub fn new(config: &Config) -> Self { let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() .map_err(|e| { error!("Failed to set up hickory dns resolver with system config: {}", e); diff --git a/src/service/globals/updates.rs b/src/service/globals/updates.rs new file mode 100644 index 00000000..5f530736 --- /dev/null +++ b/src/service/globals/updates.rs @@ -0,0 +1,76 @@ +use std::time::Duration; + +use ruma::events::room::message::RoomMessageEventContent; +use serde::Deserialize; +use tokio::{task::JoinHandle, time::interval}; +use tracing::{debug, error}; + +use crate::{ + conduit::{Error, Result}, + services, +}; + +#[derive(Deserialize)] +struct CheckForUpdatesResponseEntry { + id: u64, + date: String, + message: String, +} +#[derive(Deserialize)] +struct CheckForUpdatesResponse { + updates: Vec, +} + +#[tracing::instrument] +pub async fn start_check_for_updates_task() -> Result> { + let timer_interval = Duration::from_secs(7200); // 2 hours + + Ok(services().server.runtime().spawn(async move { + let mut i = interval(timer_interval); + + loop { + tokio::select! { + _ = i.tick() => { + debug!(target: "start_check_for_updates_task", "Timer ticked"); + }, + } + + _ = try_handle_updates().await; + } + })) +} + +async fn try_handle_updates() -> Result<()> { + let response = services() + .globals + .client + .default + .get("https://pupbrain.dev/check-for-updates/stable") + .send() + .await?; + + let response = serde_json::from_str::(&response.text().await?).map_err(|e| { + error!("Bad check for updates response: {e}"); + Error::BadServerResponse("Bad version check response") + })?; + + let mut last_update_id = services().globals.last_check_for_updates_id()?; + for update in response.updates { + last_update_id = last_update_id.max(update.id); + if update.id > services().globals.last_check_for_updates_id()? { + error!("{}", update.message); + services() + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", + update.date, update.message + ))) + .await; + } + } + services() + .globals + .update_check_for_updates_id(last_update_id)?; + + Ok(()) +} diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index 5aefbd49..ac595a6b 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -8,7 +8,7 @@ use ruma::{ use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result; fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index d071ccc4..abab604c 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,7 +1,7 @@ mod data; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; -pub(crate) use data::Data; +pub use data::Data; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, @@ -10,79 +10,73 @@ use ruma::{ use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { - pub(crate) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { + pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { self.db.create_backup(user_id, backup_metadata) } - pub(crate) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { self.db.delete_backup(user_id, version) } - pub(crate) fn update_backup( + pub fn update_backup( &self, user_id: &UserId, version: &str, backup_metadata: &Raw, ) -> Result { self.db.update_backup(user_id, version, backup_metadata) } - pub(crate) fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { + pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { self.db.get_latest_backup_version(user_id) } - pub(crate) fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { + pub fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { self.db.get_latest_backup(user_id) } - pub(crate) fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { + pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { self.db.get_backup(user_id, version) } - pub(crate) fn add_key( + pub fn add_key( &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, ) -> Result<()> { self.db .add_key(user_id, version, room_id, session_id, key_data) } - pub(crate) fn count_keys(&self, user_id: &UserId, version: &str) -> Result { - self.db.count_keys(user_id, version) - } + pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { self.db.count_keys(user_id, version) } - pub(crate) fn get_etag(&self, user_id: &UserId, version: &str) -> Result { - self.db.get_etag(user_id, version) - } + pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result { self.db.get_etag(user_id, version) } - pub(crate) fn get_all(&self, user_id: &UserId, version: &str) -> Result> { + pub fn get_all(&self, user_id: &UserId, version: &str) -> Result> { self.db.get_all(user_id, version) } - pub(crate) fn get_room( + pub fn get_room( &self, user_id: &UserId, version: &str, room_id: &RoomId, ) -> Result>> { self.db.get_room(user_id, version, room_id) } - pub(crate) fn get_session( + pub fn get_session( &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, ) -> Result>> { self.db.get_session(user_id, version, room_id, session_id) } - pub(crate) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { self.db.delete_all_keys(user_id, version) } - pub(crate) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { + pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { self.db.delete_room_keys(user_id, version, room_id) } - pub(crate) fn delete_room_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result<()> { + pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { self.db .delete_room_key(user_id, version, room_id, session_id) } diff --git a/src/database/key_value/account_data.rs b/src/service/key_value/account_data.rs similarity index 96% rename from src/database/key_value/account_data.rs rename to src/service/key_value/account_data.rs index d67f8881..981f1b8c 100644 --- a/src/database/key_value/account_data.rs +++ b/src/service/key_value/account_data.rs @@ -8,9 +8,9 @@ use ruma::{ }; use tracing::warn; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{services, utils, Error, KeyValueDatabase, Result}; -impl service::account_data::Data for KeyValueDatabase { +impl crate::account_data::Data for KeyValueDatabase { /// Places one event in the account data of the user and removes the /// previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] diff --git a/src/database/key_value/appservice.rs b/src/service/key_value/appservice.rs similarity index 92% rename from src/database/key_value/appservice.rs rename to src/service/key_value/appservice.rs index ead37c2b..f030d5e7 100644 --- a/src/database/key_value/appservice.rs +++ b/src/service/key_value/appservice.rs @@ -1,8 +1,8 @@ use ruma::api::appservice::Registration; -use crate::{database::KeyValueDatabase, service, utils, Error, Result}; +use crate::{utils, Error, KeyValueDatabase, Result}; -impl service::appservice::Data for KeyValueDatabase { +impl crate::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller fn register_appservice(&self, yaml: Registration) -> Result { let id = yaml.id.as_str(); diff --git a/src/database/key_value/globals.rs b/src/service/key_value/globals.rs similarity index 96% rename from src/database/key_value/globals.rs rename to src/service/key_value/globals.rs index 4ed07eba..e56f9feb 100644 --- a/src/database/key_value/globals.rs +++ b/src/service/key_value/globals.rs @@ -8,17 +8,15 @@ use ruma::{ signatures::Ed25519KeyPair, DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, }; +use tracing::trace; -use crate::{ - database::{Cork, KeyValueDatabase}, - service, services, utils, Error, Result, -}; +use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result}; const COUNTER: &[u8] = b"c"; const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; #[async_trait] -impl service::globals::Data for KeyValueDatabase { +impl crate::globals::Data for KeyValueDatabase { fn next_count(&self) -> Result { utils::u64_from_bytes(&self.global.increment(COUNTER)?) .map_err(|_| Error::bad_database("Count has invalid bytes.")) @@ -47,6 +45,7 @@ impl service::globals::Data for KeyValueDatabase { } #[allow(unused_qualifications)] // async traits + #[tracing::instrument(skip(self))] async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let userid_bytes = user_id.as_bytes().to_vec(); let mut userid_prefix = userid_bytes.clone(); @@ -132,7 +131,9 @@ impl service::globals::Data for KeyValueDatabase { futures.push(Box::pin(services().globals.rotate.watch())); // Wait until one of them finds something + trace!(futures = futures.len(), "watch started"); futures.next().await; + trace!(futures = futures.len(), "watch finished"); Ok(()) } diff --git a/src/database/key_value/key_backups.rs b/src/service/key_value/key_backups.rs similarity index 98% rename from src/database/key_value/key_backups.rs rename to src/service/key_value/key_backups.rs index 7ed1da4c..82bbdd48 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/service/key_value/key_backups.rs @@ -9,9 +9,9 @@ use ruma::{ OwnedRoomId, RoomId, UserId, }; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{services, utils, Error, KeyValueDatabase, Result}; -impl service::key_backups::Data for KeyValueDatabase { +impl crate::key_backups::Data for KeyValueDatabase { fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { let version = services().globals.next_count()?.to_string(); diff --git a/src/database/key_value/media.rs b/src/service/key_value/media.rs similarity index 97% rename from src/database/key_value/media.rs rename to src/service/key_value/media.rs index d86b50f1..07d8dde3 100644 --- a/src/database/key_value/media.rs +++ b/src/service/key_value/media.rs @@ -1,14 +1,9 @@ use ruma::api::client::error::ErrorKind; use tracing::debug; -use crate::{ - database::KeyValueDatabase, - service::{self, media::UrlPreviewData}, - utils::string_from_bytes, - Error, Result, -}; +use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, KeyValueDatabase, Result}; -impl service::media::Data for KeyValueDatabase { +impl crate::media::Data for KeyValueDatabase { fn create_file_metadata( &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, diff --git a/src/database/key_value/mod.rs b/src/service/key_value/mod.rs similarity index 100% rename from src/database/key_value/mod.rs rename to src/service/key_value/mod.rs diff --git a/src/database/key_value/presence.rs b/src/service/key_value/presence.rs similarity index 96% rename from src/database/key_value/presence.rs rename to src/service/key_value/presence.rs index 17068f90..9defd06d 100644 --- a/src/database/key_value/presence.rs +++ b/src/service/key_value/presence.rs @@ -1,15 +1,14 @@ +use conduit::debug_info; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use crate::{ - database::KeyValueDatabase, - debug_info, - service::{self, presence::Presence}, + presence::Presence, services, utils::{self, user_id_from_bytes}, - Error, Result, + Error, KeyValueDatabase, Result, }; -impl service::presence::Data for KeyValueDatabase { +impl crate::presence::Data for KeyValueDatabase { fn get_presence(&self, user_id: &UserId) -> Result> { if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { let count = utils::u64_from_bytes(&count_bytes) diff --git a/src/database/key_value/pusher.rs b/src/service/key_value/pusher.rs similarity index 94% rename from src/database/key_value/pusher.rs rename to src/service/key_value/pusher.rs index 851831ec..876b531c 100644 --- a/src/database/key_value/pusher.rs +++ b/src/service/key_value/pusher.rs @@ -3,9 +3,9 @@ use ruma::{ UserId, }; -use crate::{database::KeyValueDatabase, service, utils, Error, Result}; +use crate::{utils, Error, KeyValueDatabase, Result}; -impl service::pusher::Data for KeyValueDatabase { +impl crate::pusher::Data for KeyValueDatabase { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { match &pusher { set_pusher::v3::PusherAction::Post(data) => { diff --git a/src/database/key_value/rooms/alias.rs b/src/service/key_value/rooms/alias.rs similarity index 94% rename from src/database/key_value/rooms/alias.rs rename to src/service/key_value/rooms/alias.rs index b5a976c7..402e59fd 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/service/key_value/rooms/alias.rs @@ -1,8 +1,8 @@ use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{services, utils, Error, KeyValueDatabase, Result}; -impl service::rooms::alias::Data for KeyValueDatabase { +impl crate::rooms::alias::Data for KeyValueDatabase { fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { self.alias_roomid .insert(alias.alias().as_bytes(), room_id.as_bytes())?; diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/service/key_value/rooms/auth_chain.rs similarity index 91% rename from src/database/key_value/rooms/auth_chain.rs rename to src/service/key_value/rooms/auth_chain.rs index 435a1c03..f01ff4aa 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/service/key_value/rooms/auth_chain.rs @@ -1,8 +1,8 @@ use std::{mem::size_of, sync::Arc}; -use crate::{database::KeyValueDatabase, service, utils, Result}; +use crate::{utils, KeyValueDatabase, Result}; -impl service::rooms::auth_chain::Data for KeyValueDatabase { +impl crate::rooms::auth_chain::Data for KeyValueDatabase { fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { // Check RAM cache if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { diff --git a/src/database/key_value/rooms/directory.rs b/src/service/key_value/rooms/directory.rs similarity index 85% rename from src/database/key_value/rooms/directory.rs rename to src/service/key_value/rooms/directory.rs index 20ccfb55..9265d2c8 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/service/key_value/rooms/directory.rs @@ -1,8 +1,8 @@ use ruma::{OwnedRoomId, RoomId}; -use crate::{database::KeyValueDatabase, service, utils, Error, Result}; +use crate::{utils, Error, KeyValueDatabase, Result}; -impl service::rooms::directory::Data for KeyValueDatabase { +impl crate::rooms::directory::Data for KeyValueDatabase { fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) } diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/service/key_value/rooms/lazy_load.rs similarity index 92% rename from src/database/key_value/rooms/lazy_load.rs rename to src/service/key_value/rooms/lazy_load.rs index 080eb4b8..700505cd 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/service/key_value/rooms/lazy_load.rs @@ -1,8 +1,8 @@ use ruma::{DeviceId, RoomId, UserId}; -use crate::{database::KeyValueDatabase, service, Result}; +use crate::{KeyValueDatabase, Result}; -impl service::rooms::lazy_loading::Data for KeyValueDatabase { +impl crate::rooms::lazy_loading::Data for KeyValueDatabase { fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, ) -> Result { diff --git a/src/database/key_value/rooms/metadata.rs b/src/service/key_value/rooms/metadata.rs similarity index 93% rename from src/database/key_value/rooms/metadata.rs rename to src/service/key_value/rooms/metadata.rs index 9528da1e..ab8c1a78 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/service/key_value/rooms/metadata.rs @@ -1,9 +1,9 @@ use ruma::{OwnedRoomId, RoomId}; use tracing::error; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{services, utils, Error, KeyValueDatabase, Result}; -impl service::rooms::metadata::Data for KeyValueDatabase { +impl crate::rooms::metadata::Data for KeyValueDatabase { fn exists(&self, room_id: &RoomId) -> Result { let prefix = match services().rooms.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), diff --git a/src/database/key_value/rooms/mod.rs b/src/service/key_value/rooms/mod.rs similarity index 71% rename from src/database/key_value/rooms/mod.rs rename to src/service/key_value/rooms/mod.rs index 087f2711..d69cf141 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/service/key_value/rooms/mod.rs @@ -16,6 +16,6 @@ mod threads; mod timeline; mod user; -use crate::{database::KeyValueDatabase, service}; +use crate::KeyValueDatabase; -impl service::rooms::Data for KeyValueDatabase {} +impl crate::rooms::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/outlier.rs b/src/service/key_value/rooms/outlier.rs similarity index 85% rename from src/database/key_value/rooms/outlier.rs rename to src/service/key_value/rooms/outlier.rs index 933660e8..701e4cb2 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/service/key_value/rooms/outlier.rs @@ -1,8 +1,8 @@ use ruma::{CanonicalJsonObject, EventId}; -use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result}; +use crate::{Error, KeyValueDatabase, PduEvent, Result}; -impl service::rooms::outlier::Data for KeyValueDatabase { +impl crate::rooms::outlier::Data for KeyValueDatabase { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/service/key_value/rooms/pdu_metadata.rs similarity index 92% rename from src/database/key_value/rooms/pdu_metadata.rs rename to src/service/key_value/rooms/pdu_metadata.rs index 7e69788d..225ed1cc 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/service/key_value/rooms/pdu_metadata.rs @@ -2,13 +2,9 @@ use std::{mem, sync::Arc}; use ruma::{EventId, RoomId, UserId}; -use crate::{ - database::KeyValueDatabase, - service::{self, rooms::timeline::PduCount}, - services, utils, Error, PduEvent, Result, -}; +use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result}; -impl service::rooms::pdu_metadata::Data for KeyValueDatabase { +impl crate::rooms::pdu_metadata::Data for KeyValueDatabase { fn add_relation(&self, from: u64, to: u64) -> Result<()> { let mut key = to.to_be_bytes().to_vec(); key.extend_from_slice(&from.to_be_bytes()); diff --git a/src/database/key_value/rooms/read_receipt.rs b/src/service/key_value/rooms/read_receipt.rs similarity index 96% rename from src/database/key_value/rooms/read_receipt.rs rename to src/service/key_value/rooms/read_receipt.rs index e3f01a75..6cd913e7 100644 --- a/src/database/key_value/rooms/read_receipt.rs +++ b/src/service/key_value/rooms/read_receipt.rs @@ -2,9 +2,9 @@ use std::mem; use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{services, utils, Error, KeyValueDatabase, Result}; -impl service::rooms::read_receipt::Data for KeyValueDatabase { +impl crate::rooms::read_receipt::Data for KeyValueDatabase { fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); diff --git a/src/database/key_value/rooms/search.rs b/src/service/key_value/rooms/search.rs similarity index 93% rename from src/database/key_value/rooms/search.rs rename to src/service/key_value/rooms/search.rs index 6c5d1bc2..ab826172 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/service/key_value/rooms/search.rs @@ -1,10 +1,10 @@ use ruma::RoomId; -use crate::{database::KeyValueDatabase, service, services, utils, Result}; +use crate::{services, utils, KeyValueDatabase, Result}; type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; -impl service::rooms::search::Data for KeyValueDatabase { +impl crate::rooms::search::Data for KeyValueDatabase { fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) diff --git a/src/database/key_value/rooms/short.rs b/src/service/key_value/rooms/short.rs similarity index 97% rename from src/database/key_value/rooms/short.rs rename to src/service/key_value/rooms/short.rs index e0c3daac..69d85da4 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/service/key_value/rooms/short.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use ruma::{events::StateEventType, EventId, RoomId}; use tracing::warn; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{services, utils, Error, KeyValueDatabase, Result}; -impl service::rooms::short::Data for KeyValueDatabase { +impl crate::rooms::short::Data for KeyValueDatabase { fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? diff --git a/src/database/key_value/rooms/state.rs b/src/service/key_value/rooms/state.rs similarity index 94% rename from src/database/key_value/rooms/state.rs rename to src/service/key_value/rooms/state.rs index 79ef202f..f7637c57 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/service/key_value/rooms/state.rs @@ -3,9 +3,9 @@ use std::{collections::HashSet, sync::Arc}; use ruma::{EventId, OwnedEventId, RoomId}; use tokio::sync::MutexGuard; -use crate::{database::KeyValueDatabase, service, utils, Error, Result}; +use crate::{utils, Error, KeyValueDatabase, Result}; -impl service::rooms::state::Data for KeyValueDatabase { +impl crate::rooms::state::Data for KeyValueDatabase { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.roomid_shortstatehash .get(room_id.as_bytes())? diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/service/key_value/rooms/state_accessor.rs similarity index 96% rename from src/database/key_value/rooms/state_accessor.rs rename to src/service/key_value/rooms/state_accessor.rs index 5b3a71d0..c36fd1cf 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/service/key_value/rooms/state_accessor.rs @@ -3,10 +3,10 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; +use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; #[async_trait] -impl service::rooms::state_accessor::Data for KeyValueDatabase { +impl crate::rooms::state_accessor::Data for KeyValueDatabase { #[allow(unused_qualifications)] // async traits async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { let full_state = services() diff --git a/src/database/key_value/rooms/state_cache.rs b/src/service/key_value/rooms/state_cache.rs similarity index 98% rename from src/database/key_value/rooms/state_cache.rs rename to src/service/key_value/rooms/state_cache.rs index 1ca29ebd..795da576 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/service/key_value/rooms/state_cache.rs @@ -9,18 +9,17 @@ use ruma::{ use tracing::error; use crate::{ - database::KeyValueDatabase, - service::{self, appservice::RegistrationInfo}, - services, - utils::{self, user_id::user_is_local}, - Error, Result, + appservice::RegistrationInfo, + services, user_is_local, + utils::{self}, + Error, KeyValueDatabase, Result, }; type StrippedStateEventIter<'a> = Box>)>> + 'a>; type AnySyncStateEventIter<'a> = Box>)>> + 'a>; -impl service::rooms::state_cache::Data for KeyValueDatabase { +impl crate::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/service/key_value/rooms/state_compressor.rs similarity index 88% rename from src/database/key_value/rooms/state_compressor.rs rename to src/service/key_value/rooms/state_compressor.rs index 0043a1ba..bc0a2c33 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/service/key_value/rooms/state_compressor.rs @@ -1,12 +1,8 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; -use crate::{ - database::KeyValueDatabase, - service::{self, rooms::state_compressor::data::StateDiff}, - utils, Error, Result, -}; +use crate::{rooms::state_compressor::data::StateDiff, utils, Error, KeyValueDatabase, Result}; -impl service::rooms::state_compressor::Data for KeyValueDatabase { +impl crate::rooms::state_compressor::Data for KeyValueDatabase { fn get_statediff(&self, shortstatehash: u64) -> Result { let value = self .shortstatehash_statediff diff --git a/src/database/key_value/rooms/threads.rs b/src/service/key_value/rooms/threads.rs similarity index 93% rename from src/database/key_value/rooms/threads.rs rename to src/service/key_value/rooms/threads.rs index 4cb2591b..9f0aad3a 100644 --- a/src/database/key_value/rooms/threads.rs +++ b/src/service/key_value/rooms/threads.rs @@ -2,11 +2,11 @@ use std::mem; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; +use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; type PduEventIterResult<'a> = Result> + 'a>>; -impl service::rooms::threads::Data for KeyValueDatabase { +impl crate::rooms::threads::Data for KeyValueDatabase { fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, ) -> PduEventIterResult<'a> { diff --git a/src/database/key_value/rooms/timeline.rs b/src/service/key_value/rooms/timeline.rs similarity index 97% rename from src/database/key_value/rooms/timeline.rs rename to src/service/key_value/rooms/timeline.rs index d583c7ec..7f22354c 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/service/key_value/rooms/timeline.rs @@ -1,12 +1,11 @@ use std::{collections::hash_map, mem::size_of, sync::Arc}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; -use service::rooms::timeline::PduCount; use tracing::error; -use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; +use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result}; -impl service::rooms::timeline::Data for KeyValueDatabase { +impl crate::rooms::timeline::Data for KeyValueDatabase { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache diff --git a/src/database/key_value/rooms/user.rs b/src/service/key_value/rooms/user.rs similarity index 96% rename from src/database/key_value/rooms/user.rs rename to src/service/key_value/rooms/user.rs index cc031747..a49dc815 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/service/key_value/rooms/user.rs @@ -1,8 +1,8 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{services, utils, Error, KeyValueDatabase, Result}; -impl service::rooms::user::Data for KeyValueDatabase { +impl crate::rooms::user::Data for KeyValueDatabase { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); diff --git a/src/database/key_value/sending.rs b/src/service/key_value/sending.rs similarity index 96% rename from src/database/key_value/sending.rs rename to src/service/key_value/sending.rs index ae662f3b..544d37c9 100644 --- a/src/database/key_value/sending.rs +++ b/src/service/key_value/sending.rs @@ -1,15 +1,11 @@ use ruma::{ServerName, UserId}; use crate::{ - database::KeyValueDatabase, - service::{ - self, - sending::{Destination, SendingEvent}, - }, - services, utils, Error, Result, + sending::{Destination, SendingEvent}, + services, utils, Error, KeyValueDatabase, Result, }; -impl service::sending::Data for KeyValueDatabase { +impl crate::sending::Data for KeyValueDatabase { fn active_requests<'a>(&'a self) -> Box, Destination, SendingEvent)>> + 'a> { Box::new( self.servercurrentevent_data diff --git a/src/database/key_value/transaction_ids.rs b/src/service/key_value/transaction_ids.rs similarity index 88% rename from src/database/key_value/transaction_ids.rs rename to src/service/key_value/transaction_ids.rs index f88ae69f..2dfcdfb1 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/service/key_value/transaction_ids.rs @@ -1,8 +1,8 @@ use ruma::{DeviceId, TransactionId, UserId}; -use crate::{database::KeyValueDatabase, service, Result}; +use crate::{KeyValueDatabase, Result}; -impl service::transaction_ids::Data for KeyValueDatabase { +impl crate::transaction_ids::Data for KeyValueDatabase { fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], ) -> Result<()> { diff --git a/src/database/key_value/uiaa.rs b/src/service/key_value/uiaa.rs similarity index 94% rename from src/database/key_value/uiaa.rs rename to src/service/key_value/uiaa.rs index a8047656..801f08c1 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/service/key_value/uiaa.rs @@ -3,9 +3,9 @@ use ruma::{ CanonicalJsonValue, DeviceId, UserId, }; -use crate::{database::KeyValueDatabase, service, Error, Result}; +use crate::{Error, KeyValueDatabase, Result}; -impl service::uiaa::Data for KeyValueDatabase { +impl crate::uiaa::Data for KeyValueDatabase { fn set_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, ) -> Result<()> { diff --git a/src/database/key_value/users.rs b/src/service/key_value/users.rs similarity index 98% rename from src/database/key_value/users.rs rename to src/service/key_value/users.rs index 9b10f2a5..c8adb39f 100644 --- a/src/database/key_value/users.rs +++ b/src/service/key_value/users.rs @@ -1,5 +1,6 @@ use std::{collections::BTreeMap, mem::size_of}; +use argon2::{password_hash::SaltString, PasswordHasher}; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, @@ -10,13 +11,9 @@ use ruma::{ }; use tracing::warn; -use crate::{ - database::KeyValueDatabase, - service::{self, users::clean_signatures}, - services, utils, Error, Result, -}; +use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, Result}; -impl service::users::Data for KeyValueDatabase { +impl crate::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } @@ -95,7 +92,7 @@ impl service::users::Data for KeyValueDatabase { /// Hash and set the user's password to the Argon2 hash fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { - if let Ok(hash) = utils::calculate_password_hash(password) { + if let Ok(hash) = calculate_password_hash(password) { self.userid_password .insert(user_id.as_bytes(), hash.as_bytes())?; Ok(()) @@ -871,8 +868,6 @@ impl service::users::Data for KeyValueDatabase { } } -impl KeyValueDatabase {} - /// Will only return with Some(username) if the password was not empty and the /// username could be successfully parsed. /// If `utils::string_from_bytes`(...) returns an error that username will be @@ -891,3 +886,13 @@ fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option< } } } + +/// Calculate a new hash for the given password +fn calculate_password_hash(password: &str) -> Result { + let salt = SaltString::generate(rand::thread_rng()); + services() + .globals + .argon + .hash_password(password.as_bytes(), &salt) + .map(|it| it.to_string()) +} diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 7a1710b6..b20f8773 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,6 +1,6 @@ use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn create_file_metadata( &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 5e39085d..7494e9b3 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -1,7 +1,7 @@ mod data; use std::{collections::HashMap, io::Cursor, sync::Arc, time::SystemTime}; -pub(crate) use data::Data; +pub use data::Data; use image::imageops::FilterType; use ruma::{OwnedMxcUri, OwnedUserId}; use serde::Serialize; @@ -15,37 +15,37 @@ use tracing::{debug, error}; use crate::{services, utils, Error, Result}; #[derive(Debug)] -pub(crate) struct FileMeta { +pub struct FileMeta { #[allow(dead_code)] - pub(crate) content_disposition: Option, - pub(crate) content_type: Option, - pub(crate) file: Vec, + pub content_disposition: Option, + pub content_type: Option, + pub file: Vec, } #[derive(Serialize, Default)] -pub(crate) struct UrlPreviewData { +pub struct UrlPreviewData { #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))] - pub(crate) title: Option, + pub title: Option, #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))] - pub(crate) description: Option, + pub description: Option, #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))] - pub(crate) image: Option, + pub image: Option, #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))] - pub(crate) image_size: Option, + pub image_size: Option, #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))] - pub(crate) image_width: Option, + pub image_width: Option, #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))] - pub(crate) image_height: Option, + pub image_height: Option, } -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, - pub(crate) url_preview_mutex: RwLock>>>, +pub struct Service { + pub db: Arc, + pub url_preview_mutex: RwLock>>>, } impl Service { /// Uploads a file. - pub(crate) async fn create( + pub async fn create( &self, sender_user: Option, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, file: &[u8], ) -> Result<()> { @@ -79,7 +79,7 @@ impl Service { } /// Deletes a file in the database and from the media directory via an MXC - pub(crate) async fn delete(&self, mxc: String) -> Result<()> { + pub async fn delete(&self, mxc: String) -> Result<()> { if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc.clone()) { for key in keys { let file_path; @@ -116,7 +116,7 @@ impl Service { /// Uploads or replaces a file thumbnail. #[allow(clippy::too_many_arguments)] - pub(crate) async fn upload_thumbnail( + pub async fn upload_thumbnail( &self, sender_user: Option, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, width: u32, height: u32, file: &[u8], ) -> Result<()> { @@ -149,7 +149,7 @@ impl Service { } /// Downloads a file. - pub(crate) async fn get(&self, mxc: String) -> Result> { + pub async fn get(&self, mxc: String) -> Result> { if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) { let path; @@ -182,7 +182,7 @@ impl Service { /// Deletes all remote only media files in the given at or after /// time/duration. Returns a u32 with the amount of media files deleted. - pub(crate) async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result { + pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result { if let Ok(all_keys) = self.db.get_all_media_keys() { let user_duration: SystemTime = match cyborgtime::parse_duration(&time) { Ok(duration) => { @@ -296,7 +296,7 @@ impl Service { /// Returns width, height of the thumbnail and whether it should be cropped. /// Returns None when the server should send the original file. - pub(crate) fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { + pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { match (width, height) { (0..=32, 0..=32) => Some((32, 32, true)), (0..=96, 0..=96) => Some((96, 96, true)), @@ -320,7 +320,7 @@ impl Service { /// /// For width,height <= 96 the server uses another thumbnailing algorithm /// which crops the image afterwards. - pub(crate) async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result> { + pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result> { let (width, height, crop) = self .thumbnail_properties(width, height) .unwrap_or((0, 0, false)); // 0, 0 because that's the original file @@ -467,16 +467,16 @@ impl Service { } } - pub(crate) async fn get_url_preview(&self, url: &str) -> Option { self.db.get_url_preview(url) } + pub async fn get_url_preview(&self, url: &str) -> Option { self.db.get_url_preview(url) } /// TODO: use this? #[allow(dead_code)] - pub(crate) async fn remove_url_preview(&self, url: &str) -> Result<()> { + pub async fn remove_url_preview(&self, url: &str) -> Result<()> { // TODO: also remove the downloaded image self.db.remove_url_preview(url) } - pub(crate) async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { + pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .expect("valid system time"); @@ -548,9 +548,9 @@ mod tests { fn get_url_preview(&self, _url: &str) -> Option { todo!() } } - static DB: MockedKVDatabase = MockedKVDatabase; + let db: Arc = Arc::new(MockedKVDatabase); let media = Service { - db: &DB, + db, url_preview_mutex: RwLock::new(HashMap::new()), }; diff --git a/src/service/mod.rs b/src/service/mod.rs index 3fbc435a..386e8662 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,298 +1,50 @@ -use std::{ - collections::{BTreeMap, HashMap}, - sync::{Arc, Mutex as StdMutex}, -}; +pub(crate) mod key_value; +pub mod pdu; +pub mod services; -use lru_cache::LruCache; -use tokio::sync::{broadcast, Mutex, RwLock}; +pub mod account_data; +pub mod admin; +pub mod appservice; +pub mod globals; +pub mod key_backups; +pub mod media; +pub mod presence; +pub mod pusher; +pub mod rooms; +pub mod sending; +pub mod transaction_ids; +pub mod uiaa; +pub mod users; -use crate::{Config, LogLevelReloadHandles, Result}; +extern crate conduit_core as conduit; +extern crate conduit_database as database; +use std::sync::RwLock; -pub(crate) mod account_data; -pub(crate) mod admin; -pub(crate) mod appservice; -pub(crate) mod globals; -pub(crate) mod key_backups; -pub(crate) mod media; -pub(crate) mod pdu; -pub(crate) mod presence; -pub(crate) mod pusher; -pub(crate) mod rooms; -pub(crate) mod sending; -pub(crate) mod transaction_ids; -pub(crate) mod uiaa; -pub(crate) mod users; +pub(crate) use conduit::{config, debug_error, debug_info, debug_warn, utils, Config, Error, PduCount, Result}; +pub(crate) use database::KeyValueDatabase; +pub use globals::{server_is_ours, user_is_local}; +pub use pdu::PduEvent; +pub use services::Services; -pub(crate) struct Services<'a> { - pub(crate) appservice: appservice::Service, - pub(crate) pusher: pusher::Service, - pub(crate) rooms: rooms::Service, - pub(crate) transaction_ids: transaction_ids::Service, - pub(crate) uiaa: uiaa::Service, - pub(crate) users: users::Service, - pub(crate) account_data: account_data::Service, - pub(crate) presence: Arc, - pub(crate) admin: Arc, - pub(crate) globals: globals::Service<'a>, - pub(crate) key_backups: key_backups::Service, - pub(crate) media: media::Service, - pub(crate) sending: Arc, +pub(crate) use crate as service; + +conduit::mod_ctor! {} +conduit::mod_dtor! {} + +pub static SERVICES: RwLock> = RwLock::new(None); + +#[must_use] +pub fn services() -> &'static Services { + SERVICES + .read() + .expect("SERVICES locked for reading") + .expect("SERVICES initialized with Services instance") } -impl Services<'_> { - #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] - pub(crate) fn build< - D: appservice::Data - + pusher::Data - + rooms::Data - + transaction_ids::Data - + uiaa::Data - + users::Data - + account_data::Data - + presence::Data - + globals::Data - + key_backups::Data - + media::Data - + sending::Data - + 'static, - >( - db: &'static D, config: &Config, tracing_reload_handle: LogLevelReloadHandles, - ) -> Result { - Ok(Self { - appservice: appservice::Service::build(db)?, - pusher: pusher::Service { - db, - }, - rooms: rooms::Service { - alias: rooms::alias::Service { - db, - }, - auth_chain: rooms::auth_chain::Service { - db, - }, - directory: rooms::directory::Service { - db, - }, - event_handler: rooms::event_handler::Service, - lazy_loading: rooms::lazy_loading::Service { - db, - lazy_load_waiting: Mutex::new(HashMap::new()), - }, - metadata: rooms::metadata::Service { - db, - }, - outlier: rooms::outlier::Service { - db, - }, - pdu_metadata: rooms::pdu_metadata::Service { - db, - }, - read_receipt: rooms::read_receipt::Service { - db, - }, - search: rooms::search::Service { - db, - }, - short: rooms::short::Service { - db, - }, - state: rooms::state::Service { - db, - }, - state_accessor: rooms::state_accessor::Service { - db, - server_visibility_cache: StdMutex::new(LruCache::new( - (f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) - as usize, - )), - user_visibility_cache: StdMutex::new(LruCache::new( - (f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) - as usize, - )), - }, - state_cache: rooms::state_cache::Service { - db, - }, - state_compressor: rooms::state_compressor::Service { - db, - stateinfo_cache: StdMutex::new(LruCache::new( - (f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), - }, - timeline: rooms::timeline::Service { - db, - lasttimelinecount_cache: Mutex::new(HashMap::new()), - }, - threads: rooms::threads::Service { - db, - }, - typing: rooms::typing::Service { - typing: RwLock::new(BTreeMap::new()), - last_typing_update: RwLock::new(BTreeMap::new()), - typing_update_sender: broadcast::channel(100).0, - }, - spaces: rooms::spaces::Service { - roomid_spacehierarchy_cache: Mutex::new(LruCache::new( - (f64::from(config.roomid_spacehierarchy_cache_capacity) - * config.conduit_cache_capacity_modifier) as usize, - )), - }, - user: rooms::user::Service { - db, - }, - }, - transaction_ids: transaction_ids::Service { - db, - }, - uiaa: uiaa::Service { - db, - }, - users: users::Service { - db, - connections: StdMutex::new(BTreeMap::new()), - }, - account_data: account_data::Service { - db, - }, - presence: presence::Service::build(db, config), - admin: admin::Service::build(), - key_backups: key_backups::Service { - db, - }, - media: media::Service { - db, - url_preview_mutex: RwLock::new(HashMap::new()), - }, - sending: sending::Service::build(db, config), - - globals: globals::Service::load(db, config, tracing_reload_handle)?, - }) - } - - async fn memory_usage(&self) -> String { - let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); - let server_visibility_cache = self - .rooms - .state_accessor - .server_visibility_cache - .lock() - .unwrap() - .len(); - let user_visibility_cache = self - .rooms - .state_accessor - .user_visibility_cache - .lock() - .unwrap() - .len(); - let stateinfo_cache = self - .rooms - .state_compressor - .stateinfo_cache - .lock() - .unwrap() - .len(); - let lasttimelinecount_cache = self - .rooms - .timeline - .lasttimelinecount_cache - .lock() - .await - .len(); - let roomid_spacehierarchy_cache = self - .rooms - .spaces - .roomid_spacehierarchy_cache - .lock() - .await - .len(); - let resolver_overrides_cache = self.globals.resolver.overrides.read().unwrap().len(); - let resolver_destinations_cache = self.globals.resolver.destinations.read().await.len(); - let bad_event_ratelimiter = self.globals.bad_event_ratelimiter.read().await.len(); - let bad_query_ratelimiter = self.globals.bad_query_ratelimiter.read().await.len(); - let bad_signature_ratelimiter = self.globals.bad_signature_ratelimiter.read().await.len(); - - format!( - "\ -lazy_load_waiting: {lazy_load_waiting} -server_visibility_cache: {server_visibility_cache} -user_visibility_cache: {user_visibility_cache} -stateinfo_cache: {stateinfo_cache} -lasttimelinecount_cache: {lasttimelinecount_cache} -roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache} -resolver_overrides_cache: {resolver_overrides_cache} -resolver_destinations_cache: {resolver_destinations_cache} -bad_event_ratelimiter: {bad_event_ratelimiter} -bad_query_ratelimiter: {bad_query_ratelimiter} -bad_signature_ratelimiter: {bad_signature_ratelimiter} -" - ) - } - - async fn clear_caches(&self, amount: u32) { - if amount > 0 { - self.rooms - .lazy_loading - .lazy_load_waiting - .lock() - .await - .clear(); - } - if amount > 1 { - self.rooms - .state_accessor - .server_visibility_cache - .lock() - .unwrap() - .clear(); - } - if amount > 2 { - self.rooms - .state_accessor - .user_visibility_cache - .lock() - .unwrap() - .clear(); - } - if amount > 3 { - self.rooms - .state_compressor - .stateinfo_cache - .lock() - .unwrap() - .clear(); - } - if amount > 4 { - self.rooms - .timeline - .lasttimelinecount_cache - .lock() - .await - .clear(); - } - if amount > 5 { - self.rooms - .spaces - .roomid_spacehierarchy_cache - .lock() - .await - .clear(); - } - if amount > 6 { - self.globals.resolver.overrides.write().unwrap().clear(); - self.globals.resolver.destinations.write().await.clear(); - } - if amount > 7 { - self.globals.resolver.resolver.clear_cache(); - } - if amount > 8 { - self.globals.bad_event_ratelimiter.write().await.clear(); - } - if amount > 9 { - self.globals.bad_query_ratelimiter.write().await.clear(); - } - if amount > 10 { - self.globals.bad_signature_ratelimiter.write().await.clear(); - } - } +#[inline] +pub fn available() -> bool { + SERVICES + .read() + .expect("SERVICES locked for reading") + .is_some() } diff --git a/src/service/pdu.rs b/src/service/pdu.rs index 51f276d4..4912f1c4 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -23,40 +23,40 @@ use crate::{services, Error}; /// Content hashes of a PDU. #[derive(Clone, Debug, Deserialize, Serialize)] -pub(crate) struct EventHash { +pub struct EventHash { /// The SHA-256 hash. - pub(crate) sha256: String, + pub sha256: String, } #[derive(Clone, Deserialize, Serialize, Debug)] -pub(crate) struct PduEvent { - pub(crate) event_id: Arc, - pub(crate) room_id: OwnedRoomId, - pub(crate) sender: OwnedUserId, +pub struct PduEvent { + pub event_id: Arc, + pub room_id: OwnedRoomId, + pub sender: OwnedUserId, #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) origin: Option, - pub(crate) origin_server_ts: UInt, + pub origin: Option, + pub origin_server_ts: UInt, #[serde(rename = "type")] - pub(crate) kind: TimelineEventType, - pub(crate) content: Box, + pub kind: TimelineEventType, + pub content: Box, #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) state_key: Option, - pub(crate) prev_events: Vec>, - pub(crate) depth: UInt, - pub(crate) auth_events: Vec>, + pub state_key: Option, + pub prev_events: Vec>, + pub depth: UInt, + pub auth_events: Vec>, #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) redacts: Option>, + pub redacts: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) unsigned: Option>, - pub(crate) hashes: EventHash, + pub unsigned: Option>, + pub hashes: EventHash, #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) signatures: Option>, /* BTreeMap, BTreeMap> */ + pub signatures: Option>, /* BTreeMap, BTreeMap> */ } impl PduEvent { #[tracing::instrument(skip(self))] - pub(crate) fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> { + pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> { self.unsigned = None; let mut content = serde_json::from_str(self.content.get()) @@ -76,7 +76,7 @@ impl PduEvent { Ok(()) } - pub(crate) fn remove_transaction_id(&mut self) -> crate::Result<()> { + pub fn remove_transaction_id(&mut self) -> crate::Result<()> { if let Some(unsigned) = &self.unsigned { let mut unsigned: BTreeMap> = serde_json::from_str(unsigned.get()) .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; @@ -87,7 +87,7 @@ impl PduEvent { Ok(()) } - pub(crate) fn add_age(&mut self) -> crate::Result<()> { + pub fn add_age(&mut self) -> crate::Result<()> { let mut unsigned: BTreeMap> = self .unsigned .as_ref() @@ -118,7 +118,7 @@ impl PduEvent { /// > serving /// > such events over the Client-Server API. #[must_use] - pub(crate) fn copy_redacts(&self) -> (Option>, Box) { + pub fn copy_redacts(&self) -> (Option>, Box) { if self.kind == TimelineEventType::RoomRedaction { if let Ok(mut content) = serde_json::from_str::(self.content.get()) { if let Some(redacts) = content.redacts { @@ -137,7 +137,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_sync_room_event(&self) -> Raw { + pub fn to_sync_room_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); let mut json = json!({ "content": content, @@ -162,7 +162,7 @@ impl PduEvent { /// This only works for events that are also AnyRoomEvents. #[tracing::instrument(skip(self))] - pub(crate) fn to_any_event(&self) -> Raw { + pub fn to_any_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); let mut json = json!({ "content": content, @@ -187,7 +187,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_room_event(&self) -> Raw { + pub fn to_room_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); let mut json = json!({ "content": content, @@ -212,7 +212,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_message_like_event(&self) -> Raw { + pub fn to_message_like_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); let mut json = json!({ "content": content, @@ -237,7 +237,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_state_event(&self) -> Raw { + pub fn to_state_event(&self) -> Raw { let mut json = json!({ "content": self.content, "type": self.kind, @@ -256,7 +256,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_sync_state_event(&self) -> Raw { + pub fn to_sync_state_event(&self) -> Raw { let mut json = json!({ "content": self.content, "type": self.kind, @@ -274,7 +274,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_stripped_state_event(&self) -> Raw { + pub fn to_stripped_state_event(&self) -> Raw { let json = json!({ "content": self.content, "type": self.kind, @@ -286,7 +286,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_stripped_spacechild_state_event(&self) -> Raw { + pub fn to_stripped_spacechild_state_event(&self) -> Raw { let json = json!({ "content": self.content, "type": self.kind, @@ -299,7 +299,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub(crate) fn to_member_event(&self) -> Raw> { + pub fn to_member_event(&self) -> Raw> { let mut json = json!({ "content": self.content, "type": self.kind, @@ -320,7 +320,7 @@ impl PduEvent { /// This does not return a full `Pdu` it is only to satisfy ruma's types. #[tracing::instrument] - pub(crate) fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box { + pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box { if let Some(unsigned) = pdu_json .get_mut("unsigned") .and_then(|val| val.as_object_mut()) @@ -357,7 +357,7 @@ impl PduEvent { to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") } - pub(crate) fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result { + pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result { json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); serde_json::from_value(serde_json::to_value(json).expect("valid JSON")) @@ -405,7 +405,7 @@ impl Ord for PduEvent { /// /// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap`. -pub(crate) fn gen_event_id_canonical_json( +pub fn gen_event_id_canonical_json( pdu: &RawJsonValue, room_version_id: &RoomVersionId, ) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { @@ -426,11 +426,11 @@ pub(crate) fn gen_event_id_canonical_json( /// Build the start of a PDU in order to add it to the Database. #[derive(Debug, Deserialize)] -pub(crate) struct PduBuilder { +pub struct PduBuilder { #[serde(rename = "type")] - pub(crate) event_type: TimelineEventType, - pub(crate) content: Box, - pub(crate) unsigned: Option>, - pub(crate) state_key: Option, - pub(crate) redacts: Option>, + pub event_type: TimelineEventType, + pub content: Box, + pub unsigned: Option>, + pub state_key: Option, + pub redacts: Option>, } diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index 6cd1e822..6f0f58f8 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -2,7 +2,7 @@ use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Returns the latest presence event for the given user. fn get_presence(&self, user_id: &UserId) -> Result>; diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index d2d1bb39..58be2ddf 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -2,7 +2,7 @@ mod data; use std::{sync::Arc, time::Duration}; -pub(crate) use data::Data; +pub use data::Data; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, @@ -10,19 +10,19 @@ use ruma::{ OwnedUserId, UInt, UserId, }; use serde::{Deserialize, Serialize}; -use tokio::{sync::Mutex, time::sleep}; +use tokio::{sync::Mutex, task::JoinHandle, time::sleep}; use tracing::{debug, error}; use crate::{ - services, - utils::{self, user_id::user_is_local}, + services, user_is_local, + utils::{self}, Config, Error, Result, }; /// Represents data required to be kept in order to implement the presence /// specification. #[derive(Serialize, Deserialize, Debug, Clone)] -pub(crate) struct Presence { +pub struct Presence { state: PresenceState, currently_active: bool, last_active_ts: u64, @@ -30,9 +30,8 @@ pub(crate) struct Presence { } impl Presence { - pub(crate) fn new( - state: PresenceState, currently_active: bool, last_active_ts: u64, status_msg: Option, - ) -> Self { + #[must_use] + pub fn new(state: PresenceState, currently_active: bool, last_active_ts: u64, status_msg: Option) -> Self { Self { state, currently_active, @@ -41,21 +40,21 @@ impl Presence { } } - pub(crate) fn from_json_bytes_to_event(bytes: &[u8], user_id: &UserId) -> Result { + pub fn from_json_bytes_to_event(bytes: &[u8], user_id: &UserId) -> Result { let presence = Self::from_json_bytes(bytes)?; presence.to_presence_event(user_id) } - pub(crate) fn from_json_bytes(bytes: &[u8]) -> Result { + pub fn from_json_bytes(bytes: &[u8]) -> Result { serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database")) } - pub(crate) fn to_json_bytes(&self) -> Result> { + pub fn to_json_bytes(&self) -> Result> { serde_json::to_vec(self).map_err(|_| Error::bad_database("Could not serialize Presence to JSON")) } /// Creates a PresenceEvent from available data. - pub(crate) fn to_presence_event(&self, user_id: &UserId) -> Result { + pub fn to_presence_event(&self, user_id: &UserId) -> Result { let now = utils::millis_since_unix_epoch(); let last_active_ago = if self.currently_active { None @@ -77,37 +76,55 @@ impl Presence { } } -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, - pub(crate) timer_sender: loole::Sender<(OwnedUserId, Duration)>, +pub struct Service { + pub db: Arc, + pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, timer_receiver: Mutex>, + handler_join: Mutex>>, timeout_remote_users: bool, } impl Service { - pub(crate) fn build(db: &'static dyn Data, config: &Config) -> Arc { + pub fn build(db: Arc, config: &Config) -> Arc { let (timer_sender, timer_receiver) = loole::unbounded(); - Arc::new(Self { db, timer_sender, timer_receiver: Mutex::new(timer_receiver), + handler_join: Mutex::new(None), timeout_remote_users: config.presence_timeout_remote_users, }) } - pub(crate) fn start_handler(self: &Arc) { + pub async fn start_handler(self: &Arc) { let self_ = Arc::clone(self); - tokio::spawn(async move { + let handle = services().server.runtime().spawn(async move { self_ .handler() .await .expect("Failed to start presence handler"); }); + + _ = self.handler_join.lock().await.insert(handle); + } + + pub async fn close(&self) { + self.interrupt(); + if let Some(handler_join) = self.handler_join.lock().await.take() { + if let Err(e) = handler_join.await { + error!("Failed to shutdown: {e:?}"); + } + } + } + + pub fn interrupt(&self) { + if !self.timer_sender.is_closed() { + self.timer_sender.close(); + } } /// Returns the latest presence event for the given user. - pub(crate) fn get_presence(&self, user_id: &UserId) -> Result> { + pub fn get_presence(&self, user_id: &UserId) -> Result> { if let Some((_, presence)) = self.db.get_presence(user_id)? { Ok(Some(presence)) } else { @@ -117,7 +134,7 @@ impl Service { /// Pings the presence of the given user in the given room, setting the /// specified state. - pub(crate) fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { + pub fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { const REFRESH_TIMEOUT: u64 = 60 * 25 * 1000; let last_presence = self.db.get_presence(user_id)?; @@ -146,7 +163,7 @@ impl Service { } /// Adds a presence event which will be saved until a new event replaces it. - pub(crate) fn set_presence( + pub fn set_presence( &self, user_id: &UserId, state: &PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { @@ -179,11 +196,11 @@ impl Service { /// /// TODO: Why is this not used? #[allow(dead_code)] - pub(crate) fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) } + pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) } /// Returns the most recent presence updates that happened after the event /// with id `since`. - pub(crate) fn presence_since(&self, since: u64) -> Box)>> { + pub fn presence_since(&self, since: u64) -> Box)> + '_> { self.db.presence_since(since) } @@ -191,24 +208,16 @@ impl Service { let mut presence_timers = FuturesUnordered::new(); let receiver = self.timer_receiver.lock().await; loop { + debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { - event = receiver.recv_async() => { - - match event { - Ok((user_id, timeout)) => { - debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); - presence_timers.push(presence_timer(user_id, timeout)); - } - Err(e) => { - // generally shouldn't happen - error!("Failed to receive presence timer through channel: {e}"); - } - } - } - - Some(user_id) = presence_timers.next() => { - process_presence_timer(&user_id)?; - } + Some(user_id) = presence_timers.next() => process_presence_timer(&user_id)?, + event = receiver.recv_async() => match event { + Err(_e) => return Ok(()), + Ok((user_id, timeout)) => { + debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); + presence_timers.push(presence_timer(user_id, timeout)); + }, + }, } } } diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index 0a5e2eb9..b58cd3fc 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -5,7 +5,7 @@ use ruma::{ use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result>; diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 3f7c5d80..19a570c4 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,8 +1,8 @@ mod data; -use std::{fmt::Debug, mem}; +use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -pub(crate) use data::Data; +pub use data::Data; use ipaddress::IPAddress; use ruma::{ api::{ @@ -24,27 +24,28 @@ use tracing::{info, trace, warn}; use crate::{debug_info, services, Error, PduEvent, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { - pub(crate) fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { + pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { self.db.set_pusher(sender, pusher) } - pub(crate) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { self.db.get_pusher(sender, pushkey) } - pub(crate) fn get_pushers(&self, sender: &UserId) -> Result> { self.db.get_pushers(sender) } + pub fn get_pushers(&self, sender: &UserId) -> Result> { self.db.get_pushers(sender) } - pub(crate) fn get_pushkeys(&self, sender: &UserId) -> Box>> { + #[must_use] + pub fn get_pushkeys(&self, sender: &UserId) -> Box> + '_> { self.db.get_pushkeys(sender) } #[tracing::instrument(skip(self, dest, request))] - pub(crate) async fn send_request(&self, dest: &str, request: T) -> Result + pub async fn send_request(&self, dest: &str, request: T) -> Result where T: OutgoingRequest + Debug, { @@ -131,7 +132,7 @@ impl Service { } #[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))] - pub(crate) async fn send_push_notice( + pub async fn send_push_notice( &self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent, ) -> Result<()> { let mut notify = None; @@ -176,7 +177,7 @@ impl Service { } #[tracing::instrument(skip(self, user, ruleset, pdu))] - pub(crate) fn get_actions<'a>( + pub fn get_actions<'a>( &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, pdu: &Raw, room_id: &RoomId, ) -> Result<&'a [Action]> { diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index acb2ecd6..095d6e66 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -2,7 +2,7 @@ use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Creates or updates the alias to the given room id. fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>; diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 0c95a8a9..6e8b386a 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,37 +1,37 @@ mod data; -pub(crate) use data::Data; +use std::sync::Arc; + +pub use data::Data; use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { #[tracing::instrument(skip(self))] - pub(crate) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { - self.db.set_alias(alias, room_id) - } + pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { self.db.set_alias(alias, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) } + pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) } #[tracing::instrument(skip(self))] - pub(crate) fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { + pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { self.db.resolve_local_alias(alias) } #[tracing::instrument(skip(self))] - pub(crate) fn local_aliases_for_room<'a>( + pub fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, ) -> Box> + 'a> { self.db.local_aliases_for_room(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { + pub fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { self.db.all_local_aliases() } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index baa63ebd..f77d2d90 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result>>; fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc<[u64]>) -> Result<()>; } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 4a59d989..4c9152b0 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -4,18 +4,18 @@ use std::{ sync::Arc, }; -pub(crate) use data::Data; +pub use data::Data; use ruma::{api::client::error::ErrorKind, EventId, RoomId}; use tracing::{debug, error, trace, warn}; use crate::{services, utils::debug_slice_truncated, Error, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { - pub(crate) async fn event_ids_iter<'a>( + pub async fn event_ids_iter<'a>( &self, room_id: &RoomId, starting_events_: Vec>, ) -> Result> + 'a> { let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); @@ -31,7 +31,7 @@ impl Service { } #[tracing::instrument(skip(self), fields(starting_events = debug_slice_truncated(starting_events, 5)))] - pub(crate) async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result> { + pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result> { const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); @@ -172,18 +172,18 @@ impl Service { Ok(found) } - pub(crate) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { self.db.get_cached_eventid_authchain(key) } #[tracing::instrument(skip(self))] - pub(crate) fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) -> Result<()> { + pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) -> Result<()> { self.db .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) } #[tracing::instrument(skip(self))] - pub(crate) fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) -> Result<()> { + pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) -> Result<()> { self.db .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) } diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index 6efb1af4..691b8604 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -2,7 +2,7 @@ use ruma::{OwnedRoomId, RoomId}; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Adds the room to the public room directory fn set_public(&self, room_id: &RoomId) -> Result<()>; diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 63a0abbf..ab69d003 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,24 +1,26 @@ mod data; -pub(crate) use data::Data; +use std::sync::Arc; + +pub use data::Data; use ruma::{OwnedRoomId, RoomId}; use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { #[tracing::instrument(skip(self))] - pub(crate) fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } + pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } + pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_public_room(&self, room_id: &RoomId) -> Result { self.db.is_public_room(room_id) } + pub fn is_public_room(&self, room_id: &RoomId) -> Result { self.db.is_public_room(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn public_rooms(&self) -> impl Iterator> + '_ { self.db.public_rooms() } + pub fn public_rooms(&self) -> impl Iterator> + '_ { self.db.public_rooms() } } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 1a965c0e..499d1d63 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,11 +1,17 @@ +mod parse_incoming_pdu; +mod signing_keys; +pub struct Service; + use std::{ cmp, - collections::{hash_map, HashSet}, + collections::{hash_map, BTreeMap, HashMap, HashSet}, pin::Pin, + sync::Arc, time::{Duration, Instant}, }; use futures_util::Future; +pub use parse_incoming_pdu::parse_incoming_pdu; use ruma::{ api::{ client::error::ErrorKind, @@ -24,14 +30,7 @@ use tokio::sync::RwLock; use tracing::{debug, error, info, trace, warn}; use super::state_compressor::CompressedStateEvent; -use crate::{ - debug_error, debug_info, - service::{pdu, Arc, BTreeMap, HashMap, Result}, - services, Error, PduEvent, -}; - -mod signing_keys; -pub(crate) struct Service; +use crate::{debug_error, debug_info, pdu, services, Error, PduEvent, Result}; // We use some AsyncRecursiveType hacks here so we can call async funtion // recursively. @@ -70,7 +69,7 @@ impl Service { /// 14. Check if the event passes auth based on the "current state" of the /// room, if not soft fail it #[tracing::instrument(skip(self, origin, value, is_timeline_event, pub_key_map), name = "pdu")] - pub(crate) async fn handle_incoming_pdu<'a>( + pub async fn handle_incoming_pdu<'a>( &self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId, value: BTreeMap, is_timeline_event: bool, pub_key_map: &'a RwLock>>, @@ -207,7 +206,7 @@ impl Service { skip(self, origin, event_id, room_id, pub_key_map, eventid_info, create_event, first_pdu_in_room), name = "prev" )] - pub(crate) async fn handle_prev_pdu<'a>( + pub async fn handle_prev_pdu<'a>( &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, pub_key_map: &'a RwLock>>, eventid_info: &mut HashMap, (Arc, BTreeMap)>, @@ -427,7 +426,7 @@ impl Service { }) } - pub(crate) async fn upgrade_outlier_to_timeline_pdu( + pub async fn upgrade_outlier_to_timeline_pdu( &self, incoming_pdu: Arc, val: BTreeMap, create_event: &PduEvent, origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock>>, ) -> Result>> { @@ -748,7 +747,7 @@ impl Service { // TODO: if we know the prev_events of the incoming event we can avoid the // request and build the state from a known point and resolve if > 1 prev_event #[tracing::instrument(skip_all, name = "state")] - pub(crate) async fn state_at_incoming_degree_one( + pub async fn state_at_incoming_degree_one( &self, incoming_pdu: &Arc, ) -> Result>>> { let prev_event = &*incoming_pdu.prev_events[0]; @@ -796,7 +795,7 @@ impl Service { } #[tracing::instrument(skip_all, name = "state")] - pub(crate) async fn state_at_incoming_resolved( + pub async fn state_at_incoming_resolved( &self, incoming_pdu: &Arc, room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result>>> { debug!("Calculating state at event using state res"); @@ -988,7 +987,7 @@ impl Service { /// b. Look at outlier pdu tree /// c. Ask origin server over federation /// d. TODO: Ask other servers over federation? - pub(crate) fn fetch_and_handle_outliers<'a>( + pub fn fetch_and_handle_outliers<'a>( &'a self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, ) -> AsyncRecursiveCanonicalJsonVec<'a> { @@ -1275,7 +1274,7 @@ impl Service { /// Returns Ok if the acl allows the server #[tracing::instrument(skip_all)] - pub(crate) fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { + pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { let acl_event = if let Some(acl) = services() .rooms diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs new file mode 100644 index 00000000..133ab66e --- /dev/null +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -0,0 +1,31 @@ +use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; +use serde_json::value::RawValue as RawJsonValue; +use tracing::warn; + +use crate::{service::pdu::gen_event_id_canonical_json, services, Error, Result}; + +pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + + let room_id: OwnedRoomId = value + .get("room_id") + .and_then(|id| RoomId::parse(id.as_str()?).ok()) + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; + + let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else { + return Err(Error::Err(format!("Server is not in room {room_id}"))); + }; + + let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + }; + + Ok((event_id, value, room_id)) +} diff --git a/src/service/rooms/event_handler/signing_keys.rs b/src/service/rooms/event_handler/signing_keys.rs index 81986a12..98751034 100644 --- a/src/service/rooms/event_handler/signing_keys.rs +++ b/src/service/rooms/event_handler/signing_keys.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashSet, + collections::{BTreeMap, HashMap, HashSet}, time::{Duration, SystemTime}, }; @@ -21,13 +21,10 @@ use serde_json::value::RawValue as RawJsonValue; use tokio::sync::{RwLock, RwLockWriteGuard}; use tracing::{debug, error, info, trace, warn}; -use crate::{ - service::{BTreeMap, HashMap, Result}, - services, Error, -}; +use crate::{services, Error, Result}; impl super::Service { - pub(crate) async fn fetch_required_signing_keys<'a, E>( + pub async fn fetch_required_signing_keys<'a, E>( &'a self, events: E, pub_key_map: &RwLock>>, ) -> Result<()> where @@ -265,7 +262,7 @@ impl super::Service { Ok(()) } - pub(crate) async fn fetch_join_signing_keys( + pub async fn fetch_join_signing_keys( &self, event: &create_join_event::v2::Response, room_version: &RoomVersionId, pub_key_map: &RwLock>>, ) -> Result<()> { @@ -342,7 +339,7 @@ impl super::Service { /// Search the DB for the signing keys of the given server, if we don't have /// them fetch them from the server and save to our DB. #[tracing::instrument(skip_all)] - pub(crate) async fn fetch_signing_keys_for_server( + pub async fn fetch_signing_keys_for_server( &self, origin: &ServerName, signature_ids: Vec, ) -> Result> { let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index 36db0ac8..890a2f98 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -2,7 +2,7 @@ use ruma::{DeviceId, RoomId, UserId}; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, ) -> Result; diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 104da8fc..565a186d 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,24 +1,25 @@ mod data; -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; -pub(crate) use data::Data; +pub use data::Data; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; -use super::timeline::PduCount; -use crate::Result; +use crate::{PduCount, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, #[allow(clippy::type_complexity)] - pub(crate) lazy_load_waiting: - Mutex>>, + pub lazy_load_waiting: Mutex>>, } impl Service { #[tracing::instrument(skip(self))] - pub(crate) fn lazy_load_was_sent_before( + pub fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, ) -> Result { self.db @@ -26,7 +27,7 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) async fn lazy_load_mark_sent( + pub async fn lazy_load_mark_sent( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, count: PduCount, ) { @@ -37,7 +38,7 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) async fn lazy_load_confirm_delivery( + pub async fn lazy_load_confirm_delivery( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, ) -> Result<()> { if let Some(user_ids) = self.lazy_load_waiting.lock().await.remove(&( @@ -56,7 +57,7 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { + pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { self.db.lazy_load_reset(user_id, device_id, room_id) } } diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 58d570d0..d702b203 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -2,7 +2,7 @@ use ruma::{OwnedRoomId, RoomId}; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn exists(&self, room_id: &RoomId) -> Result; fn iter_ids<'a>(&'a self) -> Box> + 'a>; fn is_disabled(&self, room_id: &RoomId) -> Result; diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index dd70745a..e14d539d 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,32 +1,36 @@ mod data; -pub(crate) use data::Data; +use std::sync::Arc; + +pub use data::Data; use ruma::{OwnedRoomId, RoomId}; use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { /// Checks if a room exists. #[tracing::instrument(skip(self))] - pub(crate) fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } + pub fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } - pub(crate) fn iter_ids<'a>(&'a self) -> Box> + 'a> { self.db.iter_ids() } + #[must_use] + pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { self.db.iter_ids() } - pub(crate) fn is_disabled(&self, room_id: &RoomId) -> Result { self.db.is_disabled(room_id) } + pub fn is_disabled(&self, room_id: &RoomId) -> Result { self.db.is_disabled(room_id) } - pub(crate) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { self.db.disable_room(room_id, disabled) } - pub(crate) fn is_banned(&self, room_id: &RoomId) -> Result { self.db.is_banned(room_id) } + pub fn is_banned(&self, room_id: &RoomId) -> Result { self.db.is_banned(room_id) } - pub(crate) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) } + pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) } - pub(crate) fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { + #[must_use] + pub fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { self.db.list_banned_rooms() } } diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index 1b78b1e2..baf3f7b5 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -1,25 +1,25 @@ -pub(crate) mod alias; -pub(crate) mod auth_chain; -pub(crate) mod directory; -pub(crate) mod event_handler; -pub(crate) mod lazy_loading; -pub(crate) mod metadata; -pub(crate) mod outlier; -pub(crate) mod pdu_metadata; -pub(crate) mod read_receipt; -pub(crate) mod search; -pub(crate) mod short; -pub(crate) mod spaces; -pub(crate) mod state; -pub(crate) mod state_accessor; -pub(crate) mod state_cache; -pub(crate) mod state_compressor; -pub(crate) mod threads; -pub(crate) mod timeline; -pub(crate) mod typing; -pub(crate) mod user; +pub mod alias; +pub mod auth_chain; +pub mod directory; +pub mod event_handler; +pub mod lazy_loading; +pub mod metadata; +pub mod outlier; +pub mod pdu_metadata; +pub mod read_receipt; +pub mod search; +pub mod short; +pub mod spaces; +pub mod state; +pub mod state_accessor; +pub mod state_cache; +pub mod state_compressor; +pub mod threads; +pub mod timeline; +pub mod typing; +pub mod user; -pub(crate) trait Data: +pub trait Data: alias::Data + auth_chain::Data + directory::Data @@ -40,25 +40,25 @@ pub(crate) trait Data: { } -pub(crate) struct Service { - pub(crate) alias: alias::Service, - pub(crate) auth_chain: auth_chain::Service, - pub(crate) directory: directory::Service, - pub(crate) event_handler: event_handler::Service, - pub(crate) lazy_loading: lazy_loading::Service, - pub(crate) metadata: metadata::Service, - pub(crate) outlier: outlier::Service, - pub(crate) pdu_metadata: pdu_metadata::Service, - pub(crate) read_receipt: read_receipt::Service, - pub(crate) search: search::Service, - pub(crate) short: short::Service, - pub(crate) state: state::Service, - pub(crate) state_accessor: state_accessor::Service, - pub(crate) state_cache: state_cache::Service, - pub(crate) state_compressor: state_compressor::Service, - pub(crate) timeline: timeline::Service, - pub(crate) threads: threads::Service, - pub(crate) typing: typing::Service, - pub(crate) spaces: spaces::Service, - pub(crate) user: user::Service, +pub struct Service { + pub alias: alias::Service, + pub auth_chain: auth_chain::Service, + pub directory: directory::Service, + pub event_handler: event_handler::Service, + pub lazy_loading: lazy_loading::Service, + pub metadata: metadata::Service, + pub outlier: outlier::Service, + pub pdu_metadata: pdu_metadata::Service, + pub read_receipt: read_receipt::Service, + pub search: search::Service, + pub short: short::Service, + pub state: state::Service, + pub state_accessor: state_accessor::Service, + pub state_cache: state_cache::Service, + pub state_compressor: state_compressor::Service, + pub timeline: timeline::Service, + pub threads: threads::Service, + pub typing: typing::Service, + pub spaces: spaces::Service, + pub user: user::Service, } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs index a5202a0f..18eb3190 100644 --- a/src/service/rooms/outlier/data.rs +++ b/src/service/rooms/outlier/data.rs @@ -2,7 +2,7 @@ use ruma::{CanonicalJsonObject, EventId}; use crate::{PduEvent, Result}; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 0de6d366..9ec4010c 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,17 +1,19 @@ mod data; -pub(crate) use data::Data; +use std::sync::Arc; + +pub use data::Data; use ruma::{CanonicalJsonObject, EventId}; use crate::{PduEvent, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { /// Returns the pdu from the outlier tree. - pub(crate) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu_json(event_id) } @@ -19,13 +21,11 @@ impl Service { /// /// TODO: use this? #[allow(dead_code)] - pub(crate) fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { - self.db.get_outlier_pdu(event_id) - } + pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu(event_id) } /// Append the PDU as an outlier. #[tracing::instrument(skip(self, pdu))] - pub(crate) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { self.db.add_pdu_outlier(event_id, pdu) } } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index d4026f55..ccc14edd 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use ruma::{EventId, RoomId, UserId}; -use crate::{service::rooms::timeline::PduCount, PduEvent, Result}; +use crate::{PduCount, PduEvent, Result}; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn add_relation(&self, from: u64, to: u64) -> Result<()>; #[allow(clippy::type_complexity)] fn relations_until<'a>( diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 0908c573..7e0da835 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,7 +1,8 @@ mod data; + use std::sync::Arc; -pub(crate) use data::Data; +pub use data::Data; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, @@ -9,11 +10,10 @@ use ruma::{ }; use serde::Deserialize; -use super::timeline::PduCount; -use crate::{services, PduEvent, Result}; +use crate::{services, PduCount, PduEvent, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } #[derive(Clone, Debug, Deserialize)] @@ -28,7 +28,7 @@ struct ExtractRelatesToEventId { impl Service { #[tracing::instrument(skip(self, from, to))] - pub(crate) fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { + pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { match (from, to) { (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), _ => { @@ -40,7 +40,7 @@ impl Service { } #[allow(clippy::too_many_arguments)] - pub(crate) fn paginate_relations_with_filter( + pub fn paginate_relations_with_filter( &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option, filter_rel_type: &Option, from: &Option, to: &Option, limit: &Option, recurse: bool, dir: Direction, @@ -174,7 +174,7 @@ impl Service { } } - pub(crate) fn relations_until<'a>( + pub fn relations_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, ) -> Result> { let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; @@ -216,22 +216,18 @@ impl Service { } #[tracing::instrument(skip(self, room_id, event_ids))] - pub(crate) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { self.db.mark_as_referenced(room_id, event_ids) } #[tracing::instrument(skip(self))] - pub(crate) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { + pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.db.is_event_referenced(room_id, event_id) } #[tracing::instrument(skip(self))] - pub(crate) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.db.mark_event_soft_failed(event_id) - } + pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.db.is_event_soft_failed(event_id) - } + pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { self.db.is_event_soft_failed(event_id) } } diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 006b5df8..4fe7be59 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -9,7 +9,7 @@ use crate::Result; type AnySyncEphemeralRoomEventIter<'a> = Box)>> + 'a>; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Replaces the previous read receipt. fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()>; diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index 006204ab..a5b9c325 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -1,17 +1,19 @@ mod data; -pub(crate) use data::Data; +use std::sync::Arc; + +pub use data::Data; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; use crate::{services, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { /// Replaces the previous read receipt. - pub(crate) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { + pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { self.db.readreceipt_update(user_id, room_id, event)?; services().sending.flush_room(room_id)?; @@ -21,7 +23,7 @@ impl Service { /// Returns an iterator over the most recent read_receipts in a room that /// happened after the event with id `since`. #[tracing::instrument(skip(self))] - pub(crate) fn readreceipts_since<'a>( + pub fn readreceipts_since<'a>( &'a self, room_id: &RoomId, since: u64, ) -> impl Iterator)>> + 'a { self.db.readreceipts_since(room_id, since) @@ -29,18 +31,18 @@ impl Service { /// Sets a private read marker at `count`. #[tracing::instrument(skip(self))] - pub(crate) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { self.db.private_read_set(room_id, user_id, count) } /// Returns the private read marker. #[tracing::instrument(skip(self))] - pub(crate) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { self.db.private_read_get(room_id, user_id) } /// Returns the count of the last typing update in this room. - pub(crate) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.last_privateread_update(user_id, room_id) } } diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 48edc42b..96439adf 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -4,7 +4,7 @@ use crate::Result; type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>; diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 50843ae4..569761a3 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,22 +1,24 @@ mod data; -pub(crate) use data::Data; +use std::sync::Arc; + +pub use data::Data; use ruma::RoomId; use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { #[tracing::instrument(skip(self))] - pub(crate) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { self.db.index_pdu(shortroomid, pdu_id, message_body) } #[tracing::instrument(skip(self))] - pub(crate) fn search_pdus<'a>( + pub fn search_pdus<'a>( &'a self, room_id: &RoomId, search_string: &str, ) -> Result> + 'a, Vec)>> { self.db.search_pdus(room_id, search_string) diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 4bbed670..d0e2085f 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -4,7 +4,7 @@ use ruma::{events::StateEventType, EventId, RoomId}; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; fn multi_get_or_create_shorteventid(&self, event_id: &[&EventId]) -> Result>; diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index ed2eae97..657de66a 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,48 +1,48 @@ mod data; use std::sync::Arc; -pub(crate) use data::Data; +pub use data::Data; use ruma::{events::StateEventType, EventId, RoomId}; use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { - pub(crate) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { + pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { self.db.get_or_create_shorteventid(event_id) } - pub(crate) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { + pub fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { self.db.multi_get_or_create_shorteventid(event_ids) } - pub(crate) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { + pub fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { self.db.get_shortstatekey(event_type, state_key) } - pub(crate) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + pub fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { self.db.get_or_create_shortstatekey(event_type, state_key) } - pub(crate) fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { self.db.get_eventid_from_short(shorteventid) } - pub(crate) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { self.db.get_statekey_from_short(shortstatekey) } /// Returns (shortstatehash, already_existed) - pub(crate) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { self.db.get_or_create_shortstatehash(state_hash) } - pub(crate) fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.db.get_shortroomid(room_id) } + pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.db.get_shortroomid(room_id) } - pub(crate) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { self.db.get_or_create_shortroomid(room_id) } } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index d2624197..219b8c39 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -32,9 +32,9 @@ use ruma::{ use tokio::sync::Mutex; use tracing::{debug, error, warn}; -use crate::{debug_info, services, utils::server_name::server_is_ours, Error, Result}; +use crate::{debug_info, server_is_ours, services, Error, Result}; -pub(crate) struct CachedSpaceHierarchySummary { +pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, } @@ -235,11 +235,11 @@ impl Arena { // Note: perhaps use some better form of token rather than just room count #[derive(Debug, PartialEq)] -pub(crate) struct PagnationToken { - pub(crate) skip: UInt, - pub(crate) limit: UInt, - pub(crate) max_depth: UInt, - pub(crate) suggested_only: bool, +pub struct PagnationToken { + pub skip: UInt, + pub limit: UInt, + pub max_depth: UInt, + pub suggested_only: bool, } impl FromStr for PagnationToken { @@ -294,8 +294,8 @@ enum Identifier<'a> { None, } -pub(crate) struct Service { - pub(crate) roomid_spacehierarchy_cache: Mutex>>, +pub struct Service { + pub roomid_spacehierarchy_cache: Mutex>>, } // Here because cannot implement `From` across ruma-federation-api and @@ -338,7 +338,7 @@ impl Service { /// ///Panics if the room does not exist, so a check if the room exists should /// be done - pub(crate) async fn get_federation_hierarchy( + pub async fn get_federation_hierarchy( &self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool, ) -> Result { match self @@ -624,7 +624,7 @@ impl Service { } // TODO: make this a lot less messy - pub(crate) async fn get_client_hierarchy( + pub async fn get_client_hierarchy( &self, sender_user: &UserId, room_id: &RoomId, limit: usize, skip: usize, max_depth: usize, suggested_only: bool, ) -> Result { diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 88d9fe92..f486f1f8 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -5,7 +5,7 @@ use tokio::sync::MutexGuard; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Returns the last state hash key added to the db for the given room. fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result>; diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index bbc561f2..8031d566 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -4,7 +4,7 @@ use std::{ sync::Arc, }; -pub(crate) use data::Data; +pub use data::Data; use ruma::{ api::client::error::ErrorKind, events::{ @@ -21,13 +21,13 @@ use tracing::warn; use super::state_compressor::CompressedStateEvent; use crate::{services, utils::calculate_hash, Error, PduEvent, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { /// Set the room to the given statehash and update caches. - pub(crate) async fn force_state( + pub async fn force_state( &self, room_id: &RoomId, shortstatehash: u64, @@ -104,7 +104,7 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed))] - pub(crate) fn set_event_state( + pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { let shorteventid = services() @@ -172,7 +172,7 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu))] - pub(crate) fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { let shorteventid = services() .rooms .short @@ -244,7 +244,7 @@ impl Service { } #[tracing::instrument(skip(self, invite_event))] - pub(crate) fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { + pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { let mut state = Vec::new(); // Add recommended events if let Some(e) = @@ -300,7 +300,7 @@ impl Service { /// Set the state hash to a new version, but does not update state_cache. #[tracing::instrument(skip(self))] - pub(crate) fn set_room_state( + pub fn set_room_state( &self, room_id: &RoomId, shortstatehash: u64, @@ -311,7 +311,7 @@ impl Service { /// Returns the room's version. #[tracing::instrument(skip(self))] - pub(crate) fn get_room_version(&self, room_id: &RoomId) -> Result { + pub fn get_room_version(&self, room_id: &RoomId) -> Result { let create_event = services() .rooms .state_accessor @@ -331,15 +331,15 @@ impl Service { Ok(create_event_content.room_version) } - pub(crate) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { + pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.db.get_room_shortstatehash(room_id) } - pub(crate) fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { + pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { self.db.get_forward_extremities(room_id) } - pub(crate) fn set_forward_extremities( + pub fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, @@ -351,7 +351,7 @@ impl Service { /// This fetches auth events from the current state. #[tracing::instrument(skip(self))] - pub(crate) fn get_auth_events( + pub fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, ) -> Result>> { diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 9bc2667f..5fd58864 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -6,7 +6,7 @@ use ruma::{events::StateEventType, EventId, RoomId}; use crate::{PduEvent, Result}; #[async_trait] -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[allow(unused_qualifications)] // async traits diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 43cbb3bd..d2e51361 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -4,7 +4,7 @@ use std::{ sync::{Arc, Mutex}, }; -pub(crate) use data::Data; +pub use data::Data; use lru_cache::LruCache; use ruma::{ events::{ @@ -24,30 +24,28 @@ use tracing::{error, warn}; use crate::{service::pdu::PduBuilder, services, Error, PduEvent, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, - pub(crate) server_visibility_cache: Mutex>, - pub(crate) user_visibility_cache: Mutex>, +pub struct Service { + pub db: Arc, + pub server_visibility_cache: Mutex>, + pub user_visibility_cache: Mutex>, } impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self))] - pub(crate) async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { self.db.state_full_ids(shortstatehash).await } - pub(crate) async fn state_full( - &self, shortstatehash: u64, - ) -> Result>> { + pub async fn state_full(&self, shortstatehash: u64) -> Result>> { self.db.state_full(shortstatehash).await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self))] - pub(crate) fn state_get_id( + pub fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { self.db.state_get_id(shortstatehash, event_type, state_key) @@ -55,7 +53,7 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - pub(crate) fn state_get( + pub fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { self.db.state_get(shortstatehash, event_type, state_key) @@ -90,9 +88,7 @@ impl Service { /// Whether a server is allowed to see an event through federation, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, origin, room_id, event_id))] - pub(crate) fn server_can_see_event( - &self, origin: &ServerName, room_id: &RoomId, event_id: &EventId, - ) -> Result { + pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> Result { let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { return Ok(true); }; @@ -155,7 +151,7 @@ impl Service { /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id, event_id))] - pub(crate) fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result { + pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result { let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { return Ok(true); }; @@ -214,7 +210,7 @@ impl Service { /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id))] - pub(crate) fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; let history_visibility = self @@ -236,22 +232,18 @@ impl Service { } /// Returns the state hash for this pdu. - pub(crate) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.db.pdu_shortstatehash(event_id) - } + pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.db.pdu_shortstatehash(event_id) } /// Returns the full room state. #[tracing::instrument(skip(self))] - pub(crate) async fn room_state_full( - &self, room_id: &RoomId, - ) -> Result>> { + pub async fn room_state_full(&self, room_id: &RoomId) -> Result>> { self.db.room_state_full(room_id).await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self))] - pub(crate) fn room_state_get_id( + pub fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>> { self.db.room_state_get_id(room_id, event_type, state_key) @@ -260,13 +252,13 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self))] - pub(crate) fn room_state_get( + pub fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>> { self.db.room_state_get(room_id, event_type, state_key) } - pub(crate) fn get_name(&self, room_id: &RoomId) -> Result> { + pub fn get_name(&self, room_id: &RoomId) -> Result> { services() .rooms .state_accessor @@ -276,7 +268,7 @@ impl Service { }) } - pub(crate) fn get_avatar(&self, room_id: &RoomId) -> Result> { + pub fn get_avatar(&self, room_id: &RoomId) -> Result> { services() .rooms .state_accessor @@ -287,7 +279,7 @@ impl Service { }) } - pub(crate) fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result> { services() .rooms .state_accessor @@ -298,7 +290,7 @@ impl Service { }) } - pub(crate) async fn user_can_invite( + pub async fn user_can_invite( &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &MutexGuard<'_, ()>, ) -> Result { let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 828d4284..70fcd6d1 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -12,7 +12,7 @@ type StrippedStateEventIter<'a> = Box = Box>)>> + 'a>; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_invited( diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 9f2ccdf4..976d858b 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, sync::Arc}; -pub(crate) use data::Data; +pub use data::Data; use itertools::Itertools; use ruma::{ events::{ @@ -19,19 +19,19 @@ use ruma::{ }; use tracing::{error, warn}; -use crate::{service::appservice::RegistrationInfo, services, utils::user_id::user_is_local, Error, Result}; +use crate::{service::appservice::RegistrationInfo, services, user_is_local, Error, Result}; mod data; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] #[allow(clippy::too_many_arguments)] - pub(crate) fn update_membership( + pub fn update_membership( &self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId, last_state: Option>>, invite_via: Option>, update_joined_count: bool, @@ -210,43 +210,43 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - pub(crate) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } + pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } #[tracing::instrument(skip(self, room_id))] - pub(crate) fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { + pub fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { self.db.get_our_real_users(room_id) } #[tracing::instrument(skip(self, room_id, appservice))] - pub(crate) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { + pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { self.db.appservice_in_room(room_id, appservice) } /// Makes a user forget a room. #[tracing::instrument(skip(self))] - pub(crate) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } /// Returns an iterator of all servers participating in this room. #[tracing::instrument(skip(self))] - pub(crate) fn room_servers<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { + pub fn room_servers<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { self.db.room_servers(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { + pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { self.db.server_in_room(server, room_id) } /// Returns an iterator of all rooms a server participates in (as far as we /// know). #[tracing::instrument(skip(self))] - pub(crate) fn server_rooms<'a>(&'a self, server: &ServerName) -> impl Iterator> + 'a { + pub fn server_rooms<'a>(&'a self, server: &ServerName) -> impl Iterator> + 'a { self.db.server_rooms(server) } /// Returns true if server can see user by sharing at least one room. #[tracing::instrument(skip(self))] - pub(crate) fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result { + pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result { Ok(self .server_rooms(server) .filter_map(Result::ok) @@ -255,7 +255,7 @@ impl Service { /// Returns true if user_a and user_b share at least one room. #[tracing::instrument(skip(self))] - pub(crate) fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result { + pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result { // Minimize number of point-queries by iterating user with least nr rooms let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() { (user_a, user_b) @@ -271,104 +271,88 @@ impl Service { /// Returns an iterator over all joined members of a room. #[tracing::instrument(skip(self))] - pub(crate) fn room_members<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { + pub fn room_members<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { self.db.room_members(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn room_joined_count(&self, room_id: &RoomId) -> Result> { - self.db.room_joined_count(room_id) - } + pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.db.room_joined_count(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn room_invited_count(&self, room_id: &RoomId) -> Result> { - self.db.room_invited_count(room_id) - } + pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.db.room_invited_count(room_id) } /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self))] - pub(crate) fn room_useroncejoined<'a>( - &'a self, room_id: &RoomId, - ) -> impl Iterator> + 'a { + pub fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { self.db.room_useroncejoined(room_id) } /// Returns an iterator over all invited members of a room. #[tracing::instrument(skip(self))] - pub(crate) fn room_members_invited<'a>( - &'a self, room_id: &RoomId, - ) -> impl Iterator> + 'a { + pub fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { self.db.room_members_invited(room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { self.db.get_invite_count(room_id, user_id) } #[tracing::instrument(skip(self))] - pub(crate) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { self.db.get_left_count(room_id, user_id) } /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self))] - pub(crate) fn rooms_joined<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { + pub fn rooms_joined<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { self.db.rooms_joined(user_id) } /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self))] - pub(crate) fn rooms_invited<'a>( + pub fn rooms_invited<'a>( &'a self, user_id: &UserId, ) -> impl Iterator>)>> + 'a { self.db.rooms_invited(user_id) } #[tracing::instrument(skip(self))] - pub(crate) fn invite_state( - &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>>> { + pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { self.db.invite_state(user_id, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn left_state( - &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>>> { + pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { self.db.left_state(user_id, room_id) } /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self))] - pub(crate) fn rooms_left<'a>( + pub fn rooms_left<'a>( &'a self, user_id: &UserId, ) -> impl Iterator>)>> + 'a { self.db.rooms_left(user_id) } #[tracing::instrument(skip(self))] - pub(crate) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.once_joined(user_id, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.is_joined(user_id, room_id) - } + pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_joined(user_id, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_invited(user_id, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.is_left(user_id, room_id) - } + pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } #[tracing::instrument(skip(self))] - pub(crate) fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { + pub fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { self.db.servers_invite_via(room_id) } @@ -377,7 +361,7 @@ impl Service { /// /// See #[tracing::instrument(skip(self))] - pub(crate) fn servers_route_via(&self, room_id: &RoomId) -> Result> { + pub fn servers_route_via(&self, room_id: &RoomId) -> Result> { let most_powerful_user_server = services() .rooms .state_accessor diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 32f902a2..eddc8716 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -3,13 +3,13 @@ use std::{collections::HashSet, sync::Arc}; use super::CompressedStateEvent; use crate::Result; -pub(crate) struct StateDiff { - pub(crate) parent: Option, - pub(crate) added: Arc>, - pub(crate) removed: Arc>, +pub struct StateDiff { + pub parent: Option, + pub added: Arc>, + pub removed: Arc>, } -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn get_statediff(&self, shortstatehash: u64) -> Result; fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 8187d8cc..a3622b7b 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,11 +1,11 @@ -pub(crate) mod data; +pub mod data; use std::{ collections::HashSet, mem::size_of, sync::{Arc, Mutex}, }; -pub(crate) use data::Data; +pub use data::Data; use lru_cache::LruCache; use ruma::{EventId, RoomId}; @@ -42,19 +42,19 @@ type ParentStatesVec = Vec<( type HashSetCompressStateEvent = Result<(u64, Arc>, Arc>)>; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, - pub(crate) stateinfo_cache: StateInfoLruCache, + pub stateinfo_cache: StateInfoLruCache, } -pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; +pub type CompressedStateEvent = [u8; 2 * size_of::()]; impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. #[tracing::instrument(skip(self))] - pub(crate) fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { + pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { if let Some(r) = self .stateinfo_cache .lock() @@ -97,7 +97,7 @@ impl Service { } } - pub(crate) fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result { + pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( &services() @@ -110,9 +110,7 @@ impl Service { } /// Returns shortstatekey, event id - pub(crate) fn parse_compressed_state_event( - &self, compressed_event: &CompressedStateEvent, - ) -> Result<(u64, Arc)> { + pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc)> { Ok(( utils::u64_from_bytes(&compressed_event[0..size_of::()]).expect("bytes have right length"), services().rooms.short.get_eventid_from_short( @@ -140,7 +138,7 @@ impl Service { /// * `parent_states` - A stack with info on shortstatehash, full state, /// added diff and removed diff for each parent layer #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states))] - pub(crate) fn save_state_from_diff( + pub fn save_state_from_diff( &self, shortstatehash: u64, statediffnew: Arc>, statediffremoved: Arc>, diff_to_sibling: usize, mut parent_states: ParentStatesVec, @@ -252,7 +250,7 @@ impl Service { /// Returns the new shortstatehash, and the state diff from the previous /// room state - pub(crate) fn save_state( + pub fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, ) -> HashSetCompressStateEvent { let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index 41a4bdc1..b18f4b79 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -4,7 +4,7 @@ use crate::{PduEvent, Result}; type PduEventIterResult<'a> = Result> + 'a>>; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, ) -> PduEventIterResult<'a>; diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index 578b21d3..05833a91 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -1,8 +1,8 @@ mod data; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; -pub(crate) use data::Data; +pub use data::Data; use ruma::{ api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, events::relation::BundledThread, @@ -12,18 +12,18 @@ use serde_json::json; use crate::{services, Error, PduEvent, Result}; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { - pub(crate) fn threads_until<'a>( + pub fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, ) -> Result> + 'a> { self.db.threads_until(user_id, room_id, until, include) } - pub(crate) fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { + pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { let root_id = &services() .rooms .timeline diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 48b3088e..a036b455 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -5,7 +5,7 @@ use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; use super::PduCount; use crate::{PduEvent, Result}; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; /// Returns the `count` of this pdu's id. diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index ea9f1613..82266437 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,12 +1,11 @@ -pub(crate) mod data; +pub mod data; use std::{ - cmp::Ordering, collections::{BTreeMap, HashMap, HashSet}, sync::Arc, }; -pub(crate) use data::Data; +pub use data::Data; use rand::prelude::SliceRandom; use ruma::{ api::{client::error::ErrorKind, federation}, @@ -35,60 +34,22 @@ use tracing::{debug, error, info, warn}; use super::state_compressor::CompressedStateEvent; use crate::{ - api::server_server, + server_is_ours, + //api::server_server, service::{ self, appservice::NamespaceRegex, pdu::{EventHash, PduBuilder}, + rooms::event_handler::parse_incoming_pdu, }, services, - utils::{self, server_name::server_is_ours}, - Error, PduEvent, Result, + utils::{self}, + Error, + PduCount, + PduEvent, + Result, }; -#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] -pub(crate) enum PduCount { - Backfilled(u64), - Normal(u64), -} - -impl PduCount { - pub(crate) fn min() -> Self { Self::Backfilled(u64::MAX) } - - pub(crate) fn max() -> Self { Self::Normal(u64::MAX) } - - pub(crate) fn try_from_string(token: &str) -> Result { - if let Some(stripped_token) = token.strip_prefix('-') { - stripped_token.parse().map(PduCount::Backfilled) - } else { - token.parse().map(PduCount::Normal) - } - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) - } - - pub(crate) fn stringify(&self) -> String { - match self { - PduCount::Backfilled(x) => format!("-{x}"), - PduCount::Normal(x) => x.to_string(), - } - } -} - -impl PartialOrd for PduCount { - fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } -} - -impl Ord for PduCount { - fn cmp(&self, other: &Self) -> Ordering { - match (self, other) { - (PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o), - (PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s), - (PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater, - (PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less, - } - } -} - // Update Relationships #[derive(Deserialize)] struct ExtractRelatesTo { @@ -106,15 +67,15 @@ struct ExtractRelatesToEventId { relates_to: ExtractEventId, } -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, - pub(crate) lasttimelinecount_cache: Mutex>, + pub lasttimelinecount_cache: Mutex>, } impl Service { #[tracing::instrument(skip(self))] - pub(crate) fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { + pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? .next() .map(|o| o.map(|(_, p)| Arc::new(p))) @@ -122,19 +83,17 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { self.db.last_timeline_count(sender_user, room_id) } /// Returns the `count` of this pdu's id. - pub(crate) fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_count(event_id) - } + pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.db.get_pdu_count(event_id) } // TODO Is this the same as the function above? /* #[tracing::instrument(skip(self))] - pub(crate) fn latest_pdu_count(&self, room_id: &RoomId) -> Result { + pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { let prefix = self .get_shortroomid(room_id)? .expect("room exists") @@ -158,7 +117,7 @@ impl Service { /// /// TODO: use this? #[allow(dead_code)] - pub(crate) fn get_room_version(&self, room_id: &RoomId) -> Result> { + pub fn get_room_version(&self, room_id: &RoomId) -> Result> { let create_event = services() .rooms .state_accessor @@ -178,17 +137,17 @@ impl Service { } /// Returns the json of a pdu. - pub(crate) fn get_pdu_json(&self, event_id: &EventId) -> Result> { + pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_pdu_json(event_id) } /// Returns the json of a pdu. - pub(crate) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + pub fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_non_outlier_pdu_json(event_id) } /// Returns the pdu's id. - pub(crate) fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.db.get_pdu_id(event_id) } + pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.db.get_pdu_id(event_id) } /// Returns the pdu. /// @@ -196,28 +155,28 @@ impl Service { /// /// TODO: use this? #[allow(dead_code)] - pub(crate) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { self.db.get_non_outlier_pdu(event_id) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub(crate) fn get_pdu(&self, event_id: &EventId) -> Result>> { self.db.get_pdu(event_id) } + pub fn get_pdu(&self, event_id: &EventId) -> Result>> { self.db.get_pdu(event_id) } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub(crate) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.db.get_pdu_from_id(pdu_id) } + pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.db.get_pdu_from_id(pdu_id) } /// Returns the pdu as a `BTreeMap`. - pub(crate) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { self.db.get_pdu_json_from_id(pdu_id) } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self))] - pub(crate) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + pub fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { self.db.replace_pdu(pdu_id, pdu_json, pdu) } @@ -228,7 +187,7 @@ impl Service { /// /// Returns pdu id #[tracing::instrument(skip(self, pdu, pdu_json, leaves))] - pub(crate) async fn append_pdu( + pub async fn append_pdu( &self, pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, @@ -513,7 +472,6 @@ impl Service { // This will evaluate to false if the emergency password is set up so that // the administrator can execute commands as conduit let from_conduit = pdu.sender == server_user && services().globals.emergency_password().is_none(); - if let Some(admin_room) = service::admin::Service::get_admin_room().await? { if to_conduit && !from_conduit && admin_room == pdu.room_id { services() @@ -628,7 +586,7 @@ impl Service { Ok(pdu_id) } - pub(crate) fn create_hash_and_sign_event( + pub fn create_hash_and_sign_event( &self, pdu_builder: PduBuilder, sender: &UserId, @@ -811,7 +769,7 @@ impl Service { /// takes a roomid_mutex_state, meaning that only this function is able to /// mutate the room state. #[tracing::instrument(skip(self, state_lock))] - pub(crate) async fn build_and_append_pdu( + pub async fn build_and_append_pdu( &self, pdu_builder: PduBuilder, sender: &UserId, @@ -819,7 +777,6 @@ impl Service { state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result> { let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - if let Some(admin_room) = service::admin::Service::get_admin_room().await? { if admin_room == room_id { match pdu.event_type() { @@ -951,7 +908,7 @@ impl Service { /// Append the incoming event setting the state snapshot to the state from /// the server that sent the event. #[tracing::instrument(skip_all)] - pub(crate) async fn append_incoming_pdu( + pub async fn append_incoming_pdu( &self, pdu: &PduEvent, pdu_json: CanonicalJsonObject, @@ -990,7 +947,7 @@ impl Service { } /// Returns an iterator over all PDUs in a room. - pub(crate) fn all_pdus<'a>( + pub fn all_pdus<'a>( &'a self, user_id: &UserId, room_id: &RoomId, ) -> Result> + 'a> { self.pdus_after(user_id, room_id, PduCount::min()) @@ -1000,7 +957,7 @@ impl Service { /// happened before the event with id `until` in reverse-chronological /// order. #[tracing::instrument(skip(self))] - pub(crate) fn pdus_until<'a>( + pub fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, ) -> Result> + 'a> { self.db.pdus_until(user_id, room_id, until) @@ -1009,7 +966,7 @@ impl Service { /// Returns an iterator over all events and their token in a room that /// happened after the event with id `from` in chronological order. #[tracing::instrument(skip(self))] - pub(crate) fn pdus_after<'a>( + pub fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, ) -> Result> + 'a> { self.db.pdus_after(user_id, room_id, from) @@ -1017,7 +974,7 @@ impl Service { /// Replace a PDU with the redacted form. #[tracing::instrument(skip(self, reason))] - pub(crate) fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { + pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { // TODO: Don't reserialize, keep original json if let Some(pdu_id) = self.get_pdu_id(event_id)? { let mut pdu = self @@ -1039,7 +996,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - pub(crate) async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { + pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { let first_pdu = self .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? .next() @@ -1154,11 +1111,11 @@ impl Service { } #[tracing::instrument(skip(self, pdu, pub_key_map))] - pub(crate) async fn backfill_pdu( + pub async fn backfill_pdu( &self, origin: &ServerName, pdu: Box, pub_key_map: &RwLock>>, ) -> Result<()> { - let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?; + let (event_id, value, room_id) = parse_incoming_pdu(&pdu)?; // Lock so we cannot backfill the same pdu twice at the same time let mutex = Arc::clone( diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 61a226fc..dab9d8d6 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -8,23 +8,23 @@ use ruma::{ use tokio::sync::{broadcast, RwLock}; use crate::{ - debug_info, services, - utils::{self, user_id::user_is_local}, + debug_info, services, user_is_local, + utils::{self}, Result, }; -pub(crate) struct Service { - pub(crate) typing: RwLock>>, // u64 is unix timestamp of timeout - pub(crate) last_typing_update: RwLock>, /* timestamp of the last change to - * typing - * users */ - pub(crate) typing_update_sender: broadcast::Sender, +pub struct Service { + pub typing: RwLock>>, // u64 is unix timestamp of timeout + pub last_typing_update: RwLock>, /* timestamp of the last change to + * typing + * users */ + pub typing_update_sender: broadcast::Sender, } impl Service { /// Sets a user as typing until the timeout timestamp is reached or /// roomtyping_remove is called. - pub(crate) async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { + pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { debug_info!("typing started {:?} in {:?} timeout:{:?}", user_id, room_id, timeout); // update clients self.typing @@ -48,7 +48,7 @@ impl Service { } /// Removes a user from typing before the timeout is reached. - pub(crate) async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { debug_info!("typing stopped {:?} in {:?}", user_id, room_id); // update clients self.typing @@ -71,7 +71,7 @@ impl Service { Ok(()) } - pub(crate) async fn wait_for_update(&self, room_id: &RoomId) -> Result<()> { + pub async fn wait_for_update(&self, room_id: &RoomId) -> Result<()> { let mut receiver = self.typing_update_sender.subscribe(); while let Ok(next) = receiver.recv().await { if next == room_id { @@ -128,7 +128,7 @@ impl Service { } /// Returns the count of the last typing update in this room. - pub(crate) async fn last_typing_update(&self, room_id: &RoomId) -> Result { + pub async fn last_typing_update(&self, room_id: &RoomId) -> Result { self.typings_maintain(room_id).await?; Ok(self .last_typing_update @@ -140,7 +140,7 @@ impl Service { } /// Returns a new typing EDU. - pub(crate) async fn typings_all( + pub async fn typings_all( &self, room_id: &RoomId, ) -> Result> { Ok(SyncEphemeralRoomEvent { diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index d10eae21..2fd1c29e 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -2,7 +2,7 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 58d6ea55..5f4d4708 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -1,45 +1,43 @@ mod data; -pub(crate) use data::Data; +use std::sync::Arc; + +pub use data::Data; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { - pub(crate) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.reset_notification_counts(user_id, room_id) } - pub(crate) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.notification_count(user_id, room_id) } - pub(crate) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.highlight_count(user_id, room_id) } - pub(crate) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.last_notification_read(user_id, room_id) } - pub(crate) fn associate_token_shortstatehash( - &self, room_id: &RoomId, token: u64, shortstatehash: u64, - ) -> Result<()> { + pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { self.db .associate_token_shortstatehash(room_id, token, shortstatehash) } - pub(crate) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { self.db.get_token_shortstatehash(room_id, token) } - pub(crate) fn get_shared_rooms( - &self, users: Vec, - ) -> Result>> { + pub fn get_shared_rooms(&self, users: Vec) -> Result> + '_> { self.db.get_shared_rooms(users) } } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 04dfd5da..41479021 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -6,7 +6,7 @@ use crate::Result; type OutgoingSendingIter<'a> = Box, Destination, SendingEvent)>> + 'a>; type SendingEventIter<'a> = Box, SendingEvent)>> + 'a>; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn active_requests(&self) -> OutgoingSendingIter<'_>; fn active_requests_for(&self, destination: &Destination) -> SendingEventIter<'_>; fn delete_active_request(&self, key: Vec) -> Result<()>; diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index ed01e797..b4a6fdeb 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,27 +1,28 @@ use std::{fmt::Debug, sync::Arc}; -pub(crate) use data::Data; +pub use data::Data; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use tokio::sync::Mutex; -use tracing::warn; +use tokio::{sync::Mutex, task::JoinHandle}; +use tracing::{error, warn}; -use crate::{services, utils::server_name::server_is_ours, Config, Error, Result}; +use crate::{server_is_ours, services, Config, Error, Result}; mod appservice; mod data; -pub(crate) mod send; -pub(crate) mod sender; -pub(crate) use send::FedDest; +pub mod send; +pub mod sender; +pub use send::FedDest; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, /// The state for a given state hash. sender: loole::Sender, receiver: Mutex>, + handler_join: Mutex>>, startup_netburst: bool, startup_netburst_keep: i64, } @@ -34,7 +35,7 @@ struct Msg { } #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub(crate) enum Destination { +pub enum Destination { Appservice(String), Push(OwnedUserId, String), // user and pushkey Normal(OwnedServerName), @@ -42,26 +43,42 @@ pub(crate) enum Destination { #[allow(clippy::module_name_repetitions)] #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub(crate) enum SendingEvent { +pub enum SendingEvent { Pdu(Vec), // pduid Edu(Vec), // pdu json Flush, // none } impl Service { - pub(crate) fn build(db: &'static dyn Data, config: &Config) -> Arc { + pub fn build(db: Arc, config: &Config) -> Arc { let (sender, receiver) = loole::unbounded(); Arc::new(Self { db, sender, receiver: Mutex::new(receiver), + handler_join: Mutex::new(None), startup_netburst: config.startup_netburst, startup_netburst_keep: config.startup_netburst_keep, }) } + pub async fn close(&self) { + self.interrupt(); + if let Some(handler_join) = self.handler_join.lock().await.take() { + if let Err(e) = handler_join.await { + error!("Failed to shutdown: {e:?}"); + } + } + } + + pub fn interrupt(&self) { + if !self.sender.is_closed() { + self.sender.close(); + } + } + #[tracing::instrument(skip(self, pdu_id, user, pushkey))] - pub(crate) fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { let dest = Destination::Push(user.to_owned(), pushkey); let event = SendingEvent::Pdu(pdu_id.to_owned()); let _cork = services().globals.db.cork()?; @@ -74,7 +91,7 @@ impl Service { } #[tracing::instrument(skip(self))] - pub(crate) fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); let _cork = services().globals.db.cork()?; @@ -87,7 +104,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id, pdu_id))] - pub(crate) fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { let servers = services() .rooms .state_cache @@ -99,9 +116,7 @@ impl Service { } #[tracing::instrument(skip(self, servers, pdu_id))] - pub(crate) fn send_pdu_servers>( - &self, servers: I, pdu_id: &[u8], - ) -> Result<()> { + pub fn send_pdu_servers>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { let requests = servers .into_iter() .map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned()))) @@ -125,7 +140,7 @@ impl Service { } #[tracing::instrument(skip(self, server, serialized))] - pub(crate) fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { + pub fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); let _cork = services().globals.db.cork()?; @@ -138,7 +153,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id, serialized))] - pub(crate) fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { + pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { let servers = services() .rooms .state_cache @@ -150,9 +165,7 @@ impl Service { } #[tracing::instrument(skip(self, servers, serialized))] - pub(crate) fn send_edu_servers>( - &self, servers: I, serialized: Vec, - ) -> Result<()> { + pub fn send_edu_servers>(&self, servers: I, serialized: Vec) -> Result<()> { let requests = servers .into_iter() .map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone()))) @@ -177,7 +190,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - pub(crate) fn flush_room(&self, room_id: &RoomId) -> Result<()> { + pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { let servers = services() .rooms .state_cache @@ -189,7 +202,7 @@ impl Service { } #[tracing::instrument(skip(self, servers))] - pub(crate) fn flush_servers>(&self, servers: I) -> Result<()> { + pub fn flush_servers>(&self, servers: I) -> Result<()> { let requests = servers.into_iter().map(Destination::Normal); for dest in requests { self.dispatch(Msg { @@ -203,7 +216,7 @@ impl Service { } #[tracing::instrument(skip(self, request), name = "request")] - pub(crate) async fn send_federation_request(&self, dest: &ServerName, request: T) -> Result + pub async fn send_federation_request(&self, dest: &ServerName, request: T) -> Result where T: OutgoingRequest + Debug, { @@ -215,7 +228,7 @@ impl Service { /// /// Only returns None if there is no url specified in the appservice /// registration file - pub(crate) async fn send_appservice_request( + pub async fn send_appservice_request( &self, registration: Registration, request: T, ) -> Result> where @@ -227,7 +240,7 @@ impl Service { /// Cleanup event data /// Used for instance after we remove an appservice registration #[tracing::instrument(skip(self))] - pub(crate) fn cleanup_events(&self, appservice_id: String) -> Result<()> { + pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { self.db .delete_all_requests_for(&Destination::Appservice(appservice_id))?; @@ -243,7 +256,7 @@ impl Service { impl Destination { #[tracing::instrument(skip(self))] - pub(crate) fn get_prefix(&self) -> Vec { + pub fn get_prefix(&self) -> Vec { let mut prefix = match self { Destination::Appservice(server) => { let mut p = b"+".to_vec(); diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 84da476d..a835e438 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -13,7 +13,6 @@ use ruma::{ client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, }, - events::room::message::RoomMessageEventContent, OwnedServerName, ServerName, }; use tracing::{debug, error, trace}; @@ -28,7 +27,7 @@ use crate::{debug_error, debug_info, debug_warn, services, Error, Result}; /// /// # Examples: /// ```rust -/// # use conduit::api::server_server::FedDest; +/// # use conduit_service::sending::FedDest; /// # fn main() -> Result<(), std::net::AddrParseError> { /// FedDest::Literal("198.51.100.3:8448".parse()?); /// FedDest::Literal("[2001:db8::4:5]:443".parse()?); @@ -39,7 +38,7 @@ use crate::{debug_error, debug_info, debug_warn, services, Error, Result}; /// # } /// ``` #[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) enum FedDest { +pub enum FedDest { Literal(SocketAddr), Named(String, String), } @@ -52,7 +51,7 @@ struct ActualDest { } #[tracing::instrument(skip_all, name = "send")] -pub(crate) async fn send(client: &Client, dest: &ServerName, req: T) -> Result +pub async fn send(client: &Client, dest: &ServerName, req: T) -> Result where T: OutgoingRequest + Debug, { @@ -195,7 +194,7 @@ async fn get_actual_dest(server_name: &ServerName) -> Result { } else { cached = false; validate_dest(server_name)?; - resolve_actual_dest(server_name, false, false).await? + resolve_actual_dest(server_name, false).await? }; let string = dest.clone().into_https_string(); @@ -211,162 +210,49 @@ async fn get_actual_dest(server_name: &ServerName) -> Result { /// Implemented according to the specification at /// Numbers in comments below refer to bullet points in linked section of /// specification -pub(crate) async fn resolve_actual_dest( - dest: &'_ ServerName, no_cache_dest: bool, admin_room_caller: bool, -) -> Result<(FedDest, String)> { +pub async fn resolve_actual_dest(dest: &ServerName, no_cache_dest: bool) -> Result<(FedDest, String)> { trace!("Finding actual destination for {dest}"); let dest_str = dest.as_str().to_owned(); let mut hostname = dest_str.clone(); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "Checking for 1: IP literal with provided or default port", - )) - .await; - } - #[allow(clippy::single_match_else)] let actual_dest = match get_ip_with_port(&dest_str) { Some(host_port) => { debug!("1: IP literal with provided or default port"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "1: IP literal with provided or default port\n\nHost and Port: {host_port:?}" - ))) - .await; - } - host_port }, None => { - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "Checking for 2: Hostname with included port", - )) - .await; - } - if let Some(pos) = dest_str.find(':') { debug!("2: Hostname with included port"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain("2: Hostname with included port")) - .await; - } - let (host, port) = dest_str.split_at(pos); if !no_cache_dest { query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await?; } - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!("Host: {host} | Port: {port}"))) - .await; - } - FedDest::Named(host.to_owned(), port.to_owned()) } else { trace!("Requesting well known for {dest}"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Checking for 3: A .well-known file is available. Requesting well-known for {dest}" - ))) - .await; - } - if let Some(delegated_hostname) = request_well_known(dest.as_str()).await? { debug!("3: A .well-known file is available"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain("3: A .well-known file is available")) - .await; - } - hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); match get_ip_with_port(&delegated_hostname) { Some(host_and_port) => { debug!("3.1: IP literal in .well-known file"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "3.1: IP literal in .well-known file\n\nHost and Port: {host_and_port:?}" - ))) - .await; - } - host_and_port }, None => { - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "Checking for 3.2: Hostname with port in .well-known file", - )) - .await; - } - if let Some(pos) = delegated_hostname.find(':') { debug!("3.2: Hostname with port in .well-known file"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "3.2: Hostname with port in .well-known file", - )) - .await; - } - let (host, port) = delegated_hostname.split_at(pos); if !no_cache_dest { query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await?; } - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Host: {host} | Port: {port}" - ))) - .await; - } - FedDest::Named(host.to_owned(), port.to_owned()) } else { trace!("Delegated hostname has no port in this branch"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "Delegated hostname has no port specified", - )) - .await; - } - if let Some(hostname_override) = query_srv_record(&delegated_hostname).await? { debug!("3.3: SRV lookup successful"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "3.3: SRV lookup successful", - )) - .await; - } - let force_port = hostname_override.port(); if !no_cache_dest { query_and_cache_override( @@ -378,53 +264,17 @@ pub(crate) async fn resolve_actual_dest( } if let Some(port) = force_port { - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Host: {delegated_hostname} | Port: {port}" - ))) - .await; - } - FedDest::Named(delegated_hostname, format!(":{port}")) } else { - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Host: {delegated_hostname} | Port: 8448" - ))) - .await; - } - add_port_to_hostname(&delegated_hostname) } } else { debug!("3.4: No SRV records, just use the hostname from .well-known"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "3.4: No SRV records, just use the hostname from .well-known", - )) - .await; - } - if !no_cache_dest { query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448) .await?; } - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Host: {delegated_hostname} | Port: 8448" - ))) - .await; - } - add_port_to_hostname(&delegated_hostname) } } @@ -432,26 +282,8 @@ pub(crate) async fn resolve_actual_dest( } } else { trace!("4: No .well-known or an error occured"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "4: No .well-known or an error occured", - )) - .await; - } - if let Some(hostname_override) = query_srv_record(&dest_str).await? { debug!("4: No .well-known; SRV record found"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "4: No .well-known; SRV record found", - )) - .await; - } - let force_port = hostname_override.port(); if !no_cache_dest { @@ -464,52 +296,16 @@ pub(crate) async fn resolve_actual_dest( } if let Some(port) = force_port { - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Host: {hostname} | Port: {port}" - ))) - .await; - } - FedDest::Named(hostname.clone(), format!(":{port}")) } else { - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Host: {hostname} | Port: 8448" - ))) - .await; - } - add_port_to_hostname(&hostname) } } else { debug!("4: No .well-known; 5: No SRV record found"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain( - "4: No .well-known; 5: No SRV record found", - )) - .await; - } - if !no_cache_dest { query_and_cache_override(&dest_str, &dest_str, 8448).await?; } - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Host: {dest_str} | Port: 8448" - ))) - .await; - } - add_port_to_hostname(&dest_str) } } @@ -531,14 +327,6 @@ pub(crate) async fn resolve_actual_dest( }; debug!("Actual destination: {actual_dest:?} hostname: {hostname:?}"); - if admin_room_caller { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Actual destination: {actual_dest:?} | Hostname: {hostname:?}" - ))) - .await; - } Ok((actual_dest, hostname.into_uri_string())) } diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index e7d00234..3d1d89da 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -23,12 +23,7 @@ use ruma::{ use tracing::{debug, error, warn}; use super::{appservice, send, Destination, Msg, SendingEvent, Service}; -use crate::{ - service::presence::Presence, - services, - utils::{calculate_hash, user_id::user_is_local}, - Error, PduEvent, Result, -}; +use crate::{service::presence::Presence, services, user_is_local, utils::calculate_hash, Error, PduEvent, Result}; #[derive(Debug)] enum TransactionStatus { @@ -47,25 +42,31 @@ const DEQUEUE_LIMIT: usize = 48; const SELECT_EDU_LIMIT: usize = 16; impl Service { - pub(crate) fn start_handler(self: &Arc) { - let self2 = Arc::clone(self); - tokio::spawn(async move { - self2.handler().await; + pub async fn start_handler(self: &Arc) { + let self_ = Arc::clone(self); + let handle = services().server.runtime().spawn(async move { + self_ + .handler() + .await + .expect("Failed to start sending handler"); }); + + _ = self.handler_join.lock().await.insert(handle); } #[tracing::instrument(skip_all, name = "sender")] - async fn handler(&self) { + async fn handler(&self) -> Result<()> { let receiver = self.receiver.lock().await; - debug_assert!(!receiver.is_closed(), "channel error"); - let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); + self.initial_transactions(&mut futures, &mut statuses); loop { + debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { - Ok(request) = receiver.recv_async() => { - self.handle_request(request, &mut futures, &mut statuses); + request = receiver.recv_async() => match request { + Ok(request) => self.handle_request(request, &mut futures, &mut statuses), + Err(_) => return Ok(()), }, Some(response) = futures.next() => { self.handle_response(response, &mut futures, &mut statuses); @@ -396,7 +397,7 @@ fn select_edus_receipts( } async fn send_events(dest: Destination, events: Vec) -> SendingResult { - debug_assert!(!events.is_empty(), "sending empty transaction"); + //debug_assert!(!events.is_empty(), "sending empty transaction"); match dest { Destination::Normal(ref server) => send_events_dest_normal(&dest, server, events).await, Destination::Appservice(ref id) => send_events_dest_appservice(&dest, id, events).await, @@ -433,7 +434,7 @@ async fn send_events_dest_appservice(dest: &Destination, id: &String, events: Ve } } - debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); + //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); match appservice::send_request( services() .appservice @@ -584,7 +585,8 @@ async fn send_events_dest_normal( } let client = &services().globals.client.sender; - debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty transaction"); + //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty + // transaction"); send::send( client, server_name, diff --git a/src/service/services.rs b/src/service/services.rs new file mode 100644 index 00000000..291cba62 --- /dev/null +++ b/src/service/services.rs @@ -0,0 +1,342 @@ +use std::{ + collections::{BTreeMap, HashMap}, + sync::{atomic, Arc, Mutex as StdMutex}, +}; + +use conduit::{debug_info, Result, Server}; +use database::KeyValueDatabase; +use lru_cache::LruCache; +use tokio::{ + fs, + sync::{broadcast, Mutex, RwLock}, +}; +use tracing::{debug, info, trace}; + +use crate::{ + account_data, admin, appservice, globals, key_backups, media, presence, pusher, rooms, sending, transaction_ids, + uiaa, users, +}; + +pub struct Services { + pub appservice: appservice::Service, + pub pusher: pusher::Service, + pub rooms: rooms::Service, + pub transaction_ids: transaction_ids::Service, + pub uiaa: uiaa::Service, + pub users: users::Service, + pub account_data: account_data::Service, + pub presence: Arc, + pub admin: Arc, + pub globals: globals::Service, + pub key_backups: key_backups::Service, + pub media: media::Service, + pub sending: Arc, + pub server: Arc, + pub db: Arc, +} + +impl Services { + pub async fn build(server: Arc, db: Arc) -> Result { + let config = &server.config; + Ok(Self { + appservice: appservice::Service::build(db.clone())?, + pusher: pusher::Service { + db: db.clone(), + }, + rooms: rooms::Service { + alias: rooms::alias::Service { + db: db.clone(), + }, + auth_chain: rooms::auth_chain::Service { + db: db.clone(), + }, + directory: rooms::directory::Service { + db: db.clone(), + }, + event_handler: rooms::event_handler::Service, + lazy_loading: rooms::lazy_loading::Service { + db: db.clone(), + lazy_load_waiting: Mutex::new(HashMap::new()), + }, + metadata: rooms::metadata::Service { + db: db.clone(), + }, + outlier: rooms::outlier::Service { + db: db.clone(), + }, + pdu_metadata: rooms::pdu_metadata::Service { + db: db.clone(), + }, + read_receipt: rooms::read_receipt::Service { + db: db.clone(), + }, + search: rooms::search::Service { + db: db.clone(), + }, + short: rooms::short::Service { + db: db.clone(), + }, + state: rooms::state::Service { + db: db.clone(), + }, + state_accessor: rooms::state_accessor::Service { + db: db.clone(), + server_visibility_cache: StdMutex::new(LruCache::new( + (f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) + as usize, + )), + user_visibility_cache: StdMutex::new(LruCache::new( + (f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) + as usize, + )), + }, + state_cache: rooms::state_cache::Service { + db: db.clone(), + }, + state_compressor: rooms::state_compressor::Service { + db: db.clone(), + stateinfo_cache: StdMutex::new(LruCache::new( + (f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, + )), + }, + timeline: rooms::timeline::Service { + db: db.clone(), + lasttimelinecount_cache: Mutex::new(HashMap::new()), + }, + threads: rooms::threads::Service { + db: db.clone(), + }, + typing: rooms::typing::Service { + typing: RwLock::new(BTreeMap::new()), + last_typing_update: RwLock::new(BTreeMap::new()), + typing_update_sender: broadcast::channel(100).0, + }, + spaces: rooms::spaces::Service { + roomid_spacehierarchy_cache: Mutex::new(LruCache::new( + (f64::from(config.roomid_spacehierarchy_cache_capacity) + * config.conduit_cache_capacity_modifier) as usize, + )), + }, + user: rooms::user::Service { + db: db.clone(), + }, + }, + transaction_ids: transaction_ids::Service { + db: db.clone(), + }, + uiaa: uiaa::Service { + db: db.clone(), + }, + users: users::Service { + db: db.clone(), + connections: StdMutex::new(BTreeMap::new()), + }, + account_data: account_data::Service { + db: db.clone(), + }, + presence: presence::Service::build(db.clone(), config), + admin: admin::Service::build(), + key_backups: key_backups::Service { + db: db.clone(), + }, + media: media::Service { + db: db.clone(), + url_preview_mutex: RwLock::new(HashMap::new()), + }, + sending: sending::Service::build(db.clone(), config), + globals: globals::Service::load(db.clone(), config)?, + server, + db, + }) + } + + pub async fn memory_usage(&self) -> String { + let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); + let server_visibility_cache = self + .rooms + .state_accessor + .server_visibility_cache + .lock() + .unwrap() + .len(); + let user_visibility_cache = self + .rooms + .state_accessor + .user_visibility_cache + .lock() + .unwrap() + .len(); + let stateinfo_cache = self + .rooms + .state_compressor + .stateinfo_cache + .lock() + .unwrap() + .len(); + let lasttimelinecount_cache = self + .rooms + .timeline + .lasttimelinecount_cache + .lock() + .await + .len(); + let roomid_spacehierarchy_cache = self + .rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .len(); + let resolver_overrides_cache = self.globals.resolver.overrides.read().unwrap().len(); + let resolver_destinations_cache = self.globals.resolver.destinations.read().await.len(); + let bad_event_ratelimiter = self.globals.bad_event_ratelimiter.read().await.len(); + let bad_query_ratelimiter = self.globals.bad_query_ratelimiter.read().await.len(); + let bad_signature_ratelimiter = self.globals.bad_signature_ratelimiter.read().await.len(); + + format!( + "\ +lazy_load_waiting: {lazy_load_waiting} +server_visibility_cache: {server_visibility_cache} +user_visibility_cache: {user_visibility_cache} +stateinfo_cache: {stateinfo_cache} +lasttimelinecount_cache: {lasttimelinecount_cache} +roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache} +resolver_overrides_cache: {resolver_overrides_cache} +resolver_destinations_cache: {resolver_destinations_cache} +bad_event_ratelimiter: {bad_event_ratelimiter} +bad_query_ratelimiter: {bad_query_ratelimiter} +bad_signature_ratelimiter: {bad_signature_ratelimiter} +" + ) + } + + pub async fn clear_caches(&self, amount: u32) { + if amount > 0 { + self.rooms + .lazy_loading + .lazy_load_waiting + .lock() + .await + .clear(); + } + if amount > 1 { + self.rooms + .state_accessor + .server_visibility_cache + .lock() + .unwrap() + .clear(); + } + if amount > 2 { + self.rooms + .state_accessor + .user_visibility_cache + .lock() + .unwrap() + .clear(); + } + if amount > 3 { + self.rooms + .state_compressor + .stateinfo_cache + .lock() + .unwrap() + .clear(); + } + if amount > 4 { + self.rooms + .timeline + .lasttimelinecount_cache + .lock() + .await + .clear(); + } + if amount > 5 { + self.rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .clear(); + } + if amount > 6 { + self.globals.resolver.overrides.write().unwrap().clear(); + self.globals.resolver.destinations.write().await.clear(); + } + if amount > 7 { + self.globals.resolver.resolver.clear_cache(); + } + if amount > 8 { + self.globals.bad_event_ratelimiter.write().await.clear(); + } + if amount > 9 { + self.globals.bad_query_ratelimiter.write().await.clear(); + } + if amount > 10 { + self.globals.bad_signature_ratelimiter.write().await.clear(); + } + } + + pub async fn start(&self) -> Result<()> { + debug_info!("Starting services"); + globals::migrations::migrations(&self.db, &self.globals.config).await?; + + self.admin.start_handler().await; + + globals::emerg_access::init_emergency_access().await; + + self.sending.start_handler().await; + + if self.globals.config.allow_local_presence { + self.presence.start_handler().await; + } + + if self.globals.allow_check_for_updates() { + let handle = globals::updates::start_check_for_updates_task().await?; + _ = self.globals.updates_handle.lock().await.insert(handle); + } + + debug_info!("Services startup complete."); + Ok(()) + } + + pub async fn interrupt(&self) { + trace!("Interrupting services..."); + self.server.interrupt.store(true, atomic::Ordering::Release); + + self.globals.rotate.fire(); + self.sending.interrupt(); + self.presence.interrupt(); + self.admin.interrupt(); + + trace!("Services interrupt complete."); + } + + #[tracing::instrument(skip_all)] + pub async fn shutdown(&self) { + info!("Shutting down services"); + self.interrupt().await; + + debug!("Removing unix socket file."); + if let Some(path) = self.globals.unix_socket_path().as_ref() { + _ = fs::remove_file(path).await; + } + + debug!("Waiting for update worker..."); + if let Some(updates_handle) = self.globals.updates_handle.lock().await.take() { + updates_handle.abort(); + _ = updates_handle.await; + } + + debug!("Waiting for admin worker..."); + self.admin.close().await; + + debug!("Waiting for presence worker..."); + self.presence.close().await; + + debug!("Waiting for sender..."); + self.sending.close().await; + + debug_info!("Services shutdown complete."); + } +} diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs index 1a0abb62..2aed1981 100644 --- a/src/service/transaction_ids/data.rs +++ b/src/service/transaction_ids/data.rs @@ -2,7 +2,7 @@ use ruma::{DeviceId, TransactionId, UserId}; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], ) -> Result<()>; diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index 0bdaf925..ba9869e7 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,22 +1,24 @@ mod data; -pub(crate) use data::Data; +use std::sync::Arc; + +pub use data::Data; use ruma::{DeviceId, TransactionId, UserId}; use crate::Result; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub struct Service { + pub db: Arc, } impl Service { - pub(crate) fn add_txnid( + pub fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], ) -> Result<()> { self.db.add_txnid(user_id, device_id, txn_id, data) } - pub(crate) fn existing_txnid( + pub fn existing_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, ) -> Result>> { self.db.existing_txnid(user_id, device_id, txn_id) diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index 1b43ff8f..3a157068 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -2,7 +2,7 @@ use ruma::{api::client::uiaa::UiaaInfo, CanonicalJsonValue, DeviceId, UserId}; use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { fn set_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, ) -> Result<()>; diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 6e38bbcc..cd131c52 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,7 +1,10 @@ mod data; +use std::sync::Arc; + use argon2::{PasswordHash, PasswordVerifier}; -pub(crate) use data::Data; +use conduit::{utils, Error, Result}; +pub use data::Data; use ruma::{ api::client::{ error::ErrorKind, @@ -11,15 +14,17 @@ use ruma::{ }; use tracing::error; -use crate::{api::client_server::SESSION_ID_LENGTH, services, utils, Error, Result}; +use crate::services; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, +pub const SESSION_ID_LENGTH: usize = 32; + +pub struct Service { + pub db: Arc, } impl Service { /// Creates a new Uiaa session. Make sure the session token is unique. - pub(crate) fn create( + pub fn create( &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, ) -> Result<()> { self.db.set_uiaa_request( @@ -37,7 +42,7 @@ impl Service { ) } - pub(crate) fn try_auth( + pub fn try_auth( &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, ) -> Result<(bool, UiaaInfo)> { let mut uiaainfo = auth.session().map_or_else( @@ -135,7 +140,8 @@ impl Service { Ok((true, uiaainfo)) } - pub(crate) fn get_uiaa_request( + #[must_use] + pub fn get_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, ) -> Option { self.db.get_uiaa_request(user_id, device_id, session) diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 9ce7ecdc..04074e85 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -10,7 +10,7 @@ use ruma::{ use crate::Result; -pub(crate) trait Data: Send + Sync { +pub trait Data: Send + Sync { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result; diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 3ae9931d..fde2ed89 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -5,7 +5,7 @@ use std::{ sync::{Arc, Mutex}, }; -pub(crate) use data::Data; +pub use data::Data; use ruma::{ api::client::{ device::Device, @@ -25,7 +25,7 @@ use ruma::{ use crate::{services, Error, Result}; -pub(crate) struct SlidingSyncCache { +pub struct SlidingSyncCache { lists: BTreeMap, subscriptions: BTreeMap, known_rooms: BTreeMap>, // For every room, the roomsince number @@ -34,25 +34,23 @@ pub(crate) struct SlidingSyncCache { type DbConnections = Mutex>>>; -pub(crate) struct Service { - pub(crate) db: &'static dyn Data, - pub(crate) connections: DbConnections, +pub struct Service { + pub db: Arc, + pub connections: DbConnections, } impl Service { /// Check if a user has an account on this homeserver. - pub(crate) fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } + pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } - pub(crate) fn forget_sync_request_connection( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, - ) { + pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { self.connections .lock() .unwrap() .remove(&(user_id, device_id, conn_id)); } - pub(crate) fn update_sync_request_with_cache( + pub fn update_sync_request_with_cache( &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, ) -> BTreeMap> { let Some(conn_id) = request.conn_id.clone() else { @@ -172,7 +170,7 @@ impl Service { cached.known_rooms.clone() } - pub(crate) fn update_sync_subscriptions( + pub fn update_sync_subscriptions( &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, subscriptions: BTreeMap, ) { @@ -195,7 +193,7 @@ impl Service { cached.subscriptions = subscriptions; } - pub(crate) fn update_sync_known_rooms( + pub fn update_sync_known_rooms( &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, new_cached_rooms: BTreeSet, globalsince: u64, ) { @@ -232,10 +230,10 @@ impl Service { } /// Check if account is deactivated - pub(crate) fn is_deactivated(&self, user_id: &UserId) -> Result { self.db.is_deactivated(user_id) } + pub fn is_deactivated(&self, user_id: &UserId) -> Result { self.db.is_deactivated(user_id) } /// Check if a user is an admin - pub(crate) fn is_admin(&self, user_id: &UserId) -> Result { + pub fn is_admin(&self, user_id: &UserId) -> Result { let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", services().globals.server_name())) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; let admin_room_id = services() @@ -251,63 +249,63 @@ impl Service { } /// Create a new user account on this homeserver. - pub(crate) fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { self.db.set_password(user_id, password)?; Ok(()) } /// Returns the number of users registered on this server. - pub(crate) fn count(&self) -> Result { self.db.count() } + pub fn count(&self) -> Result { self.db.count() } /// Find out which user an access token belongs to. - pub(crate) fn find_from_token(&self, token: &str) -> Result> { + pub fn find_from_token(&self, token: &str) -> Result> { self.db.find_from_token(token) } /// Returns an iterator over all users on this homeserver. - pub(crate) fn iter(&self) -> impl Iterator> + '_ { self.db.iter() } + pub fn iter(&self) -> impl Iterator> + '_ { self.db.iter() } /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is /// greater then zero. - pub(crate) fn list_local_users(&self) -> Result> { self.db.list_local_users() } + pub fn list_local_users(&self) -> Result> { self.db.list_local_users() } /// Returns the password hash for the given user. - pub(crate) fn password_hash(&self, user_id: &UserId) -> Result> { self.db.password_hash(user_id) } + pub fn password_hash(&self, user_id: &UserId) -> Result> { self.db.password_hash(user_id) } /// Hash and set the user's password to the Argon2 hash - pub(crate) fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { self.db.set_password(user_id, password) } /// Returns the displayname of a user on this homeserver. - pub(crate) fn displayname(&self, user_id: &UserId) -> Result> { self.db.displayname(user_id) } + pub fn displayname(&self, user_id: &UserId) -> Result> { self.db.displayname(user_id) } /// Sets a new displayname or removes it if displayname is None. You still /// need to nofify all rooms of this change. - pub(crate) async fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + pub async fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { self.db.set_displayname(user_id, displayname) } /// Get the avatar_url of a user. - pub(crate) fn avatar_url(&self, user_id: &UserId) -> Result> { self.db.avatar_url(user_id) } + pub fn avatar_url(&self, user_id: &UserId) -> Result> { self.db.avatar_url(user_id) } /// Sets a new avatar_url or removes it if avatar_url is None. - pub(crate) async fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { + pub async fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { self.db.set_avatar_url(user_id, avatar_url) } /// Get the blurhash of a user. - pub(crate) fn blurhash(&self, user_id: &UserId) -> Result> { self.db.blurhash(user_id) } + pub fn blurhash(&self, user_id: &UserId) -> Result> { self.db.blurhash(user_id) } /// Sets a new avatar_url or removes it if avatar_url is None. - pub(crate) async fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + pub async fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { self.db.set_blurhash(user_id, blurhash) } /// Adds a new device to a user. - pub(crate) fn create_device( + pub fn create_device( &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, ) -> Result<()> { self.db @@ -315,21 +313,21 @@ impl Service { } /// Removes a device from a user. - pub(crate) fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { self.db.remove_device(user_id, device_id) } /// Returns an iterator over all device ids of this user. - pub(crate) fn all_device_ids<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { + pub fn all_device_ids<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { self.db.all_device_ids(user_id) } /// Replaces the access token of one device. - pub(crate) fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { self.db.set_token(user_id, device_id, token) } - pub(crate) fn add_one_time_key( + pub fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, ) -> Result<()> { @@ -339,29 +337,27 @@ impl Service { // TODO: use this ? #[allow(dead_code)] - pub(crate) fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { self.db.last_one_time_keys_update(user_id) } - pub(crate) fn take_one_time_key( + pub fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, ) -> Result)>> { self.db.take_one_time_key(user_id, device_id, key_algorithm) } - pub(crate) fn count_one_time_keys( + pub fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result> { self.db.count_one_time_keys(user_id, device_id) } - pub(crate) fn add_device_keys( - &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, - ) -> Result<()> { + pub fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { self.db.add_device_keys(user_id, device_id, device_keys) } - pub(crate) fn add_cross_signing_keys( + pub fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, user_signing_key: &Option>, notify: bool, ) -> Result<()> { @@ -369,58 +365,56 @@ impl Service { .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) } - pub(crate) fn sign_key( + pub fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, ) -> Result<()> { self.db.sign_key(target_id, key_id, signature, sender_id) } - pub(crate) fn keys_changed<'a>( + pub fn keys_changed<'a>( &'a self, user_or_room_id: &str, from: u64, to: Option, ) -> impl Iterator> + 'a { self.db.keys_changed(user_or_room_id, from, to) } - pub(crate) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - self.db.mark_device_key_update(user_id) - } + pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { self.db.mark_device_key_update(user_id) } - pub(crate) fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + pub fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { self.db.get_device_keys(user_id, device_id) } - pub(crate) fn parse_master_key( + pub fn parse_master_key( &self, user_id: &UserId, master_key: &Raw, ) -> Result<(Vec, CrossSigningKey)> { self.db.parse_master_key(user_id, master_key) } - pub(crate) fn get_key( + pub fn get_key( &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.db .get_key(key, sender_user, user_id, allowed_signatures) } - pub(crate) fn get_master_key( + pub fn get_master_key( &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.db .get_master_key(sender_user, user_id, allowed_signatures) } - pub(crate) fn get_self_signing_key( + pub fn get_self_signing_key( &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.db .get_self_signing_key(sender_user, user_id, allowed_signatures) } - pub(crate) fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { self.db.get_user_signing_key(user_id) } - pub(crate) fn add_to_device_event( + pub fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, ) -> Result<()> { @@ -428,35 +422,33 @@ impl Service { .add_to_device_event(sender, target_user_id, target_device_id, event_type, content) } - pub(crate) fn get_to_device_events( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result>> { + pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { self.db.get_to_device_events(user_id, device_id) } - pub(crate) fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { + pub fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { self.db.remove_to_device_events(user_id, device_id, until) } - pub(crate) fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + pub fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { self.db.update_device_metadata(user_id, device_id, device) } /// Get device metadata. - pub(crate) fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { + pub fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { self.db.get_device_metadata(user_id, device_id) } - pub(crate) fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { self.db.get_devicelist_version(user_id) } - pub(crate) fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { + pub fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { self.db.all_devices_metadata(user_id) } /// Deactivate account - pub(crate) fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { // Remove all associated devices for device_id in self.all_device_ids(user_id) { self.remove_device(user_id, &device_id?)?; @@ -473,17 +465,17 @@ impl Service { } /// Creates a new sync filter. Returns the filter id. - pub(crate) fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { + pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { self.db.create_filter(user_id, filter) } - pub(crate) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { + pub fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { self.db.get_filter(user_id, filter_id) } } /// Ensure that a user only sees signatures from themselves and the target user -pub(crate) fn clean_signatures bool>( +pub fn clean_signatures bool>( cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, ) -> Result<(), Error> { if let Some(signatures) = cross_signing_key diff --git a/src/utils/server_name.rs b/src/utils/server_name.rs deleted file mode 100644 index 11303f9a..00000000 --- a/src/utils/server_name.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! utilities for doing/checking things with ServerName's/server_name's - -use ruma::ServerName; - -use crate::services; - -/// checks if `server_name` is ours -pub(crate) fn server_is_ours(server_name: &ServerName) -> bool { server_name == services().globals.config.server_name } diff --git a/src/utils/user_id.rs b/src/utils/user_id.rs deleted file mode 100644 index ae312792..00000000 --- a/src/utils/user_id.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! utilities for doing things with UserId's / usernames - -use ruma::UserId; - -use crate::services; - -/// checks if `user_id` is local to us via server_name comparison -pub(crate) fn user_is_local(user_id: &UserId) -> bool { user_id.server_name() == services().globals.config.server_name }