From 946ca364e032be8ca2529099b415990262c977fd Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 8 Aug 2024 17:18:30 +0000 Subject: [PATCH] Database Refactor combine service/users data w/ mod unit split sliding sync related out of service/users instrument database entry points remove increment crap from database interface de-wrap all database get() calls de-wrap all database insert() calls de-wrap all database remove() calls refactor database interface for async streaming add query key serializer for database implement Debug for result handle add query deserializer for database add deserialization trait for option handle start a stream utils suite de-wrap/asyncify/type-query count_one_time_keys() de-wrap/asyncify users count add admin query users command suite de-wrap/asyncify users exists de-wrap/partially asyncify user filter related asyncify/de-wrap users device/keys related asyncify/de-wrap user auth/misc related asyncify/de-wrap users blurhash asyncify/de-wrap account_data get; merge Data into Service partial asyncify/de-wrap uiaa; merge Data into Service partially asyncify/de-wrap transaction_ids get; merge Data into Service partially asyncify/de-wrap key_backups; merge Data into Service asyncify/de-wrap pusher service getters; merge Data into Service asyncify/de-wrap rooms alias getters/some iterators asyncify/de-wrap rooms directory getters/iterator partially asyncify/de-wrap rooms lazy-loading partially asyncify/de-wrap rooms metadata asyncify/dewrap rooms outlier asyncify/dewrap rooms pdu_metadata dewrap/partially asyncify rooms read receipt de-wrap rooms search service de-wrap/partially asyncify rooms user service partial de-wrap rooms state_compressor de-wrap rooms state_cache de-wrap room state et al de-wrap rooms timeline service additional users device/keys related de-wrap/asyncify sender asyncify services refactor database to TryFuture/TryStream refactor services for TryFuture/TryStream asyncify api handlers additional asyncification for admin module abstract stream related; support reverse streams additional stream conversions asyncify state-res related Signed-off-by: Jason Volk --- Cargo.lock | 53 +- Cargo.toml | 7 +- clippy.toml | 2 +- src/admin/Cargo.toml | 3 +- src/admin/check/commands.rs | 9 +- src/admin/debug/commands.rs | 100 +- src/admin/federation/commands.rs | 13 +- src/admin/media/commands.rs | 4 +- src/admin/processor.rs | 2 +- src/admin/query/account_data.rs | 6 +- src/admin/query/appservice.rs | 6 +- src/admin/query/globals.rs | 6 +- src/admin/query/presence.rs | 5 +- src/admin/query/pusher.rs | 2 +- src/admin/query/room_alias.rs | 13 +- src/admin/query/room_state_cache.rs | 93 +- src/admin/query/sending.rs | 9 +- src/admin/query/users.rs | 351 +++- src/admin/room/alias.rs | 118 +- src/admin/room/commands.rs | 47 +- src/admin/room/directory.rs | 22 +- src/admin/room/info.rs | 48 +- src/admin/room/mod.rs | 6 + src/admin/room/moderation.rs | 343 ++-- src/admin/user/commands.rs | 192 ++- src/admin/utils.rs | 22 +- src/api/Cargo.toml | 2 +- src/api/client/account.rs | 182 ++- src/api/client/alias.rs | 28 +- src/api/client/backup.rs | 230 +-- src/api/client/config.rs | 40 +- src/api/client/context.rs | 108 +- src/api/client/device.rs | 46 +- src/api/client/directory.rs | 188 +-- src/api/client/filter.rs | 25 +- src/api/client/keys.rs | 186 ++- src/api/client/membership.rs | 486 +++--- src/api/client/message.rs | 196 +-- src/api/client/presence.rs | 14 +- src/api/client/profile.rs | 200 ++- src/api/client/push.rs | 182 ++- src/api/client/read_marker.rs | 90 +- src/api/client/relations.rs | 90 +- src/api/client/report.rs | 20 +- src/api/client/room.rs | 105 +- src/api/client/search.rs | 108 +- src/api/client/session.rs | 75 +- src/api/client/state.rs | 63 +- src/api/client/sync.rs | 1091 +++++++------ src/api/client/tag.rs | 45 +- src/api/client/threads.rs | 19 +- src/api/client/to_device.rs | 52 +- src/api/client/typing.rs | 3 +- src/api/client/unstable.rs | 168 +- src/api/client/unversioned.rs | 3 +- src/api/client/user_directory.rs | 47 +- src/api/router.rs | 322 ++-- src/api/router/args.rs | 26 +- src/api/router/auth.rs | 10 +- src/api/router/handler.rs | 38 +- src/api/router/response.rs | 9 +- src/api/server/backfill.rs | 84 +- src/api/server/event.rs | 39 +- src/api/server/event_auth.rs | 33 +- src/api/server/get_missing_events.rs | 31 +- src/api/server/hierarchy.rs | 2 +- src/api/server/invite.rs | 38 +- src/api/server/make_join.rs | 87 +- src/api/server/make_leave.rs | 37 +- src/api/server/openid.rs | 5 +- src/api/server/query.rs | 36 +- src/api/server/send.rs | 196 +-- src/api/server/send_join.rs | 71 +- src/api/server/send_leave.rs | 20 +- src/api/server/state.rs | 65 +- src/api/server/state_ids.rs | 37 +- src/api/server/user.rs | 51 +- src/core/Cargo.toml | 1 + src/core/error/mod.rs | 4 +- src/core/pdu/mod.rs | 43 +- src/core/result/log_debug_err.rs | 18 +- src/core/result/log_err.rs | 20 +- src/core/utils/algorithm.rs | 25 - src/core/utils/mod.rs | 32 +- src/core/utils/set.rs | 47 + src/core/utils/stream/cloned.rs | 20 + src/core/utils/stream/expect.rs | 17 + src/core/utils/stream/ignore.rs | 21 + src/core/utils/stream/iter_stream.rs | 27 + src/core/utils/stream/mod.rs | 13 + src/core/utils/stream/ready.rs | 109 ++ src/core/utils/stream/try_ready.rs | 35 + src/core/utils/tests.rs | 130 ++ src/database/Cargo.toml | 3 + src/database/database.rs | 2 +- src/database/de.rs | 261 +++ src/database/deserialized.rs | 34 + src/database/engine.rs | 2 +- src/database/handle.rs | 89 +- src/database/iter.rs | 110 -- src/database/keyval.rs | 83 + src/database/map.rs | 304 ++-- src/database/map/count.rs | 36 + src/database/map/keys.rs | 21 + src/database/map/keys_from.rs | 49 + src/database/map/keys_prefix.rs | 54 + src/database/map/rev_keys.rs | 21 + src/database/map/rev_keys_from.rs | 49 + src/database/map/rev_keys_prefix.rs | 54 + src/database/map/rev_stream.rs | 29 + src/database/map/rev_stream_from.rs | 68 + src/database/map/rev_stream_prefix.rs | 74 + src/database/map/stream.rs | 28 + src/database/map/stream_from.rs | 68 + src/database/map/stream_prefix.rs | 74 + src/database/mod.rs | 28 +- src/database/ser.rs | 315 ++++ src/database/slice.rs | 57 - src/database/stream.rs | 122 ++ src/database/stream/items.rs | 44 + src/database/stream/items_rev.rs | 44 + src/database/stream/keys.rs | 44 + src/database/stream/keys_rev.rs | 44 + src/database/util.rs | 12 + src/service/Cargo.toml | 2 +- src/service/account_data/data.rs | 152 -- src/service/account_data/mod.rs | 164 +- src/service/admin/console.rs | 2 +- src/service/admin/create.rs | 2 +- src/service/admin/grant.rs | 216 +-- src/service/admin/mod.rs | 104 +- src/service/appservice/data.rs | 28 +- src/service/appservice/mod.rs | 49 +- src/service/emergency/mod.rs | 30 +- src/service/globals/data.rs | 121 +- src/service/globals/migrations.rs | 741 ++------- src/service/globals/mod.rs | 8 +- src/service/key_backups/data.rs | 346 ---- src/service/key_backups/mod.rs | 360 ++++- src/service/manager.rs | 2 +- src/service/media/data.rs | 100 +- src/service/media/migrations.rs | 33 +- src/service/media/mod.rs | 15 +- src/service/media/preview.rs | 8 +- src/service/media/thumbnail.rs | 4 +- src/service/mod.rs | 1 + src/service/presence/data.rs | 111 +- src/service/presence/mod.rs | 63 +- src/service/presence/presence.rs | 12 +- src/service/pusher/data.rs | 77 - src/service/pusher/mod.rs | 124 +- src/service/resolver/actual.rs | 6 +- src/service/rooms/alias/data.rs | 125 -- src/service/rooms/alias/mod.rs | 147 +- src/service/rooms/auth_chain/data.rs | 21 +- src/service/rooms/auth_chain/mod.rs | 45 +- src/service/rooms/directory/data.rs | 39 - src/service/rooms/directory/mod.rs | 40 +- src/service/rooms/event_handler/mod.rs | 1220 +++++++------- .../rooms/event_handler/parse_incoming_pdu.rs | 6 +- src/service/rooms/lazy_loading/data.rs | 65 - src/service/rooms/lazy_loading/mod.rs | 112 +- src/service/rooms/metadata/data.rs | 110 -- src/service/rooms/metadata/mod.rs | 99 +- src/service/rooms/outlier/data.rs | 42 - src/service/rooms/outlier/mod.rs | 59 +- src/service/rooms/pdu_metadata/data.rs | 76 +- src/service/rooms/pdu_metadata/mod.rs | 171 +- src/service/rooms/read_receipt/data.rs | 148 +- src/service/rooms/read_receipt/mod.rs | 49 +- src/service/rooms/search/data.rs | 73 +- src/service/rooms/search/mod.rs | 17 +- src/service/rooms/short/data.rs | 212 ++- src/service/rooms/short/mod.rs | 36 +- src/service/rooms/spaces/mod.rs | 174 +- src/service/rooms/state/data.rs | 71 +- src/service/rooms/state/mod.rs | 276 ++-- src/service/rooms/state_accessor/data.rs | 156 +- src/service/rooms/state_accessor/mod.rs | 356 +++-- src/service/rooms/state_cache/data.rs | 646 +------- src/service/rooms/state_cache/mod.rs | 475 ++++-- src/service/rooms/state_compressor/data.rs | 20 +- src/service/rooms/state_compressor/mod.rs | 85 +- src/service/rooms/threads/data.rs | 78 +- src/service/rooms/threads/mod.rs | 38 +- src/service/rooms/timeline/data.rs | 329 ++-- src/service/rooms/timeline/mod.rs | 656 ++++---- src/service/rooms/typing/mod.rs | 33 +- src/service/rooms/user/data.rs | 146 +- src/service/rooms/user/mod.rs | 46 +- src/service/sending/data.rs | 194 +-- src/service/sending/mod.rs | 113 +- src/service/sending/sender.rs | 278 ++-- src/service/server_keys/mod.rs | 26 +- src/service/services.rs | 4 +- src/service/sync/mod.rs | 233 +++ src/service/transaction_ids/data.rs | 44 - src/service/transaction_ids/mod.rs | 44 +- src/service/uiaa/data.rs | 87 - src/service/uiaa/mod.rs | 313 ++-- src/service/updates/mod.rs | 90 +- src/service/users/data.rs | 1098 ------------- src/service/users/mod.rs | 1413 +++++++++++------ 203 files changed, 12202 insertions(+), 10709 deletions(-) delete mode 100644 src/core/utils/algorithm.rs create mode 100644 src/core/utils/set.rs create mode 100644 src/core/utils/stream/cloned.rs create mode 100644 src/core/utils/stream/expect.rs create mode 100644 src/core/utils/stream/ignore.rs create mode 100644 src/core/utils/stream/iter_stream.rs create mode 100644 src/core/utils/stream/mod.rs create mode 100644 src/core/utils/stream/ready.rs create mode 100644 src/core/utils/stream/try_ready.rs create mode 100644 src/database/de.rs create mode 100644 src/database/deserialized.rs delete mode 100644 src/database/iter.rs create mode 100644 src/database/keyval.rs create mode 100644 src/database/map/count.rs create mode 100644 src/database/map/keys.rs create mode 100644 src/database/map/keys_from.rs create mode 100644 src/database/map/keys_prefix.rs create mode 100644 src/database/map/rev_keys.rs create mode 100644 src/database/map/rev_keys_from.rs create mode 100644 src/database/map/rev_keys_prefix.rs create mode 100644 src/database/map/rev_stream.rs create mode 100644 src/database/map/rev_stream_from.rs create mode 100644 src/database/map/rev_stream_prefix.rs create mode 100644 src/database/map/stream.rs create mode 100644 src/database/map/stream_from.rs create mode 100644 src/database/map/stream_prefix.rs create mode 100644 src/database/ser.rs delete mode 100644 src/database/slice.rs create mode 100644 src/database/stream.rs create mode 100644 src/database/stream/items.rs create mode 100644 src/database/stream/items_rev.rs create mode 100644 src/database/stream/keys.rs create mode 100644 src/database/stream/keys_rev.rs delete mode 100644 src/service/account_data/data.rs delete mode 100644 src/service/key_backups/data.rs delete mode 100644 src/service/pusher/data.rs delete mode 100644 src/service/rooms/alias/data.rs delete mode 100644 src/service/rooms/directory/data.rs delete mode 100644 src/service/rooms/lazy_loading/data.rs delete mode 100644 src/service/rooms/metadata/data.rs delete mode 100644 src/service/rooms/outlier/data.rs create mode 100644 src/service/sync/mod.rs delete mode 100644 src/service/transaction_ids/data.rs delete mode 100644 src/service/uiaa/data.rs delete mode 100644 src/service/users/data.rs diff --git a/Cargo.lock b/Cargo.lock index 6386f968..08e0498a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -626,10 +626,11 @@ dependencies = [ "clap", "conduit_api", "conduit_core", + "conduit_database", "conduit_macros", "conduit_service", "const-str", - "futures-util", + "futures", "log", "ruma", "serde_json", @@ -652,7 +653,7 @@ dependencies = [ "conduit_database", "conduit_service", "const-str", - "futures-util", + "futures", "hmac", "http", "http-body-util", @@ -689,6 +690,7 @@ dependencies = [ "cyborgtime", "either", "figment", + "futures", "hardened_malloc-rs", "http", "http-body-util", @@ -726,8 +728,11 @@ version = "0.4.7" dependencies = [ "conduit_core", "const-str", + "futures", "log", "rust-rocksdb-uwu", + "serde", + "serde_json", "tokio", "tracing", ] @@ -784,7 +789,7 @@ dependencies = [ "conduit_core", "conduit_database", "const-str", - "futures-util", + "futures", "hickory-resolver", "http", "image", @@ -1283,6 +1288,20 @@ dependencies = [ "new_debug_unreachable", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -1345,6 +1364,7 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -2953,7 +2973,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "assign", "js_int", @@ -2975,7 +2995,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "js_int", "ruma-common", @@ -2987,7 +3007,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "as_variant", "assign", @@ -3010,7 +3030,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "as_variant", "base64 0.22.1", @@ -3040,7 +3060,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3064,7 +3084,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "bytes", "http", @@ -3082,7 +3102,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "js_int", "thiserror", @@ -3091,7 +3111,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "js_int", "ruma-common", @@ -3101,7 +3121,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "cfg-if", "once_cell", @@ -3117,7 +3137,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "js_int", "ruma-common", @@ -3129,7 +3149,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "headers", "http", @@ -3142,7 +3162,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3158,8 +3178,9 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb" dependencies = [ + "futures-util", "itertools 0.12.1", "js_int", "ruma-common", diff --git a/Cargo.toml b/Cargo.toml index b75c4975..3bfb3bc8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -210,9 +210,10 @@ features = [ "string", ] -[workspace.dependencies.futures-util] +[workspace.dependencies.futures] version = "0.3.30" default-features = false +features = ["std"] [workspace.dependencies.tokio] version = "1.40.0" @@ -314,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "9900d0676564883cfade556d6e8da2a2c9061efd" +rev = "e7db44989d68406393270d3a91815597385d3acb" features = [ "compat", "rand", @@ -463,7 +464,6 @@ version = "1.0.36" [workspace.dependencies.proc-macro2] version = "1.0.89" - # # Patches # @@ -828,6 +828,7 @@ missing_panics_doc = { level = "allow", priority = 1 } module_name_repetitions = { level = "allow", priority = 1 } no_effect_underscore_binding = { level = "allow", priority = 1 } similar_names = { level = "allow", priority = 1 } +single_match_else = { level = "allow", priority = 1 } struct_field_names = { level = "allow", priority = 1 } unnecessary_wraps = { level = "allow", priority = 1 } unused_async = { level = "allow", priority = 1 } diff --git a/clippy.toml b/clippy.toml index c942b93c..08641fcc 100644 --- a/clippy.toml +++ b/clippy.toml @@ -2,6 +2,6 @@ array-size-threshold = 4096 cognitive-complexity-threshold = 94 # TODO reduce me ALARA excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5 future-size-threshold = 7745 # TODO reduce me ALARA -stack-size-threshold = 144000 # reduce me ALARA +stack-size-threshold = 196608 # reduce me ALARA too-many-lines-threshold = 700 # TODO reduce me to <= 100 type-complexity-threshold = 250 # reduce me to ~200 diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml index d756b3cb..f5cab449 100644 --- a/src/admin/Cargo.toml +++ b/src/admin/Cargo.toml @@ -29,10 +29,11 @@ release_max_log_level = [ clap.workspace = true conduit-api.workspace = true conduit-core.workspace = true +conduit-database.workspace = true conduit-macros.workspace = true conduit-service.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true log.workspace = true ruma.workspace = true serde_json.workspace = true diff --git a/src/admin/check/commands.rs b/src/admin/check/commands.rs index 0a983046..88fca462 100644 --- a/src/admin/check/commands.rs +++ b/src/admin/check/commands.rs @@ -1,5 +1,6 @@ use conduit::Result; use conduit_macros::implement; +use futures::StreamExt; use ruma::events::room::message::RoomMessageEventContent; use crate::Command; @@ -10,14 +11,12 @@ use crate::Command; #[implement(Command, params = "<'_>")] pub(super) async fn check_all_users(&self) -> Result { let timer = tokio::time::Instant::now(); - let results = self.services.users.db.iter(); + let users = self.services.users.iter().collect::>().await; let query_time = timer.elapsed(); - let users = results.collect::>(); - let total = users.len(); - let err_count = users.iter().filter(|user| user.is_err()).count(); - let ok_count = users.iter().filter(|user| user.is_ok()).count(); + let err_count = users.iter().filter(|_user| false).count(); + let ok_count = users.iter().filter(|_user| true).count(); let message = format!( "Database query completed in {query_time:?}:\n\n```\nTotal entries: {total:?}\nFailure/Invalid user count: \ diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 2d967006..65c9bc71 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -7,6 +7,7 @@ use std::{ use api::client::validate_and_add_event_id; use conduit::{debug, debug_error, err, info, trace, utils, warn, Error, PduEvent, Result}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, events::room::message::RoomMessageEventContent, @@ -27,7 +28,7 @@ pub(super) async fn echo(&self, message: Vec) -> Result) -> Result { let event_id = Arc::::from(event_id); - if let Some(event) = self.services.rooms.timeline.get_pdu_json(&event_id)? { + if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await { let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) @@ -43,7 +44,8 @@ pub(super) async fn get_auth_chain(&self, event_id: Box) -> Result) -> Result { + Ok(json) => { let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); Ok(RoomMessageEventContent::notice_markdown(format!( "{}\n```json\n{}\n```", @@ -109,7 +114,7 @@ pub(super) async fn get_pdu(&self, event_id: Box) -> Result Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), + Err(_) => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), } } @@ -157,7 +162,8 @@ pub(super) async fn get_remote_pdu_list( .send_message(RoomMessageEventContent::text_plain(format!( "Failed to get remote PDU, ignoring error: {e}" ))) - .await; + .await + .ok(); warn!("Failed to get remote PDU, ignoring error: {e}"); } else { success_count = success_count.saturating_add(1); @@ -215,7 +221,9 @@ pub(super) async fn get_remote_pdu( .services .rooms .event_handler - .parse_incoming_pdu(&response.pdu); + .parse_incoming_pdu(&response.pdu) + .await; + let (event_id, value, room_id) = match parsed_result { Ok(t) => t, Err(e) => { @@ -333,9 +341,12 @@ pub(super) async fn ping(&self, server: Box) -> Result Result { // Force E2EE device list updates for all users - for user_id in self.services.users.iter().filter_map(Result::ok) { - self.services.users.mark_device_key_update(&user_id)?; - } + self.services + .users + .stream() + .for_each(|user_id| self.services.users.mark_device_key_update(user_id)) + .await; + Ok(RoomMessageEventContent::text_plain( "Marked all devices for all users as having new keys to update", )) @@ -470,7 +481,8 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box) -> Result) -> Result) -> Result) -> Result> = HashMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); @@ -554,13 +571,21 @@ pub(super) async fn force_set_room_state_from_server( let mut events = Vec::with_capacity(remote_state_response.pdus.len()); for pdu in remote_state_response.pdus.clone() { - events.push(match self.services.rooms.event_handler.parse_incoming_pdu(&pdu) { - Ok(t) => t, - Err(e) => { - warn!("Could not parse PDU, ignoring: {e}"); - continue; + events.push( + match self + .services + .rooms + .event_handler + .parse_incoming_pdu(&pdu) + .await + { + Ok(t) => t, + Err(e) => { + warn!("Could not parse PDU, ignoring: {e}"); + continue; + }, }, - }); + ); } info!("Fetching required signing keys for all the state events we got"); @@ -587,13 +612,16 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .outlier - .add_pdu_outlier(&event_id, &value)?; + .add_pdu_outlier(&event_id, &value); + if let Some(state_key) = &pdu.state_key { let shortstatekey = self .services .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; + state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -611,7 +639,7 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .outlier - .add_pdu_outlier(&event_id, &value)?; + .add_pdu_outlier(&event_id, &value); } let new_room_state = self @@ -626,7 +654,8 @@ pub(super) async fn force_set_room_state_from_server( .services .rooms .state_compressor - .save_state(room_id.clone().as_ref(), new_room_state)?; + .save_state(room_id.clone().as_ref(), new_room_state) + .await?; let state_lock = self.services.rooms.state.mutex.lock(&room_id).await; self.services @@ -642,7 +671,8 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .state_cache - .update_joined_count(&room_id)?; + .update_joined_count(&room_id) + .await; drop(state_lock); @@ -656,7 +686,7 @@ pub(super) async fn get_signing_keys( &self, server_name: Option>, _cached: bool, ) -> Result { let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); - let signing_keys = self.services.globals.signing_keys_for(&server_name)?; + let signing_keys = self.services.globals.signing_keys_for(&server_name).await?; Ok(RoomMessageEventContent::notice_markdown(format!( "```rs\n{signing_keys:#?}\n```" @@ -674,7 +704,7 @@ pub(super) async fn get_verify_keys( if cached { writeln!(out, "| Key ID | VerifyKey |")?; writeln!(out, "| --- | --- |")?; - for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name)? { + for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name).await? { writeln!(out, "| {key_id} | {verify_key:?} |")?; } diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 8917a46b..ce95ac01 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -1,19 +1,20 @@ use std::fmt::Write; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId}; use crate::{admin_command, escape_html, get_room_info}; #[admin_command] pub(super) async fn disable_room(&self, room_id: Box) -> Result { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); Ok(RoomMessageEventContent::text_plain("Room disabled.")) } #[admin_command] pub(super) async fn enable_room(&self, room_id: Box) -> Result { - self.services.rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false); Ok(RoomMessageEventContent::text_plain("Room enabled.")) } @@ -85,7 +86,7 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box) -> Result< )); } - if !self.services.users.exists(&user_id)? { + if !self.services.users.exists(&user_id).await { return Ok(RoomMessageEventContent::text_plain( "Remote user does not exist in our database.", )); @@ -96,9 +97,9 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box) -> Result< .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .map(|room_id| get_room_info(self.services, &room_id)) - .collect(); + .then(|room_id| get_room_info(self.services, room_id)) + .collect() + .await; if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("User is not in any rooms.")); diff --git a/src/admin/media/commands.rs b/src/admin/media/commands.rs index 3c4bf2ef..82ac162e 100644 --- a/src/admin/media/commands.rs +++ b/src/admin/media/commands.rs @@ -36,7 +36,7 @@ pub(super) async fn delete( let mut mxc_urls = Vec::with_capacity(4); // parsing the PDU for any MXC URLs begins here - if let Some(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id)? { + if let Ok(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id).await { if let Some(content_key) = event_json.get("content") { debug!("Event ID has \"content\"."); let content_obj = content_key.as_object(); @@ -300,7 +300,7 @@ pub(super) async fn delete_all_from_server( #[admin_command] pub(super) async fn get_file_info(&self, mxc: OwnedMxcUri) -> Result { let mxc: Mxc<'_> = mxc.as_str().try_into()?; - let metadata = self.services.media.get_metadata(&mxc); + let metadata = self.services.media.get_metadata(&mxc).await; Ok(RoomMessageEventContent::notice_markdown(format!("```\n{metadata:#?}\n```"))) } diff --git a/src/admin/processor.rs b/src/admin/processor.rs index 4f60f56e..3c1895ff 100644 --- a/src/admin/processor.rs +++ b/src/admin/processor.rs @@ -17,7 +17,7 @@ use conduit::{ utils::string::{collect_stream, common_prefix}, warn, Error, Result, }; -use futures_util::future::FutureExt; +use futures::future::FutureExt; use ruma::{ events::{ relation::InReplyTo, diff --git a/src/admin/query/account_data.rs b/src/admin/query/account_data.rs index e18c298a..896bf95c 100644 --- a/src/admin/query/account_data.rs +++ b/src/admin/query/account_data.rs @@ -44,7 +44,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .changes_since(room_id.as_deref(), &user_id, since)?; + .changes_since(room_id.as_deref(), &user_id, since) + .await?; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -59,7 +60,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .get(room_id.as_deref(), &user_id, kind)?; + .get(room_id.as_deref(), &user_id, kind) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/appservice.rs b/src/admin/query/appservice.rs index 683c228f..4b97ef4e 100644 --- a/src/admin/query/appservice.rs +++ b/src/admin/query/appservice.rs @@ -29,7 +29,9 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> let results = services .appservice .db - .get_registration(appservice_id.as_ref()); + .get_registration(appservice_id.as_ref()) + .await; + let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -38,7 +40,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> }, AppserviceCommand::All => { let timer = tokio::time::Instant::now(); - let results = services.appservice.all(); + let results = services.appservice.all().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 5f271c2c..150a213c 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -29,7 +29,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - match subcommand { GlobalsCommand::DatabaseVersion => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.database_version(); + let results = services.globals.db.database_version().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -47,7 +47,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - }, GlobalsCommand::LastCheckForUpdatesId => { let timer = tokio::time::Instant::now(); - let results = services.updates.last_check_for_updates_id(); + let results = services.updates.last_check_for_updates_id().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -67,7 +67,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - origin, } => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.verify_keys_for(&origin); + let results = services.globals.db.verify_keys_for(&origin).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/presence.rs b/src/admin/query/presence.rs index 145ecd9b..6189270c 100644 --- a/src/admin/query/presence.rs +++ b/src/admin/query/presence.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, UserId}; use crate::Command; @@ -30,7 +31,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) user_id, } => { let timer = tokio::time::Instant::now(); - let results = services.presence.db.get_presence(&user_id)?; + let results = services.presence.db.get_presence(&user_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -42,7 +43,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) } => { let timer = tokio::time::Instant::now(); let results = services.presence.db.presence_since(since); - let presence_since: Vec<(_, _, _)> = results.collect(); + let presence_since: Vec<(_, _, _)> = results.collect().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/pusher.rs b/src/admin/query/pusher.rs index 637c57b6..a1bd32f9 100644 --- a/src/admin/query/pusher.rs +++ b/src/admin/query/pusher.rs @@ -21,7 +21,7 @@ pub(super) async fn process(subcommand: PusherCommand, context: &Command<'_>) -> user_id, } => { let timer = tokio::time::Instant::now(); - let results = services.pusher.get_pushers(&user_id)?; + let results = services.pusher.get_pushers(&user_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/room_alias.rs b/src/admin/query/room_alias.rs index 1809e26a..05fac42c 100644 --- a/src/admin/query/room_alias.rs +++ b/src/admin/query/room_alias.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; use crate::Command; @@ -31,7 +32,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) alias, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.resolve_local_alias(&alias); + let results = services.rooms.alias.resolve_local_alias(&alias).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -43,7 +44,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) } => { let timer = tokio::time::Instant::now(); let results = services.rooms.alias.local_aliases_for_room(&room_id); - let aliases: Vec<_> = results.collect(); + let aliases: Vec<_> = results.collect().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -52,8 +53,12 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) }, RoomAliasCommand::AllLocalAliases => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.all_local_aliases(); - let aliases: Vec<_> = results.collect(); + let aliases = services + .rooms + .alias + .all_local_aliases() + .collect::>() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/room_state_cache.rs b/src/admin/query/room_state_cache.rs index 4215cf8d..e32517fb 100644 --- a/src/admin/query/room_state_cache.rs +++ b/src/admin/query/room_state_cache.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId}; use crate::Command; @@ -86,7 +87,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let result = services.rooms.state_cache.server_in_room(&server, &room_id); + let result = services + .rooms + .state_cache + .server_in_room(&server, &room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -97,7 +102,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.room_servers(&room_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .room_servers(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -108,7 +119,13 @@ pub(super) async fn process( server, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.server_rooms(&server).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .server_rooms(&server) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -119,7 +136,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.room_members(&room_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .room_members(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -134,7 +157,9 @@ pub(super) async fn process( .rooms .state_cache .local_users_in_room(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -149,7 +174,9 @@ pub(super) async fn process( .rooms .state_cache .active_local_users_in_room(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -160,7 +187,7 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.room_joined_count(&room_id); + let results = services.rooms.state_cache.room_joined_count(&room_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -171,7 +198,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.room_invited_count(&room_id); + let results = services + .rooms + .state_cache + .room_invited_count(&room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -182,11 +213,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services + let results: Vec<_> = services .rooms .state_cache .room_useroncejoined(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -197,11 +230,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services + let results: Vec<_> = services .rooms .state_cache .room_members_invited(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -216,7 +251,8 @@ pub(super) async fn process( let results = services .rooms .state_cache - .get_invite_count(&room_id, &user_id); + .get_invite_count(&room_id, &user_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -231,7 +267,8 @@ pub(super) async fn process( let results = services .rooms .state_cache - .get_left_count(&room_id, &user_id); + .get_left_count(&room_id, &user_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -242,7 +279,13 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_joined(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_joined(&user_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -253,7 +296,12 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_invited(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_invited(&user_id) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -264,7 +312,12 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_left(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_left(&user_id) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -276,7 +329,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.invite_state(&user_id, &room_id); + let results = services + .rooms + .state_cache + .invite_state(&user_id, &room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/sending.rs b/src/admin/query/sending.rs index 6d54bddf..eaab1f5e 100644 --- a/src/admin/query/sending.rs +++ b/src/admin/query/sending.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, ServerName, UserId}; use service::sending::Destination; @@ -68,7 +69,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - SendingCommand::ActiveRequests => { let timer = tokio::time::Instant::now(); let results = services.sending.db.active_requests(); - let active_requests: Result> = results.collect(); + let active_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -133,7 +134,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - }, }; - let queued_requests = results.collect::>>(); + let queued_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -199,7 +200,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - }, }; - let active_requests = results.collect::>>(); + let active_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -210,7 +211,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - server_name, } => { let timer = tokio::time::Instant::now(); - let results = services.sending.db.get_latest_educount(&server_name); + let results = services.sending.db.get_latest_educount(&server_name).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/users.rs b/src/admin/query/users.rs index fee12fbf..0792e484 100644 --- a/src/admin/query/users.rs +++ b/src/admin/query/users.rs @@ -1,29 +1,344 @@ use clap::Subcommand; use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; +use futures::stream::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedDeviceId, OwnedRoomId, OwnedUserId}; -use crate::Command; +use crate::{admin_command, admin_command_dispatch}; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/users.rs pub(crate) enum UsersCommand { - Iter, + CountUsers, + + IterUsers, + + PasswordHash { + user_id: OwnedUserId, + }, + + ListDevices { + user_id: OwnedUserId, + }, + + ListDevicesMetadata { + user_id: OwnedUserId, + }, + + GetDeviceMetadata { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetDevicesVersion { + user_id: OwnedUserId, + }, + + CountOneTimeKeys { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetDeviceKeys { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetUserSigningKey { + user_id: OwnedUserId, + }, + + GetMasterKey { + user_id: OwnedUserId, + }, + + GetToDeviceEvents { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetLatestBackup { + user_id: OwnedUserId, + }, + + GetLatestBackupVersion { + user_id: OwnedUserId, + }, + + GetBackupAlgorithm { + user_id: OwnedUserId, + version: String, + }, + + GetAllBackups { + user_id: OwnedUserId, + version: String, + }, + + GetRoomBackups { + user_id: OwnedUserId, + version: String, + room_id: OwnedRoomId, + }, + + GetBackupSession { + user_id: OwnedUserId, + version: String, + room_id: OwnedRoomId, + session_id: String, + }, } -/// All the getters and iterators in key_value/users.rs -pub(super) async fn process(subcommand: UsersCommand, context: &Command<'_>) -> Result { - let services = context.services; +#[admin_command] +async fn get_backup_session( + &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, session_id: String, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_session(&user_id, &version, &room_id, &session_id) + .await; + let query_time = timer.elapsed(); - match subcommand { - UsersCommand::Iter => { - let timer = tokio::time::Instant::now(); - let results = services.users.db.iter(); - let users = results.collect::>(); - let query_time = timer.elapsed(); - - Ok(RoomMessageEventContent::notice_markdown(format!( - "Query completed in {query_time:?}:\n\n```rs\n{users:#?}\n```" - ))) - }, - } + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_room_backups( + &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_room(&user_id, &version, &room_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_all_backups(&self, user_id: OwnedUserId, version: String) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.key_backups.get_all(&user_id, &version).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_backup_algorithm(&self, user_id: OwnedUserId, version: String) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_backup(&user_id, &version) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_latest_backup_version(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_latest_backup_version(&user_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_latest_backup(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.key_backups.get_latest_backup(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn iter_users(&self) -> Result { + let timer = tokio::time::Instant::now(); + let result: Vec = self.services.users.stream().map(Into::into).collect().await; + + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn count_users(&self) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.count().await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn password_hash(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.password_hash(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn list_devices(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let devices = self + .services + .users + .all_device_ids(&user_id) + .map(ToOwned::to_owned) + .collect::>() + .await; + + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```" + ))) +} + +#[admin_command] +async fn list_devices_metadata(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let devices = self + .services + .users + .all_devices_metadata(&user_id) + .collect::>() + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```" + ))) +} + +#[admin_command] +async fn get_device_metadata(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let device = self + .services + .users + .get_device_metadata(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```" + ))) +} + +#[admin_command] +async fn get_devices_version(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let device = self.services.users.get_devicelist_version(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```" + ))) +} + +#[admin_command] +async fn count_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .count_one_time_keys(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_device_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_device_keys(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_user_signing_key(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.get_user_signing_key(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_master_key(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_master_key(None, &user_id, &|_| true) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_to_device_events( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_to_device_events(&user_id, &device_id) + .collect::>() + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) } diff --git a/src/admin/room/alias.rs b/src/admin/room/alias.rs index 415b8a08..34b6c42e 100644 --- a/src/admin/room/alias.rs +++ b/src/admin/room/alias.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use crate::{escape_html, Command}; @@ -66,8 +67,8 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> force, room_id, .. - } => match (force, services.rooms.alias.resolve_local_alias(&room_alias)) { - (true, Ok(Some(id))) => match services + } => match (force, services.rooms.alias.resolve_local_alias(&room_alias).await) { + (true, Ok(id)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -77,10 +78,10 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> ))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - (false, Ok(Some(id))) => Ok(RoomMessageEventContent::text_plain(format!( + (false, Ok(id)) => Ok(RoomMessageEventContent::text_plain(format!( "Refusing to overwrite in use alias for {id}, use -f or --force to overwrite" ))), - (_, Ok(None)) => match services + (_, Err(_)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -88,12 +89,11 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Ok(()) => Ok(RoomMessageEventContent::text_plain("Successfully set alias")), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - (_, Err(err)) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), }, RoomAliasCommand::Remove { .. - } => match services.rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => match services + } => match services.rooms.alias.resolve_local_alias(&room_alias).await { + Ok(id) => match services .rooms .alias .remove_alias(&room_alias, server_user) @@ -102,15 +102,13 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Ok(()) => Ok(RoomMessageEventContent::text_plain(format!("Removed alias from {id}"))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), + Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), }, RoomAliasCommand::Which { .. - } => match services.rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), - Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), + } => match services.rooms.alias.resolve_local_alias(&room_alias).await { + Ok(id) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), + Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), }, RoomAliasCommand::List { .. @@ -125,63 +123,59 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> .rooms .alias .local_aliases_for_room(&room_id) - .collect::, _>>(); - match aliases { - Ok(aliases) => { - let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "- {alias}").expect("should be able to write to string buffer"); - output - }); + .map(Into::into) + .collect::>() + .await; - let html_list = aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "
  • {}
  • ", escape_html(alias.as_ref())) - .expect("should be able to write to string buffer"); - output - }); + let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "- {alias}").expect("should be able to write to string buffer"); + output + }); - let plain = format!("Aliases for {room_id}:\n{plain_list}"); - let html = format!("Aliases for {room_id}:\n
      {html_list}
    "); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {err}"))), - } + let html_list = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "
  • {}
  • ", escape_html(alias.as_ref())) + .expect("should be able to write to string buffer"); + output + }); + + let plain = format!("Aliases for {room_id}:\n{plain_list}"); + let html = format!("Aliases for {room_id}:\n
      {html_list}
    "); + Ok(RoomMessageEventContent::text_html(plain, html)) } else { let aliases = services .rooms .alias .all_local_aliases() - .collect::, _>>(); - match aliases { - Ok(aliases) => { - let server_name = services.globals.server_name(); - let plain_list = aliases - .iter() - .fold(String::new(), |mut output, (alias, id)| { - writeln!(output, "- `{alias}` -> #{id}:{server_name}") - .expect("should be able to write to string buffer"); - output - }); + .map(|(room_id, localpart)| (room_id.into(), localpart.into())) + .collect::>() + .await; - let html_list = aliases - .iter() - .fold(String::new(), |mut output, (alias, id)| { - writeln!( - output, - "
  • {} -> #{}:{}
  • ", - escape_html(alias.as_ref()), - escape_html(id.as_ref()), - server_name - ) - .expect("should be able to write to string buffer"); - output - }); + let server_name = services.globals.server_name(); + let plain_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!(output, "- `{alias}` -> #{id}:{server_name}") + .expect("should be able to write to string buffer"); + output + }); - let plain = format!("Aliases:\n{plain_list}"); - let html = format!("Aliases:\n
      {html_list}
    "); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))), - } + let html_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!( + output, + "
  • {} -> #{}:{}
  • ", + escape_html(alias.as_ref()), + escape_html(id), + server_name + ) + .expect("should be able to write to string buffer"); + output + }); + + let plain = format!("Aliases:\n{plain_list}"); + let html = format!("Aliases:\n
      {html_list}
    "); + Ok(RoomMessageEventContent::text_html(plain, html)) } }, } diff --git a/src/admin/room/commands.rs b/src/admin/room/commands.rs index 2adfa7d7..1c90a998 100644 --- a/src/admin/room/commands.rs +++ b/src/admin/room/commands.rs @@ -1,11 +1,12 @@ use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; use crate::{admin_command, get_room_info, PAGE_SIZE}; #[admin_command] pub(super) async fn list_rooms( - &self, page: Option, exclude_disabled: bool, exclude_banned: bool, no_details: bool, + &self, page: Option, _exclude_disabled: bool, _exclude_banned: bool, no_details: bool, ) -> Result { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); @@ -14,37 +15,12 @@ pub(super) async fn list_rooms( .rooms .metadata .iter_ids() - .filter_map(|room_id| { - room_id - .ok() - .filter(|room_id| { - if exclude_disabled - && self - .services - .rooms - .metadata - .is_disabled(room_id) - .unwrap_or(false) - { - return false; - } + //.filter(|room_id| async { !exclude_disabled || !self.services.rooms.metadata.is_disabled(room_id).await }) + //.filter(|room_id| async { !exclude_banned || !self.services.rooms.metadata.is_banned(room_id).await }) + .then(|room_id| get_room_info(self.services, room_id)) + .collect::>() + .await; - if exclude_banned - && self - .services - .rooms - .metadata - .is_banned(room_id) - .unwrap_or(false) - { - return false; - } - - true - }) - .map(|room_id| get_room_info(self.services, &room_id)) - }) - .collect::>(); rooms.sort_by_key(|r| r.1); rooms.reverse(); @@ -74,3 +50,10 @@ pub(super) async fn list_rooms( Ok(RoomMessageEventContent::notice_markdown(output_plain)) } + +#[admin_command] +pub(super) async fn exists(&self, room_id: OwnedRoomId) -> Result { + let result = self.services.rooms.metadata.exists(&room_id).await; + + Ok(RoomMessageEventContent::notice_markdown(format!("{result}"))) +} diff --git a/src/admin/room/directory.rs b/src/admin/room/directory.rs index 7bba2eb7..7ccdea6f 100644 --- a/src/admin/room/directory.rs +++ b/src/admin/room/directory.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId}; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, RoomId}; use crate::{escape_html, get_room_info, Command, PAGE_SIZE}; @@ -31,15 +32,15 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_> match command { RoomDirectoryCommand::Publish { room_id, - } => match services.rooms.directory.set_public(&room_id) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Room published")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), + } => { + services.rooms.directory.set_public(&room_id); + Ok(RoomMessageEventContent::notice_plain("Room published")) }, RoomDirectoryCommand::Unpublish { room_id, - } => match services.rooms.directory.set_not_public(&room_id) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Room unpublished")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), + } => { + services.rooms.directory.set_not_public(&room_id); + Ok(RoomMessageEventContent::notice_plain("Room unpublished")) }, RoomDirectoryCommand::List { page, @@ -50,9 +51,10 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_> .rooms .directory .public_rooms() - .filter_map(Result::ok) - .map(|id: OwnedRoomId| get_room_info(services, &id)) - .collect::>(); + .then(|room_id| get_room_info(services, room_id)) + .collect::>() + .await; + rooms.sort_by_key(|r| r.1); rooms.reverse(); diff --git a/src/admin/room/info.rs b/src/admin/room/info.rs index d17a2924..fc0619e3 100644 --- a/src/admin/room/info.rs +++ b/src/admin/room/info.rs @@ -1,5 +1,6 @@ use clap::Subcommand; -use conduit::Result; +use conduit::{utils::ReadyExt, Result}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId}; use crate::{admin_command, admin_command_dispatch}; @@ -32,46 +33,42 @@ async fn list_joined_members(&self, room_id: Box, local_only: bool) -> R .rooms .state_accessor .get_name(&room_id) - .ok() - .flatten() - .unwrap_or_else(|| room_id.to_string()); + .await + .unwrap_or_else(|_| room_id.to_string()); - let members = self + let member_info: Vec<_> = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|member| { + .ready_filter(|user_id| { if local_only { - member - .ok() - .filter(|user| self.services.globals.user_is_local(user)) + self.services.globals.user_is_local(user_id) } else { - member.ok() + true } - }); - - let member_info = members - .into_iter() - .map(|user_id| { - ( - user_id.clone(), + }) + .filter_map(|user_id| async move { + let user_id = user_id.to_owned(); + Some(( self.services .users .displayname(&user_id) - .unwrap_or(None) - .unwrap_or_else(|| user_id.to_string()), - ) + .await + .unwrap_or_else(|_| user_id.to_string()), + user_id, + )) }) - .collect::>(); + .collect() + .await; let output_plain = format!( "{} Members in Room \"{}\":\n```\n{}\n```", member_info.len(), room_name, member_info - .iter() - .map(|(mxid, displayname)| format!("{mxid} | {displayname}")) + .into_iter() + .map(|(displayname, mxid)| format!("{mxid} | {displayname}")) .collect::>() .join("\n") ); @@ -81,11 +78,12 @@ async fn list_joined_members(&self, room_id: Box, local_only: bool) -> R #[admin_command] async fn view_room_topic(&self, room_id: Box) -> Result { - let Some(room_topic) = self + let Ok(room_topic) = self .services .rooms .state_accessor - .get_room_topic(&room_id)? + .get_room_topic(&room_id) + .await else { return Ok(RoomMessageEventContent::text_plain("Room does not have a room topic set.")); }; diff --git a/src/admin/room/mod.rs b/src/admin/room/mod.rs index 64d2af45..8c6cbeaa 100644 --- a/src/admin/room/mod.rs +++ b/src/admin/room/mod.rs @@ -6,6 +6,7 @@ mod moderation; use clap::Subcommand; use conduit::Result; +use ruma::OwnedRoomId; use self::{ alias::RoomAliasCommand, directory::RoomDirectoryCommand, info::RoomInfoCommand, moderation::RoomModerationCommand, @@ -49,4 +50,9 @@ pub(super) enum RoomCommand { #[command(subcommand)] /// - Manage the room directory Directory(RoomDirectoryCommand), + + /// - Check if we know about a room + Exists { + room_id: OwnedRoomId, + }, } diff --git a/src/admin/room/moderation.rs b/src/admin/room/moderation.rs index 70d8486b..9a772da4 100644 --- a/src/admin/room/moderation.rs +++ b/src/admin/room/moderation.rs @@ -1,6 +1,11 @@ use api::client::leave_room; use clap::Subcommand; -use conduit::{debug, error, info, warn, Result}; +use conduit::{ + debug, error, info, + utils::{IterStream, ReadyExt}, + warn, Result, +}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; use crate::{admin_command, admin_command_dispatch, get_room_info}; @@ -76,7 +81,7 @@ async fn ban_room( let admin_room_alias = &self.services.globals.admin_alias; - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + if let Ok(admin_room_id) = self.services.admin.get_admin_room().await { if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) { return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); } @@ -95,7 +100,7 @@ async fn ban_room( debug!("Room specified is a room ID, banning room ID"); - self.services.rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true); room_id } else if room.is_room_alias_id() { @@ -114,7 +119,13 @@ async fn ban_room( get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); @@ -138,7 +149,7 @@ async fn ban_room( } }; - self.services.rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true); room_id } else { @@ -150,56 +161,40 @@ async fn ban_room( debug!("Making all users leave the room {}", &room); if force { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - self.services.globals.user_is_local(local_user) - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would - // fail auth check) - && (self.services.globals.user_is_local(local_user) - // since this is a force operation, assume user is an admin - // if somehow this fails - && self.services - .users - .is_admin(local_user) - .unwrap_or(true)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, &room_id + "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \ + admins too)", ); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - && !self.services - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { + if self.services.users.is_admin(local_user).await { + continue; + } + debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { error!( "Error attempting to make local user {} leave room {} during room banning: {}", &local_user, &room_id, e @@ -214,12 +209,14 @@ async fn ban_room( } // remove any local aliases, ignore errors - for ref local_alias in self + for local_alias in &self .services .rooms .alias .local_aliases_for_room(&room_id) - .filter_map(Result::ok) + .map(ToOwned::to_owned) + .collect::>() + .await { _ = self .services @@ -230,10 +227,10 @@ async fn ban_room( } // unpublish from room directory, ignore errors - _ = self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.directory.set_not_public(&room_id); if disable_federation { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); return Ok(RoomMessageEventContent::text_plain( "Room banned, removed all our local users, and disabled incoming federation with room.", )); @@ -268,7 +265,7 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu for &room in &rooms_s { match <&RoomOrAliasId>::try_from(room) { Ok(room_alias_or_id) => { - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + if let Ok(admin_room_id) = self.services.admin.get_admin_room().await { if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) { info!("User specified admin room in bulk ban list, ignoring"); continue; @@ -300,43 +297,48 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu if room_alias_or_id.is_room_alias_id() { match RoomAliasId::parse(room_alias_or_id) { Ok(room_alias) => { - let room_id = - if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { - room_id - } else { - debug!( - "We don't have this room alias to a room ID locally, attempting to fetch room \ - ID over federation" - ); + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { + room_id + } else { + debug!( + "We don't have this room alias to a room ID locally, attempting to fetch room ID \ + over federation" + ); - match self - .services - .rooms - .alias - .resolve_alias(&room_alias, None) - .await - { - Ok((room_id, servers)) => { - debug!( - ?room_id, - ?servers, - "Got federation response fetching room ID for {room}", - ); - room_id - }, - Err(e) => { - // don't fail if force blocking - if force { - warn!("Failed to resolve room alias {room} to a room ID: {e}"); - continue; - } + match self + .services + .rooms + .alias + .resolve_alias(&room_alias, None) + .await + { + Ok((room_id, servers)) => { + debug!( + ?room_id, + ?servers, + "Got federation response fetching room ID for {room}", + ); + room_id + }, + Err(e) => { + // don't fail if force blocking + if force { + warn!("Failed to resolve room alias {room} to a room ID: {e}"); + continue; + } - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to resolve room alias {room} to a room ID: {e}" - ))); - }, - } - }; + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to resolve room alias {room} to a room ID: {e}" + ))); + }, + } + }; room_ids.push(room_id); }, @@ -374,74 +376,52 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu } for room_id in room_ids { - if self - .services - .rooms - .metadata - .ban_room(&room_id, true) - .is_ok() - { - debug!("Banned {room_id} successfully"); - room_ban_count = room_ban_count.saturating_add(1); - } + self.services.rooms.metadata.ban_room(&room_id, true); + + debug!("Banned {room_id} successfully"); + room_ban_count = room_ban_count.saturating_add(1); debug!("Making all users leave the room {}", &room_id); if force { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote - // users who are in the admin room to the list of local - // users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - // since this is a force operation, assume user is an - // admin if somehow this fails - && self.services - .users - .is_admin(local_user) - .unwrap_or(true)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, room_id + "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \ + admins too)", ); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote - // users who are in the admin room to the list of local - // users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - && !self.services - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) { - debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { + if self.services.users.is_admin(local_user).await { + continue; + } + + debug!("Attempting leave for user {local_user} in room {room_id}"); + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { error!( - "Error attempting to make local user {} leave room {} during bulk room banning: {}", - &local_user, &room_id, e + "Error attempting to make local user {local_user} leave room {room_id} during bulk room \ + banning: {e}", ); + return Ok(RoomMessageEventContent::text_plain(format!( "Error attempting to make local user {} leave room {} during room banning (room is still \ banned but not removing any more users and not banning any more rooms): {}\nIf you would \ @@ -453,26 +433,26 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu } // remove any local aliases, ignore errors - for ref local_alias in self - .services + self.services .rooms .alias .local_aliases_for_room(&room_id) - .filter_map(Result::ok) - { - _ = self - .services - .rooms - .alias - .remove_alias(local_alias, &self.services.globals.server_user) - .await; - } + .map(ToOwned::to_owned) + .for_each(|local_alias| async move { + self.services + .rooms + .alias + .remove_alias(&local_alias, &self.services.globals.server_user) + .await + .ok(); + }) + .await; // unpublish from room directory, ignore errors - _ = self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.directory.set_not_public(&room_id); if disable_federation { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); } } @@ -503,7 +483,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> debug!("Room specified is a room ID, unbanning room ID"); - self.services.rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false); room_id } else if room.is_room_alias_id() { @@ -522,7 +502,13 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); @@ -546,7 +532,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> } }; - self.services.rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false); room_id } else { @@ -557,7 +543,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> }; if enable_federation { - self.services.rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false); return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); } @@ -569,45 +555,42 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> #[admin_command] async fn list_banned_rooms(&self, no_details: bool) -> Result { - let rooms = self + let room_ids = self .services .rooms .metadata .list_banned_rooms() - .collect::, _>>(); + .map(Into::into) + .collect::>() + .await; - match rooms { - Ok(room_ids) => { - if room_ids.is_empty() { - return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); - } - - let mut rooms = room_ids - .into_iter() - .map(|room_id| get_room_info(self.services, &room_id)) - .collect::>(); - rooms.sort_by_key(|r| r.1); - rooms.reverse(); - - let output_plain = format!( - "Rooms Banned ({}):\n```\n{}\n```", - rooms.len(), - rooms - .iter() - .map(|(id, members, name)| if no_details { - format!("{id}") - } else { - format!("{id}\tMembers: {members}\tName: {name}") - }) - .collect::>() - .join("\n") - ); - - Ok(RoomMessageEventContent::notice_markdown(output_plain)) - }, - Err(e) => { - error!("Failed to list banned rooms: {e}"); - Ok(RoomMessageEventContent::text_plain(format!("Unable to list banned rooms: {e}"))) - }, + if room_ids.is_empty() { + return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); } + + let mut rooms = room_ids + .iter() + .stream() + .then(|room_id| get_room_info(self.services, room_id)) + .collect::>() + .await; + + rooms.sort_by_key(|r| r.1); + rooms.reverse(); + + let output_plain = format!( + "Rooms Banned ({}):\n```\n{}\n```", + rooms.len(), + rooms + .iter() + .map(|(id, members, name)| if no_details { + format!("{id}") + } else { + format!("{id}\tMembers: {members}\tName: {name}") + }) + .collect::>() + .join("\n") + ); + + Ok(RoomMessageEventContent::notice_markdown(output_plain)) } diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 20691f1a..1b086856 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -1,7 +1,9 @@ use std::{collections::BTreeMap, fmt::Write as _}; use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room}; -use conduit::{error, info, utils, warn, PduBuilder, Result}; +use conduit::{error, info, is_equal_to, utils, warn, PduBuilder, Result}; +use conduit_api::client::{leave_all_rooms, update_avatar_url, update_displayname}; +use futures::StreamExt; use ruma::{ events::{ room::{ @@ -25,16 +27,19 @@ const AUTO_GEN_PASSWORD_LENGTH: usize = 25; #[admin_command] pub(super) async fn list_users(&self) -> Result { - match self.services.users.list_local_users() { - Ok(users) => { - let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); - plain_msg += users.join("\n").as_str(); - plain_msg += "\n```"; + let users = self + .services + .users + .list_local_users() + .map(ToString::to_string) + .collect::>() + .await; - Ok(RoomMessageEventContent::notice_markdown(plain_msg)) - }, - Err(e) => Ok(RoomMessageEventContent::text_plain(e.to_string())), - } + let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); + plain_msg += users.join("\n").as_str(); + plain_msg += "\n```"; + + Ok(RoomMessageEventContent::notice_markdown(plain_msg)) } #[admin_command] @@ -42,7 +47,7 @@ pub(super) async fn create_user(&self, username: String, password: Option )); } - self.services.users.deactivate_account(&user_id)?; + self.services.users.deactivate_account(&user_id).await?; if !no_leave_rooms { self.services @@ -175,17 +184,22 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) -> .send_message(RoomMessageEventContent::text_plain(format!( "Making {user_id} leave all rooms after deactivation..." ))) - .await; + .await + .ok(); let all_joined_rooms: Vec = self .services .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; + full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?; + update_displayname(self.services, &user_id, None, &all_joined_rooms).await?; + update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms).await?; + leave_all_rooms(self.services, &user_id).await; } Ok(RoomMessageEventContent::text_plain(format!( @@ -238,15 +252,16 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> let mut admins = Vec::new(); for username in usernames { - match parse_active_local_user_id(self.services, username) { + match parse_active_local_user_id(self.services, username).await { Ok(user_id) => { - if self.services.users.is_admin(&user_id)? && !force { + if self.services.users.is_admin(&user_id).await && !force { self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "{username} is an admin and --force is not set, skipping over" ))) - .await; + .await + .ok(); admins.push(username); continue; } @@ -258,7 +273,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .send_message(RoomMessageEventContent::text_plain(format!( "{username} is the server service account, skipping over" ))) - .await; + .await + .ok(); continue; } @@ -270,7 +286,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .send_message(RoomMessageEventContent::text_plain(format!( "{username} is not a valid username, skipping over: {e}" ))) - .await; + .await + .ok(); continue; }, } @@ -279,7 +296,7 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> let mut deactivation_count: usize = 0; for user_id in user_ids { - match self.services.users.deactivate_account(&user_id) { + match self.services.users.deactivate_account(&user_id).await { Ok(()) => { deactivation_count = deactivation_count.saturating_add(1); if !no_leave_rooms { @@ -289,16 +306,26 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .collect(); - full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; + .map(Into::into) + .collect() + .await; + + full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?; + update_displayname(self.services, &user_id, None, &all_joined_rooms) + .await + .ok(); + update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms) + .await + .ok(); + leave_all_rooms(self.services, &user_id).await; } }, Err(e) => { self.services .admin .send_message(RoomMessageEventContent::text_plain(format!("Failed deactivating user: {e}"))) - .await; + .await + .ok(); }, } } @@ -326,9 +353,9 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result(&room_id, &StateEventType::RoomPowerLevels, "") + .await + .ok(); let user_can_demote_self = room_power_levels .as_ref() @@ -417,9 +443,9 @@ pub(super) async fn force_demote( .services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .as_ref() - .is_some_and(|event| event.sender == user_id); + .room_state_get(&room_id, &StateEventType::RoomCreate, "") + .await + .is_ok_and(|event| event.sender == user_id); if !user_can_demote_self { return Ok(RoomMessageEventContent::notice_markdown( @@ -473,15 +499,16 @@ pub(super) async fn make_user_admin(&self, user_id: String) -> Result, tag: String, ) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -494,12 +521,15 @@ pub(super) async fn put_room_tag( .tags .insert(tag.clone().into(), TagInfo::new()); - self.services.account_data.update( - Some(&room_id), - &user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + self.services + .account_data + .update( + Some(&room_id), + &user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Successfully updated room account data for {user_id} and room {room_id} with tag {tag}" @@ -510,15 +540,16 @@ pub(super) async fn put_room_tag( pub(super) async fn delete_room_tag( &self, user_id: String, room_id: Box, tag: String, ) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -528,12 +559,15 @@ pub(super) async fn delete_room_tag( tags_event.content.tags.remove(&tag.clone().into()); - self.services.account_data.update( - Some(&room_id), - &user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + self.services + .account_data + .update( + Some(&room_id), + &user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Successfully updated room account data for {user_id} and room {room_id}, deleting room tag {tag}" @@ -542,15 +576,16 @@ pub(super) async fn delete_room_tag( #[admin_command] pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -566,11 +601,12 @@ pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box) #[admin_command] pub(super) async fn redact_event(&self, event_id: Box) -> Result { - let Some(event) = self + let Ok(event) = self .services .rooms .timeline - .get_non_outlier_pdu(&event_id)? + .get_non_outlier_pdu(&event_id) + .await else { return Ok(RoomMessageEventContent::text_plain("Event does not exist in our database.")); }; diff --git a/src/admin/utils.rs b/src/admin/utils.rs index 8d3d15ae..ba98bbea 100644 --- a/src/admin/utils.rs +++ b/src/admin/utils.rs @@ -8,23 +8,21 @@ pub(crate) fn escape_html(s: &str) -> String { .replace('>', ">") } -pub(crate) fn get_room_info(services: &Services, id: &RoomId) -> (OwnedRoomId, u64, String) { +pub(crate) async fn get_room_info(services: &Services, room_id: &RoomId) -> (OwnedRoomId, u64, String) { ( - id.into(), + room_id.into(), services .rooms .state_cache - .room_joined_count(id) - .ok() - .flatten() + .room_joined_count(room_id) + .await .unwrap_or(0), services .rooms .state_accessor - .get_name(id) - .ok() - .flatten() - .unwrap_or_else(|| id.to_string()), + .get_name(room_id) + .await + .unwrap_or_else(|_| room_id.to_string()), ) } @@ -46,14 +44,14 @@ pub(crate) fn parse_local_user_id(services: &Services, user_id: &str) -> Result< } /// Parses user ID that is an active (not guest or deactivated) local user -pub(crate) fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result { +pub(crate) async fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result { let user_id = parse_local_user_id(services, user_id)?; - if !services.users.exists(&user_id)? { + if !services.users.exists(&user_id).await { return Err!("User {user_id:?} does not exist on this server."); } - if services.users.is_deactivated(&user_id)? { + if services.users.is_deactivated(&user_id).await? { return Err!("User {user_id:?} is deactivated."); } diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 2b89c3e8..6e37cb40 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -45,7 +45,7 @@ conduit-core.workspace = true conduit-database.workspace = true conduit-service.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true hmac.workspace = true http.workspace = true http-body-util.workspace = true diff --git a/src/api/client/account.rs b/src/api/client/account.rs index cee86f80..63d02f8f 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug_info, error, info, utils, warn, Error, PduBuilder, Result}; +use conduit::{debug_info, error, info, is_equal_to, utils, utils::ReadyExt, warn, Error, PduBuilder, Result}; +use futures::{FutureExt, StreamExt}; use register::RegistrationKind; use ruma::{ api::client::{ @@ -55,7 +56,7 @@ pub(crate) async fn get_register_available_route( .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; // Check if username is creative enough - if services.users.exists(&user_id)? { + if services.users.exists(&user_id).await { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } @@ -125,7 +126,7 @@ pub(crate) async fn register_route( // forbid guests from registering if there is not a real admin user yet. give // generic user error. - if is_guest && services.users.count()? < 2 { + if is_guest && services.users.count().await < 2 { warn!( "Guest account attempted to register before a real admin user has been registered, rejecting \ registration. Guest's initial device name: {:?}", @@ -142,7 +143,7 @@ pub(crate) async fn register_route( .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - if services.users.exists(&proposed_user_id)? { + if services.users.exists(&proposed_user_id).await { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } @@ -162,7 +163,7 @@ pub(crate) async fn register_route( services.globals.server_name(), ) .unwrap(); - if !services.users.exists(&proposed_user_id)? { + if !services.users.exists(&proposed_user_id).await { break proposed_user_id; } }, @@ -210,12 +211,15 @@ pub(crate) async fn register_route( if !skip_auth { if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services.uiaa.try_auth( - &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), - "".into(), - auth, - &uiaainfo, - )?; + let (worked, uiaainfo) = services + .uiaa + .try_auth( + &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), + "".into(), + auth, + &uiaainfo, + ) + .await?; if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -227,7 +231,7 @@ pub(crate) async fn register_route( "".into(), &uiaainfo, &json, - )?; + ); return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -255,21 +259,23 @@ pub(crate) async fn register_route( services .users - .set_displayname(&user_id, Some(displayname.clone())) - .await?; + .set_displayname(&user_id, Some(displayname.clone())); // Initial account data - services.account_data.update( - None, - &user_id, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json always works"), - )?; + services + .account_data + .update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json always works"), + ) + .await?; // Inhibit login does not work for guests if !is_guest && body.inhibit_login { @@ -294,13 +300,16 @@ pub(crate) async fn register_route( let token = utils::random_string(TOKEN_LENGTH); // Create device for this account - services.users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - Some(client.to_string()), - )?; + services + .users + .create_device( + &user_id, + &device_id, + &token, + body.initial_device_display_name.clone(), + Some(client.to_string()), + ) + .await?; debug_info!(%user_id, %device_id, "User account was created"); @@ -318,7 +327,8 @@ pub(crate) async fn register_route( "New user \"{user_id}\" registered on this server from IP {client} and device display name \ \"{device_display_name}\"" ))) - .await; + .await + .ok(); } } else { info!("New user \"{user_id}\" registered on this server."); @@ -329,7 +339,8 @@ pub(crate) async fn register_route( .send_message(RoomMessageEventContent::notice_plain(format!( "New user \"{user_id}\" registered on this server from IP {client}" ))) - .await; + .await + .ok(); } } } @@ -346,7 +357,8 @@ pub(crate) async fn register_route( "Guest user \"{user_id}\" with device display name \"{device_display_name}\" registered on \ this server from IP {client}" ))) - .await; + .await + .ok(); } } else { #[allow(clippy::collapsible_else_if)] @@ -357,7 +369,8 @@ pub(crate) async fn register_route( "Guest user \"{user_id}\" with no device display name registered on this server from IP \ {client}", ))) - .await; + .await + .ok(); } } } @@ -365,10 +378,15 @@ pub(crate) async fn register_route( // If this is the first real user, grant them admin privileges except for guest // users Note: the server user, @conduit:servername, is generated first if !is_guest { - if let Some(admin_room) = services.admin.get_admin_room()? { - if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { + if let Ok(admin_room) = services.admin.get_admin_room().await { + if services + .rooms + .state_cache + .room_joined_count(&admin_room) + .await + .is_ok_and(is_equal_to!(1)) + { services.admin.make_user_admin(&user_id).await?; - warn!("Granting {user_id} admin privileges as the first user"); } } @@ -382,7 +400,8 @@ pub(crate) async fn register_route( if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room)? + .server_in_room(services.globals.server_name(), room) + .await { warn!("Skipping room {room} to automatically join as we have never joined before."); continue; @@ -398,6 +417,7 @@ pub(crate) async fn register_route( None, &body.appservice_info, ) + .boxed() .await { // don't return this error so we don't fail registrations @@ -461,16 +481,20 @@ pub(crate) async fn change_password_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } - // Success! + + // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -482,14 +506,12 @@ pub(crate) async fn change_password_route( if body.logout_devices { // Logout all devices except the current one - for id in services + services .users .all_device_ids(sender_user) - .filter_map(Result::ok) - .filter(|id| id != sender_device) - { - services.users.remove_device(sender_user, &id)?; - } + .ready_filter(|id| id != sender_device) + .for_each(|id| services.users.remove_device(sender_user, id)) + .await; } info!("User {sender_user} changed their password."); @@ -500,7 +522,8 @@ pub(crate) async fn change_password_route( .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} changed their password." ))) - .await; + .await + .ok(); } Ok(change_password::v3::Response {}) @@ -520,7 +543,7 @@ pub(crate) async fn whoami_route( Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: services.users.is_deactivated(sender_user)? && body.appservice_info.is_none(), + is_guest: services.users.is_deactivated(sender_user).await? && body.appservice_info.is_none(), }) } @@ -561,7 +584,9 @@ pub(crate) async fn deactivate_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -570,7 +595,8 @@ pub(crate) async fn deactivate_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -581,10 +607,14 @@ pub(crate) async fn deactivate_route( .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(&services, sender_user, all_joined_rooms).await?; + super::update_displayname(&services, sender_user, None, &all_joined_rooms).await?; + super::update_avatar_url(&services, sender_user, None, None, &all_joined_rooms).await?; + + full_user_deactivate(&services, sender_user, &all_joined_rooms).await?; info!("User {sender_user} deactivated their account."); @@ -594,7 +624,8 @@ pub(crate) async fn deactivate_route( .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} deactivated their account." ))) - .await; + .await + .ok(); } Ok(deactivate::v3::Response { @@ -674,34 +705,27 @@ pub(crate) async fn check_registration_token_validity( /// - Removing all profile data /// - Leaving all rooms (and forgets all of them) pub async fn full_user_deactivate( - services: &Services, user_id: &UserId, all_joined_rooms: Vec, + services: &Services, user_id: &UserId, all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - services.users.deactivate_account(user_id)?; + services.users.deactivate_account(user_id).await?; + super::update_displayname(services, user_id, None, all_joined_rooms).await?; + super::update_avatar_url(services, user_id, None, None, all_joined_rooms).await?; - super::update_displayname(services, user_id, None, all_joined_rooms.clone()).await?; - super::update_avatar_url(services, user_id, None, None, all_joined_rooms.clone()).await?; - - let all_profile_keys = services + services .users .all_profile_keys(user_id) - .filter_map(Result::ok); - - for (profile_key, _profile_value) in all_profile_keys { - if let Err(e) = services.users.set_profile_key(user_id, &profile_key, None) { - warn!("Failed removing {user_id} profile key {profile_key}: {e}"); - } - } + .ready_for_each(|(profile_key, _)| services.users.set_profile_key(user_id, &profile_key, None)) + .await; for room_id in all_joined_rooms { - let state_lock = services.rooms.state.mutex.lock(&room_id).await; + let state_lock = services.rooms.state.mutex.lock(room_id).await; let room_power_levels = services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? - .as_ref() - .and_then(|event| serde_json::from_str(event.content.get()).ok()?) - .and_then(|content: RoomPowerLevelsEventContent| content.into()); + .room_state_get_content::(room_id, &StateEventType::RoomPowerLevels, "") + .await + .ok(); let user_can_demote_self = room_power_levels .as_ref() @@ -710,9 +734,9 @@ pub async fn full_user_deactivate( }) || services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .as_ref() - .is_some_and(|event| event.sender == user_id); + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + .is_ok_and(|event| event.sender == user_id); if user_can_demote_self { let mut power_levels_content = room_power_levels.unwrap_or_default(); @@ -732,7 +756,7 @@ pub async fn full_user_deactivate( timestamp: None, }, user_id, - &room_id, + room_id, &state_lock, ) .await diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 12d6352c..2399a355 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -1,11 +1,9 @@ use axum::extract::State; -use conduit::{debug, Error, Result}; +use conduit::{debug, Err, Result}; +use futures::StreamExt; use rand::seq::SliceRandom; use ruma::{ - api::client::{ - alias::{create_alias, delete_alias, get_alias}, - error::ErrorKind, - }, + api::client::alias::{create_alias, delete_alias, get_alias}, OwnedServerName, RoomAliasId, RoomId, }; use service::Services; @@ -33,16 +31,17 @@ pub(crate) async fn create_alias_route( .forbidden_alias_names() .is_match(body.room_alias.alias()) { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden.")); + return Err!(Request(Forbidden("Room alias is forbidden."))); } if services .rooms .alias - .resolve_local_alias(&body.room_alias)? - .is_some() + .resolve_local_alias(&body.room_alias) + .await + .is_ok() { - return Err(Error::Conflict("Alias already exists.")); + return Err!(Conflict("Alias already exists.")); } services @@ -95,16 +94,16 @@ pub(crate) async fn get_alias_route( .resolve_alias(&room_alias, servers.as_ref()) .await else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")); + return Err!(Request(NotFound("Room with alias not found."))); }; - let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers); + let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers).await; debug!(?room_alias, ?room_id, "available servers: {servers:?}"); Ok(get_alias::v3::Response::new(room_id, servers)) } -fn room_available_servers( +async fn room_available_servers( services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option>, ) -> Vec { // find active servers in room state cache to suggest @@ -112,8 +111,9 @@ fn room_available_servers( .rooms .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; // push any servers we want in the list already (e.g. responded remote alias // servers, room alias server itself) diff --git a/src/api/client/backup.rs b/src/api/client/backup.rs index 4ead8777..d52da80a 100644 --- a/src/api/client/backup.rs +++ b/src/api/client/backup.rs @@ -1,18 +1,16 @@ use axum::extract::State; +use conduit::{err, Err}; use ruma::{ - api::client::{ - backup::{ - add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, - delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, - get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, - get_latest_backup_info, update_backup_version, - }, - error::ErrorKind, + api::client::backup::{ + add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, + delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, + get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, + get_latest_backup_info, update_backup_version, }, UInt, }; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `POST /_matrix/client/r0/room_keys/version` /// @@ -40,7 +38,8 @@ pub(crate) async fn update_backup_version_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); services .key_backups - .update_backup(sender_user, &body.version, &body.algorithm)?; + .update_backup(sender_user, &body.version, &body.algorithm) + .await?; Ok(update_backup_version::v3::Response {}) } @@ -55,14 +54,15 @@ pub(crate) async fn get_latest_backup_info_route( let (version, algorithm) = services .key_backups - .get_latest_backup(sender_user)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; + .get_latest_backup(sender_user) + .await + .map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?; Ok(get_latest_backup_info::v3::Response { algorithm, - count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version)?) + count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version).await) .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &version)?, + etag: services.key_backups.get_etag(sender_user, &version).await, version, }) } @@ -76,18 +76,21 @@ pub(crate) async fn get_backup_info_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let algorithm = services .key_backups - .get_backup(sender_user, &body.version)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; + .get_backup(sender_user, &body.version) + .await + .map_err(|_| err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))))?; Ok(get_backup_info::v3::Response { algorithm, - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, version: body.version.clone(), }) } @@ -105,7 +108,8 @@ pub(crate) async fn delete_backup_version_route( services .key_backups - .delete_backup(sender_user, &body.version)?; + .delete_backup(sender_user, &body.version) + .await; Ok(delete_backup_version::v3::Response {}) } @@ -123,34 +127,36 @@ pub(crate) async fn add_backup_keys_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { services .key_backups - .add_key(sender_user, &body.version, room_id, session_id, key_data)?; + .add_key(sender_user, &body.version, room_id, session_id, key_data) + .await?; } } Ok(add_backup_keys::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -167,32 +173,34 @@ pub(crate) async fn add_backup_keys_for_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } for (session_id, key_data) in &body.sessions { services .key_backups - .add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; + .add_key(sender_user, &body.version, &body.room_id, session_id, key_data) + .await?; } Ok(add_backup_keys_for_room::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -209,30 +217,32 @@ pub(crate) async fn add_backup_keys_for_session_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } services .key_backups - .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; + .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data) + .await?; Ok(add_backup_keys_for_session::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -244,7 +254,10 @@ pub(crate) async fn get_backup_keys_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = services.key_backups.get_all(sender_user, &body.version)?; + let rooms = services + .key_backups + .get_all(sender_user, &body.version) + .await; Ok(get_backup_keys::v3::Response { rooms, @@ -261,7 +274,8 @@ pub(crate) async fn get_backup_keys_for_room_route( let sessions = services .key_backups - .get_room(sender_user, &body.version, &body.room_id)?; + .get_room(sender_user, &body.version, &body.room_id) + .await; Ok(get_backup_keys_for_room::v3::Response { sessions, @@ -278,8 +292,9 @@ pub(crate) async fn get_backup_keys_for_session_route( let key_data = services .key_backups - .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?; + .get_session(sender_user, &body.version, &body.room_id, &body.session_id) + .await + .map_err(|_| err!(Request(NotFound(debug_error!("Backup key not found for this user's session.")))))?; Ok(get_backup_keys_for_session::v3::Response { key_data, @@ -296,16 +311,19 @@ pub(crate) async fn delete_backup_keys_route( services .key_backups - .delete_all_keys(sender_user, &body.version)?; + .delete_all_keys(sender_user, &body.version) + .await; Ok(delete_backup_keys::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -319,16 +337,19 @@ pub(crate) async fn delete_backup_keys_for_room_route( services .key_backups - .delete_room_keys(sender_user, &body.version, &body.room_id)?; + .delete_room_keys(sender_user, &body.version, &body.room_id) + .await; Ok(delete_backup_keys_for_room::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -342,15 +363,18 @@ pub(crate) async fn delete_backup_keys_for_session_route( services .key_backups - .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; + .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id) + .await; Ok(delete_backup_keys_for_session::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } diff --git a/src/api/client/config.rs b/src/api/client/config.rs index 61cc97ff..33b85136 100644 --- a/src/api/client/config.rs +++ b/src/api/client/config.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::err; use ruma::{ api::client::{ config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data}, @@ -25,7 +26,8 @@ pub(crate) async fn set_global_account_data_route( &body.sender_user, &body.event_type.to_string(), body.data.json(), - )?; + ) + .await?; Ok(set_global_account_data::v3::Response {}) } @@ -42,7 +44,8 @@ pub(crate) async fn set_room_account_data_route( &body.sender_user, &body.event_type.to_string(), body.data.json(), - )?; + ) + .await?; Ok(set_room_account_data::v3::Response {}) } @@ -57,8 +60,9 @@ pub(crate) async fn get_global_account_data_route( let event: Box = services .account_data - .get(None, sender_user, body.event_type.to_string().into())? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; + .get(None, sender_user, body.event_type.to_string().into()) + .await + .map_err(|_| err!(Request(NotFound("Data not found."))))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -79,8 +83,9 @@ pub(crate) async fn get_room_account_data_route( let event: Box = services .account_data - .get(Some(&body.room_id), sender_user, body.event_type.clone())? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; + .get(Some(&body.room_id), sender_user, body.event_type.clone()) + .await + .map_err(|_| err!(Request(NotFound("Data not found."))))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -91,7 +96,7 @@ pub(crate) async fn get_room_account_data_route( }) } -fn set_account_data( +async fn set_account_data( services: &Services, room_id: Option<&RoomId>, sender_user: &Option, event_type: &str, data: &RawJsonValue, ) -> Result<()> { @@ -100,15 +105,18 @@ fn set_account_data( let data: serde_json::Value = serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; - services.account_data.update( - room_id, - sender_user, - event_type.into(), - &json!({ - "type": event_type, - "content": data, - }), - )?; + services + .account_data + .update( + room_id, + sender_user, + event_type.into(), + &json!({ + "type": event_type, + "content": data, + }), + ) + .await?; Ok(()) } diff --git a/src/api/client/context.rs b/src/api/client/context.rs index f223d488..cc49b763 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,13 +1,14 @@ use std::collections::HashSet; use axum::extract::State; +use conduit::{err, error, Err}; +use futures::StreamExt; use ruma::{ - api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, + api::client::{context::get_context, filter::LazyLoadOptions}, events::StateEventType, }; -use tracing::error; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// @@ -35,34 +36,33 @@ pub(crate) async fn get_context_route( let base_token = services .rooms .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?; + .get_pdu_count(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Base event id not found."))))?; let base_event = services .rooms .timeline - .get_pdu(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?; + .get_pdu(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Base event not found."))))?; - let room_id = base_event.room_id.clone(); + let room_id = &base_event.room_id; if !services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &body.event_id)? + .user_can_see_event(sender_user, room_id, &body.event_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this event.", - )); + return Err!(Request(Forbidden("You don't have permission to view this event."))); } - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &base_event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &base_event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(base_event.sender.as_str().to_owned()); } @@ -75,25 +75,26 @@ pub(crate) async fn get_context_route( let events_before: Vec<_> = services .rooms .timeline - .pdus_until(sender_user, &room_id, base_token)? + .pdus_until(sender_user, room_id, base_token) + .await? .take(limit / 2) - .filter_map(Result::ok) // Remove buggy events - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect(); + .collect() + .await; for (_, event) in &events_before { - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(event.sender.as_str().to_owned()); } @@ -111,25 +112,26 @@ pub(crate) async fn get_context_route( let events_after: Vec<_> = services .rooms .timeline - .pdus_after(sender_user, &room_id, base_token)? + .pdus_after(sender_user, room_id, base_token) + .await? .take(limit / 2) - .filter_map(Result::ok) // Remove buggy events - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect(); + .collect() + .await; for (_, event) in &events_after { - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(event.sender.as_str().to_owned()); } @@ -142,12 +144,14 @@ pub(crate) async fn get_context_route( events_after .last() .map_or(&*body.event_id, |(_, e)| &*e.event_id), - )? + ) + .await .map_or( services .rooms .state - .get_room_shortstatehash(&room_id)? + .get_room_shortstatehash(room_id) + .await .expect("All rooms have state"), |hash| hash, ); @@ -156,7 +160,8 @@ pub(crate) async fn get_context_route( .rooms .state_accessor .state_full_ids(shortstatehash) - .await?; + .await + .map_err(|e| err!(Database("State not found: {e}")))?; let end_token = events_after .last() @@ -173,18 +178,19 @@ pub(crate) async fn get_context_route( let (event_type, state_key) = services .rooms .short - .get_statekey_from_short(shortstatekey)?; + .get_statekey_from_short(shortstatekey) + .await?; if event_type != StateEventType::RoomMember { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; state.push(pdu.to_state_event()); } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; diff --git a/src/api/client/device.rs b/src/api/client/device.rs index bad7f284..93eaa393 100644 --- a/src/api/client/device.rs +++ b/src/api/client/device.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::{err, Err}; +use futures::StreamExt; use ruma::api::client::{ device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, error::ErrorKind, @@ -19,8 +21,8 @@ pub(crate) async fn get_devices_route( let devices: Vec = services .users .all_devices_metadata(sender_user) - .filter_map(Result::ok) // Filter out buggy devices - .collect(); + .collect() + .await; Ok(get_devices::v3::Response { devices, @@ -37,8 +39,9 @@ pub(crate) async fn get_device_route( let device = services .users - .get_device_metadata(sender_user, &body.body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + .get_device_metadata(sender_user, &body.body.device_id) + .await + .map_err(|_| err!(Request(NotFound("Device not found."))))?; Ok(get_device::v3::Response { device, @@ -55,14 +58,16 @@ pub(crate) async fn update_device_route( let mut device = services .users - .get_device_metadata(sender_user, &body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + .get_device_metadata(sender_user, &body.device_id) + .await + .map_err(|_| err!(Request(NotFound("Device not found."))))?; device.display_name.clone_from(&body.display_name); services .users - .update_device_metadata(sender_user, &body.device_id, &device)?; + .update_device_metadata(sender_user, &body.device_id, &device) + .await?; Ok(update_device::v3::Response {}) } @@ -97,22 +102,28 @@ pub(crate) async fn delete_device_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { - return Err(Error::Uiaa(uiaainfo)); + return Err!(Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); + .create(sender_user, sender_device, &uiaainfo, &json); + + return Err!(Uiaa(uiaainfo)); } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + return Err!(Request(NotJson("Not json."))); } - services.users.remove_device(sender_user, &body.device_id)?; + services + .users + .remove_device(sender_user, &body.device_id) + .await; Ok(delete_device::v3::Response {}) } @@ -149,7 +160,9 @@ pub(crate) async fn delete_devices_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -158,14 +171,15 @@ pub(crate) async fn delete_devices_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } for device_id in &body.devices { - services.users.remove_device(sender_user, device_id)?; + services.users.remove_device(sender_user, device_id).await; } Ok(delete_devices::v3::Response {}) diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 602f876a..ea499545 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -1,6 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{err, info, warn, Err, Error, Result}; +use conduit::{info, warn, Err, Error, Result}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ api::{ client::{ @@ -18,7 +19,7 @@ use ruma::{ }, StateEventType, }, - uint, RoomId, ServerName, UInt, UserId, + uint, OwnedRoomId, RoomId, ServerName, UInt, UserId, }; use service::Services; @@ -119,16 +120,22 @@ pub(crate) async fn set_room_visibility_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } - if services.users.is_deactivated(sender_user).unwrap_or(false) && body.appservice_info.is_none() { + if services + .users + .is_deactivated(sender_user) + .await + .unwrap_or(false) + && body.appservice_info.is_none() + { return Err!(Request(Forbidden("Guests cannot publish to room directories"))); } - if !user_can_publish_room(&services, sender_user, &body.room_id)? { + if !user_can_publish_room(&services, sender_user, &body.room_id).await? { return Err(Error::BadRequest( ErrorKind::forbidden(), "User is not allowed to publish this room", @@ -138,7 +145,7 @@ pub(crate) async fn set_room_visibility_route( match &body.visibility { room::Visibility::Public => { if services.globals.config.lockdown_public_room_directory - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { info!( @@ -164,7 +171,7 @@ pub(crate) async fn set_room_visibility_route( )); } - services.rooms.directory.set_public(&body.room_id)?; + services.rooms.directory.set_public(&body.room_id); if services.globals.config.admin_room_notices { services @@ -174,7 +181,7 @@ pub(crate) async fn set_room_visibility_route( } info!("{sender_user} made {0} public to the room directory", body.room_id); }, - room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id)?, + room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id), _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -192,13 +199,13 @@ pub(crate) async fn set_room_visibility_route( pub(crate) async fn get_room_visibility_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } Ok(get_room_visibility::v3::Response { - visibility: if services.rooms.directory.is_public_room(&body.room_id)? { + visibility: if services.rooms.directory.is_public_room(&body.room_id).await { room::Visibility::Public } else { room::Visibility::Private @@ -257,101 +264,41 @@ pub(crate) async fn get_public_rooms_filtered_helper( } } - let mut all_rooms: Vec<_> = services + let mut all_rooms: Vec = services .rooms .directory .public_rooms() - .map(|room_id| { - let room_id = room_id?; - - let chunk = PublicRoomsChunk { - canonical_alias: services - .rooms - .state_accessor - .get_canonical_alias(&room_id)?, - name: services.rooms.state_accessor.get_name(&room_id)?, - num_joined_members: services - .rooms - .state_cache - .room_joined_count(&room_id)? - .unwrap_or_else(|| { - warn!("Room {} has no member count", room_id); - 0 - }) - .try_into() - .expect("user count should not be that big"), - topic: services - .rooms - .state_accessor - .get_room_topic(&room_id) - .unwrap_or(None), - world_readable: services.rooms.state_accessor.is_world_readable(&room_id)?, - guest_can_join: services - .rooms - .state_accessor - .guest_can_join(&room_id)?, - avatar_url: services - .rooms - .state_accessor - .get_avatar(&room_id)? - .into_option() - .unwrap_or_default() - .url, - join_rule: services - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| match c.join_rule { - JoinRule::Public => Some(PublicRoomJoinRule::Public), - JoinRule::Knock => Some(PublicRoomJoinRule::Knock), - _ => None, - }) - .map_err(|e| { - err!(Database(error!("Invalid room join rule event in database: {e}"))) - }) - }) - .transpose()? - .flatten() - .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, - room_type: services - .rooms - .state_accessor - .get_room_type(&room_id)?, - room_id, - }; - Ok(chunk) - }) - .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms - .filter(|chunk| { + .map(ToOwned::to_owned) + .then(|room_id| public_rooms_chunk(services, room_id)) + .filter_map(|chunk| async move { if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) { if let Some(name) = &chunk.name { if name.as_str().to_lowercase().contains(&query) { - return true; + return Some(chunk); } } if let Some(topic) = &chunk.topic { if topic.to_lowercase().contains(&query) { - return true; + return Some(chunk); } } if let Some(canonical_alias) = &chunk.canonical_alias { if canonical_alias.as_str().to_lowercase().contains(&query) { - return true; + return Some(chunk); } } - false - } else { - // No search term - true + return None; } + + // No search term + Some(chunk) }) // We need to collect all, so we can sort by member count - .collect(); + .collect() + .await; all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); @@ -394,22 +341,23 @@ pub(crate) async fn get_public_rooms_filtered_helper( /// Check whether the user can publish to the room directory via power levels of /// room history visibility event or room creator -fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result { - if let Some(event) = services +async fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result { + if let Ok(event) = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await { serde_json::from_str(event.content.get()) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) .map(|content: RoomPowerLevelsEventContent| { RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility) }) - } else if let Some(event) = - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? + } else if let Ok(event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await { Ok(event.sender == user_id) } else { @@ -419,3 +367,61 @@ fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId )); } } + +async fn public_rooms_chunk(services: &Services, room_id: OwnedRoomId) -> PublicRoomsChunk { + PublicRoomsChunk { + canonical_alias: services + .rooms + .state_accessor + .get_canonical_alias(&room_id) + .await + .ok(), + name: services.rooms.state_accessor.get_name(&room_id).await.ok(), + num_joined_members: services + .rooms + .state_cache + .room_joined_count(&room_id) + .await + .unwrap_or(0) + .try_into() + .expect("joined count overflows ruma UInt"), + topic: services + .rooms + .state_accessor + .get_room_topic(&room_id) + .await + .ok(), + world_readable: services + .rooms + .state_accessor + .is_world_readable(&room_id) + .await, + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await, + avatar_url: services + .rooms + .state_accessor + .get_avatar(&room_id) + .await + .into_option() + .unwrap_or_default() + .url, + join_rule: services + .rooms + .state_accessor + .room_state_get_content(&room_id, &StateEventType::RoomJoinRules, "") + .map_ok(|c: RoomJoinRulesEventContent| match c.join_rule { + JoinRule::Public => PublicRoomJoinRule::Public, + JoinRule::Knock => PublicRoomJoinRule::Knock, + _ => "invite".into(), + }) + .await + .unwrap_or_default(), + room_type: services + .rooms + .state_accessor + .get_room_type(&room_id) + .await + .ok(), + room_id, + } +} diff --git a/src/api/client/filter.rs b/src/api/client/filter.rs index 8b2690c6..2a8ebb9c 100644 --- a/src/api/client/filter.rs +++ b/src/api/client/filter.rs @@ -1,10 +1,8 @@ use axum::extract::State; -use ruma::api::client::{ - error::ErrorKind, - filter::{create_filter, get_filter}, -}; +use conduit::err; +use ruma::api::client::filter::{create_filter, get_filter}; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// @@ -15,11 +13,13 @@ pub(crate) async fn get_filter_route( State(services): State, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let Some(filter) = services.users.get_filter(sender_user, &body.filter_id)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")); - }; - Ok(get_filter::v3::Response::new(filter)) + services + .users + .get_filter(sender_user, &body.filter_id) + .await + .map(get_filter::v3::Response::new) + .map_err(|_| err!(Request(NotFound("Filter not found.")))) } /// # `PUT /_matrix/client/r0/user/{userId}/filter` @@ -29,7 +29,8 @@ pub(crate) async fn create_filter_route( State(services): State, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - Ok(create_filter::v3::Response::new( - services.users.create_filter(sender_user, &body.filter)?, - )) + + let filter_id = services.users.create_filter(sender_user, &body.filter); + + Ok(create_filter::v3::Response::new(filter_id)) } diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index a426364a..abf2a22f 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -4,8 +4,8 @@ use std::{ }; use axum::extract::State; -use conduit::{utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{err, utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; +use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ client::{ @@ -21,7 +21,10 @@ use ruma::{ use serde_json::json; use super::SESSION_ID_LENGTH; -use crate::{service::Services, Ruma}; +use crate::{ + service::{users::parse_master_key, Services}, + Ruma, +}; /// # `POST /_matrix/client/r0/keys/upload` /// @@ -39,7 +42,8 @@ pub(crate) async fn upload_keys_route( for (key_key, key_value) in &body.one_time_keys { services .users - .add_one_time_key(sender_user, sender_device, key_key, key_value)?; + .add_one_time_key(sender_user, sender_device, key_key, key_value) + .await?; } if let Some(device_keys) = &body.device_keys { @@ -47,19 +51,22 @@ pub(crate) async fn upload_keys_route( // This check is needed to assure that signatures are kept if services .users - .get_device_keys(sender_user, sender_device)? - .is_none() + .get_device_keys(sender_user, sender_device) + .await + .is_err() { services .users - .add_device_keys(sender_user, sender_device, device_keys)?; + .add_device_keys(sender_user, sender_device, device_keys) + .await; } } Ok(upload_keys::v3::Response { one_time_key_counts: services .users - .count_one_time_keys(sender_user, sender_device)?, + .count_one_time_keys(sender_user, sender_device) + .await, }) } @@ -120,7 +127,9 @@ pub(crate) async fn upload_signing_keys_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -129,20 +138,24 @@ pub(crate) async fn upload_signing_keys_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } if let Some(master_key) = &body.master_key { - services.users.add_cross_signing_keys( - sender_user, - master_key, - &body.self_signing_key, - &body.user_signing_key, - true, // notify so that other users see the new keys - )?; + services + .users + .add_cross_signing_keys( + sender_user, + master_key, + &body.self_signing_key, + &body.user_signing_key, + true, // notify so that other users see the new keys + ) + .await?; } Ok(upload_signing_keys::v3::Response {}) @@ -179,9 +192,11 @@ pub(crate) async fn upload_signatures_route( .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? .to_owned(), ); + services .users - .sign_key(user_id, key_id, signature, sender_user)?; + .sign_key(user_id, key_id, signature, sender_user) + .await?; } } } @@ -204,56 +219,51 @@ pub(crate) async fn get_key_changes_route( let mut device_list_updates = HashSet::new(); + let from = body + .from + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?; + + let to = body + .to + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?; + device_list_updates.extend( services .users - .keys_changed( - sender_user.as_str(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(Result::ok), + .keys_changed(sender_user.as_str(), from, Some(to)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); - for room_id in services - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(Result::ok) - { + let mut rooms_joined = services.rooms.state_cache.rooms_joined(sender_user).boxed(); + + while let Some(room_id) = rooms_joined.next().await { device_list_updates.extend( services .users - .keys_changed( - room_id.as_ref(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(Result::ok), + .keys_changed(room_id.as_ref(), from, Some(to)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); } + Ok(get_key_changes::v3::Response { changed: device_list_updates.into_iter().collect(), left: Vec::new(), // TODO }) } -pub(crate) async fn get_keys_helper bool + Send>( +pub(crate) async fn get_keys_helper( services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap>, allowed_signatures: F, include_display_names: bool, -) -> Result { +) -> Result +where + F: Fn(&UserId) -> bool + Send + Sync, +{ let mut master_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new(); let mut user_signing_keys = BTreeMap::new(); @@ -274,56 +284,60 @@ pub(crate) async fn get_keys_helper bool + Send>( if device_ids.is_empty() { let mut container = BTreeMap::new(); - for device_id in services.users.all_device_ids(user_id) { - let device_id = device_id?; - if let Some(mut keys) = services.users.get_device_keys(user_id, &device_id)? { + let mut devices = services.users.all_device_ids(user_id).boxed(); + + while let Some(device_id) = devices.next().await { + if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await { let metadata = services .users - .get_device_metadata(user_id, &device_id)? - .ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?; + .get_device_metadata(user_id, device_id) + .await + .map_err(|_| err!(Database("all_device_keys contained nonexistent device.")))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| err!(Database("invalid device keys in database")))?; - container.insert(device_id, keys); + container.insert(device_id.to_owned(), keys); } } + device_keys.insert(user_id.to_owned(), container); } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = services.users.get_device_keys(user_id, device_id)? { + if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await { let metadata = services .users - .get_device_metadata(user_id, device_id)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to get keys for nonexistent device.", - ))?; + .get_device_metadata(user_id, device_id) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to get keys for nonexistent device."))))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| err!(Database("invalid device keys in database")))?; + container.insert(device_id.to_owned(), keys); } + device_keys.insert(user_id.to_owned(), container); } } - if let Some(master_key) = services + if let Ok(master_key) = services .users - .get_master_key(sender_user, user_id, &allowed_signatures)? + .get_master_key(sender_user, user_id, &allowed_signatures) + .await { master_keys.insert(user_id.to_owned(), master_key); } - if let Some(self_signing_key) = - services - .users - .get_self_signing_key(sender_user, user_id, &allowed_signatures)? + if let Ok(self_signing_key) = services + .users + .get_self_signing_key(sender_user, user_id, &allowed_signatures) + .await { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = services.users.get_user_signing_key(user_id)? { + if let Ok(user_signing_key) = services.users.get_user_signing_key(user_id).await { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -386,23 +400,26 @@ pub(crate) async fn get_keys_helper bool + Send>( while let Some((server, response)) = futures.next().await { if let Ok(Ok(response)) = response { for (user, masterkey) in response.master_keys { - let (master_key_id, mut master_key) = services.users.parse_master_key(&user, &masterkey)?; + let (master_key_id, mut master_key) = parse_master_key(&user, &masterkey)?; - if let Some(our_master_key) = - services - .users - .get_key(&master_key_id, sender_user, &user, &allowed_signatures)? + if let Ok(our_master_key) = services + .users + .get_key(&master_key_id, sender_user, &user, &allowed_signatures) + .await { - let (_, our_master_key) = services.users.parse_master_key(&user, &our_master_key)?; + let (_, our_master_key) = parse_master_key(&user, &our_master_key)?; master_key.signatures.extend(our_master_key.signatures); } let json = serde_json::to_value(master_key).expect("to_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works"); - services.users.add_cross_signing_keys( - &user, &raw, &None, &None, - false, /* Dont notify. A notification would trigger another key request resulting in an - * endless loop */ - )?; + services + .users + .add_cross_signing_keys( + &user, &raw, &None, &None, + false, /* Dont notify. A notification would trigger another key request resulting in an + * endless loop */ + ) + .await?; master_keys.insert(user.clone(), raw); } @@ -465,9 +482,10 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = services + if let Ok(one_time_keys) = services .users - .take_one_time_key(user_id, device_id, key_algorithm)? + .take_one_time_key(user_id, device_id, key_algorithm) + .await { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 470db669..5a5d436f 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -11,9 +11,10 @@ use conduit::{ debug, debug_error, debug_warn, err, error, info, pdu::{gen_event_id_canonical_json, PduBuilder}, trace, utils, - utils::math::continue_exponential_backoff_secs, + utils::{math::continue_exponential_backoff_secs, IterStream, ReadyExt}, warn, Err, Error, PduEvent, Result, }; +use futures::{FutureExt, StreamExt}; use ruma::{ api::{ client::{ @@ -55,9 +56,9 @@ async fn banned_room_check( services: &Services, user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>, client_ip: IpAddr, ) -> Result<()> { - if !services.users.is_admin(user_id)? { + if !services.users.is_admin(user_id).await { if let Some(room_id) = room_id { - if services.rooms.metadata.is_banned(room_id)? + if services.rooms.metadata.is_banned(room_id).await || services .globals .config @@ -79,23 +80,22 @@ async fn banned_room_check( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ {client_ip}" ))) - .await; + .await + .ok(); } let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(services, user_id, all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms).await?; } - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This room is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This room is banned on this homeserver."))); } } else if let Some(server_name) = server_name { if services @@ -119,23 +119,22 @@ async fn banned_room_check( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ {client_ip}" ))) - .await; + .await + .ok(); } let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(services, user_id, all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms).await?; } - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This remote server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This remote server is banned on this homeserver."))); } } } @@ -172,14 +171,16 @@ pub(crate) async fn join_room_by_id_route( .rooms .state_cache .servers_invite_via(&body.room_id) - .filter_map(Result::ok) - .collect::>(); + .map(ToOwned::to_owned) + .collect::>() + .await; servers.extend( services .rooms .state_cache - .invite_state(sender_user, &body.room_id)? + .invite_state(sender_user, &body.room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -202,6 +203,7 @@ pub(crate) async fn join_room_by_id_route( body.third_party_signed.as_ref(), &body.appservice_info, ) + .boxed() .await } @@ -233,14 +235,17 @@ pub(crate) async fn join_room_by_id_or_alias_route( .rooms .state_cache .servers_invite_via(&room_id) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); servers.extend( services .rooms .state_cache - .invite_state(sender_user, &room_id)? + .invite_state(sender_user, &room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -270,19 +275,23 @@ pub(crate) async fn join_room_by_id_or_alias_route( if let Some(pre_servers) = &mut pre_servers { servers.append(pre_servers); } + servers.extend( services .rooms .state_cache .servers_invite_via(&room_id) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); servers.extend( services .rooms .state_cache - .invite_state(sender_user, &room_id)? + .invite_state(sender_user, &room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -305,6 +314,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( body.third_party_signed.as_ref(), appservice_info, ) + .boxed() .await?; Ok(join_room_by_id_or_alias::v3::Response { @@ -337,7 +347,7 @@ pub(crate) async fn invite_user_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services.users.is_admin(sender_user)? && services.globals.block_non_admin_invites() { + if !services.users.is_admin(sender_user).await && services.globals.block_non_admin_invites() { info!( "User {sender_user} is not an admin and attempted to send an invite to room {}", &body.room_id @@ -375,15 +385,13 @@ pub(crate) async fn kick_user_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot kick member that's not in the room.", - ))? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot kick member that's not in the room."))))? .content .get(), ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + .map_err(|_| err!(Database("Invalid member event in database.")))?; event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); @@ -421,10 +429,13 @@ pub(crate) async fn ban_user_route( let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + let blurhash = services.users.blurhash(&body.user_id).await.ok(); + let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await .map_or( Ok(RoomMemberEventContent { membership: MembershipState::Ban, @@ -432,7 +443,7 @@ pub(crate) async fn ban_user_route( avatar_url: None, is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: blurhash.clone(), reason: body.reason.clone(), join_authorized_via_users_server: None, }), @@ -442,12 +453,12 @@ pub(crate) async fn ban_user_route( membership: MembershipState::Ban, displayname: None, avatar_url: None, - blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: blurhash.clone(), reason: body.reason.clone(), join_authorized_via_users_server: None, ..event }) - .map_err(|_| Error::bad_database("Invalid member event in database.")) + .map_err(|e| err!(Database("Invalid member event in database: {e:?}"))) }, )?; @@ -488,12 +499,13 @@ pub(crate) async fn unban_user_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .ok_or(Error::BadRequest(ErrorKind::BadState, "Cannot unban a user who is not banned."))? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot unban a user who is not banned."))))? .content .get(), ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + .map_err(|e| err!(Database("Invalid member event in database: {e:?}")))?; event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); @@ -539,18 +551,16 @@ pub(crate) async fn forget_room_route( if services .rooms .state_cache - .is_joined(sender_user, &body.room_id)? + .is_joined(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "You must leave the room before forgetting it", - )); + return Err!(Request(Unknown("You must leave the room before forgetting it"))); } services .rooms .state_cache - .forget(&body.room_id, sender_user)?; + .forget(&body.room_id, sender_user); Ok(forget_room::v3::Response::new()) } @@ -568,8 +578,9 @@ pub(crate) async fn joined_rooms_route( .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect(), + .map(ToOwned::to_owned) + .collect() + .await, }) } @@ -587,12 +598,10 @@ pub(crate) async fn get_member_events_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + return Err!(Request(Forbidden("You don't have permission to view this room."))); } Ok(get_member_events::v3::Response { @@ -622,30 +631,27 @@ pub(crate) async fn joined_members_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + return Err!(Request(Forbidden("You don't have permission to view this room."))); } let joined: BTreeMap = services .rooms .state_cache .room_members(&body.room_id) - .filter_map(|user| { - let user = user.ok()?; - - Some(( - user.clone(), + .then(|user| async move { + ( + user.to_owned(), RoomMember { - display_name: services.users.displayname(&user).unwrap_or_default(), - avatar_url: services.users.avatar_url(&user).unwrap_or_default(), + display_name: services.users.displayname(user).await.ok(), + avatar_url: services.users.avatar_url(user).await.ok(), }, - )) + ) }) - .collect(); + .collect() + .await; Ok(joined_members::v3::Response { joined, @@ -658,13 +664,23 @@ pub async fn join_room_by_id_helper( ) -> Result { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let user_is_guest = services.users.is_deactivated(sender_user).unwrap_or(false) && appservice_info.is_none(); + let user_is_guest = services + .users + .is_deactivated(sender_user) + .await + .unwrap_or(false) + && appservice_info.is_none(); - if matches!(services.rooms.state_accessor.guest_can_join(room_id), Ok(false)) && user_is_guest { + if user_is_guest && !services.rooms.state_accessor.guest_can_join(room_id).await { return Err!(Request(Forbidden("Guests are not allowed to join this room"))); } - if matches!(services.rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { + if services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { debug_warn!("{sender_user} is already joined in {room_id}"); return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), @@ -674,15 +690,17 @@ pub async fn join_room_by_id_helper( if services .rooms .state_cache - .server_in_room(services.globals.server_name(), room_id)? - || servers.is_empty() + .server_in_room(services.globals.server_name(), room_id) + .await || servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) { join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) + .boxed() .await } else { // Ask a remote server if we are not participating in this room join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) + .boxed() .await } } @@ -739,11 +757,11 @@ async fn join_room_by_id_helper_remote( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason, join_authorized_via_users_server: join_authorized_via_users_server.clone(), }) @@ -791,10 +809,11 @@ async fn join_room_by_id_helper_remote( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), + omit_members: false, pdu: services .sending - .convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, + .convert_to_outgoing_federation_event(join_event.clone()) + .await, }, ) .await?; @@ -864,7 +883,11 @@ async fn join_room_by_id_helper_remote( } } - services.rooms.short.get_or_create_shortroomid(room_id)?; + services + .rooms + .short + .get_or_create_shortroomid(room_id) + .await; info!("Parsing join event"); let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) @@ -895,12 +918,13 @@ async fn join_room_by_id_helper_remote( err!(BadServerResponse("Invalid PDU in send_join response: {e:?}")) })?; - services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value); if let Some(state_key) = &pdu.state_key { let shortstatekey = services .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -916,50 +940,53 @@ async fn join_room_by_id_helper_remote( continue; }; - services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value); } debug!("Running send_join auth check"); + let fetch_state = &state; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = fetch_state.get(&shortstatekey)?; + services.rooms.timeline.get_pdu(event_id).await.ok() + }; let auth_check = state_res::event_auth::auth_check( &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), &parsed_join_pdu, - None::, // TODO: third party invite - |k, s| { - services - .rooms - .timeline - .get_pdu( - state.get( - &services - .rooms - .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, - )?, - ) - .ok()? - }, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), ) - .map_err(|e| { - warn!("Auth check failed: {e}"); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed") - })?; + .await + .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); + return Err!(Request(Forbidden("Auth check failed"))); } info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services.rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| services.rooms.state_compressor.compress_state_event(k, &id)) - .collect::>()?, - ), - )?; + let (statehash_before_join, new, removed) = services + .rooms + .state_compressor + .save_state( + room_id, + Arc::new( + state + .into_iter() + .stream() + .then(|(k, id)| async move { + services + .rooms + .state_compressor + .compress_state_event(k, &id) + .await + }) + .collect() + .await, + ), + ) + .await?; services .rooms @@ -968,12 +995,20 @@ async fn join_room_by_id_helper_remote( .await?; info!("Updating joined counts for new room"); - services.rooms.state_cache.update_joined_count(room_id)?; + services + .rooms + .state_cache + .update_joined_count(room_id) + .await; // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehash_after_join = services.rooms.state.append_to_state(&parsed_join_pdu)?; + let statehash_after_join = services + .rooms + .state + .append_to_state(&parsed_join_pdu) + .await?; info!("Appending new room join event"); services @@ -993,7 +1028,7 @@ async fn join_room_by_id_helper_remote( services .rooms .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; + .set_room_state(room_id, statehash_after_join, &state_lock); Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } @@ -1005,23 +1040,15 @@ async fn join_room_by_id_helper_local( ) -> Result { debug!("We can join locally"); - let join_rules_event = services + let join_rules_event_content = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; + .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map(|content: RoomJoinRulesEventContent| content); let restriction_rooms = match join_rules_event_content { - Some(RoomJoinRulesEventContent { + Ok(RoomJoinRulesEventContent { join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), }) => restricted .allow @@ -1034,29 +1061,34 @@ async fn join_room_by_id_helper_local( _ => Vec::new(), }; - let local_members = services + let local_members: Vec<_> = services .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|user| services.globals.user_is_local(user)) - .collect::>(); + .ready_filter(|user| services.globals.user_is_local(user)) + .map(ToOwned::to_owned) + .collect() + .await; let mut join_authorized_via_users_server: Option = None; - if restriction_rooms.iter().any(|restriction_room_id| { - services - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .unwrap_or(false) - }) { + if restriction_rooms + .iter() + .stream() + .any(|restriction_room_id| { + services + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + }) + .await + { for user in local_members { if services .rooms .state_accessor .user_can_invite(room_id, &user, sender_user, &state_lock) - .unwrap_or(false) + .await { join_authorized_via_users_server = Some(user); break; @@ -1066,11 +1098,11 @@ async fn join_room_by_id_helper_local( let event = RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: reason.clone(), join_authorized_via_users_server, }; @@ -1144,11 +1176,11 @@ async fn join_room_by_id_helper_local( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason, join_authorized_via_users_server, }) @@ -1195,10 +1227,11 @@ async fn join_room_by_id_helper_local( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), + omit_members: false, pdu: services .sending - .convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, + .convert_to_outgoing_federation_event(join_event.clone()) + .await, }, ) .await?; @@ -1369,7 +1402,7 @@ pub(crate) async fn invite_helper( services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option, is_direct: bool, ) -> Result<()> { - if !services.users.is_admin(user_id)? && services.globals.block_non_admin_invites() { + if !services.users.is_admin(user_id).await && services.globals.block_non_admin_invites() { info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}"); return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -1381,7 +1414,7 @@ pub(crate) async fn invite_helper( let (pdu, pdu_json, invite_room_state) = { let state_lock = services.rooms.state.mutex.lock(room_id).await; let content = to_raw_value(&RoomMemberEventContent { - avatar_url: services.users.avatar_url(user_id)?, + avatar_url: services.users.avatar_url(user_id).await.ok(), displayname: None, is_direct: Some(is_direct), membership: MembershipState::Invite, @@ -1392,28 +1425,32 @@ pub(crate) async fn invite_helper( }) .expect("member event is valid value"); - let (pdu, pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - sender_user, - room_id, - &state_lock, - )?; + let (pdu, pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await?; - let invite_room_state = services.rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = services.rooms.state.calculate_invite_state(&pdu).await?; drop(state_lock); (pdu, pdu_json, invite_room_state) }; - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let response = services .sending @@ -1425,9 +1462,15 @@ pub(crate) async fn invite_helper( room_version: room_version_id.clone(), event: services .sending - .convert_to_outgoing_federation_event(pdu_json.clone()), + .convert_to_outgoing_federation_event(pdu_json.clone()) + .await, invite_room_state, - via: services.rooms.state_cache.servers_route_via(room_id).ok(), + via: services + .rooms + .state_cache + .servers_route_via(room_id) + .await + .ok(), }, ) .await?; @@ -1478,11 +1521,16 @@ pub(crate) async fn invite_helper( "Could not accept incoming PDU as timeline event.", ))?; - services.sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id).await?; return Ok(()); } - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", @@ -1499,11 +1547,11 @@ pub(crate) async fn invite_helper( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Invite, - displayname: services.users.displayname(user_id)?, - avatar_url: services.users.avatar_url(user_id)?, + displayname: services.users.displayname(user_id).await.ok(), + avatar_url: services.users.avatar_url(user_id).await.ok(), is_direct: Some(is_direct), third_party_invite: None, - blurhash: services.users.blurhash(user_id)?, + blurhash: services.users.blurhash(user_id).await.ok(), reason, join_authorized_via_users_server: None, }) @@ -1531,36 +1579,37 @@ pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { .rooms .state_cache .rooms_joined(user_id) + .map(ToOwned::to_owned) .chain( services .rooms .state_cache .rooms_invited(user_id) - .map(|t| t.map(|(r, _)| r)), + .map(|(r, _)| r), ) - .collect::>(); + .collect::>() + .await; for room_id in all_rooms { - let Ok(room_id) = room_id else { - continue; - }; - // ignore errors if let Err(e) = leave_room(services, user_id, &room_id, None).await { warn!(%room_id, %user_id, %e, "Failed to leave room"); } - if let Err(e) = services.rooms.state_cache.forget(&room_id, user_id) { - warn!(%room_id, %user_id, %e, "Failed to forget room"); - } + + services.rooms.state_cache.forget(&room_id, user_id); } } pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, reason: Option) -> Result<()> { + //use conduit::utils::stream::OptionStream; + use futures::TryFutureExt; + // Ask a remote server if we don't have this room if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room_id)? + .server_in_room(services.globals.server_name(), room_id) + .await { if let Err(e) = remote_leave_room(services, user_id, room_id).await { warn!("Failed to leave room {} remotely: {}", user_id, e); @@ -1570,34 +1619,42 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, let last_state = services .rooms .state_cache - .invite_state(user_id, room_id)? - .map_or_else(|| services.rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?; + .invite_state(user_id, room_id) + .map_err(|_| services.rooms.state_cache.left_state(user_id, room_id)) + .await + .ok(); // We always drop the invite, we can't rely on other servers - services.rooms.state_cache.update_membership( - room_id, - user_id, - RoomMemberEventContent::new(MembershipState::Leave), - user_id, - last_state, - None, - true, - )?; + services + .rooms + .state_cache + .update_membership( + room_id, + user_id, + RoomMemberEventContent::new(MembershipState::Leave), + user_id, + last_state, + None, + true, + ) + .await?; } else { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let member_event = - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?; + let member_event = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await; // Fix for broken rooms - let member_event = match member_event { - None => { - error!("Trying to leave a room you are not a member of."); + let Ok(member_event) = member_event else { + error!("Trying to leave a room you are not a member of."); - services.rooms.state_cache.update_membership( + services + .rooms + .state_cache + .update_membership( room_id, user_id, RoomMemberEventContent::new(MembershipState::Leave), @@ -1605,16 +1662,14 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, None, None, true, - )?; - return Ok(()); - }, - Some(e) => e, + ) + .await?; + + return Ok(()); }; - let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()).map_err(|e| { - error!("Invalid room member event in database: {}", e); - Error::bad_database("Invalid member event in database.") - })?; + let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()) + .map_err(|e| err!(Database(error!("Invalid room member event in database: {e}"))))?; event.membership = MembershipState::Leave; event.reason = reason; @@ -1647,15 +1702,17 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room let invite_state = services .rooms .state_cache - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; + .invite_state(user_id, room_id) + .await + .map_err(|_| err!(Request(BadState("User is not invited."))))?; let mut servers: HashSet = services .rooms .state_cache .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; servers.extend( invite_state @@ -1760,7 +1817,8 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room event_id, pdu: services .sending - .convert_to_outgoing_federation_event(leave_event.clone()), + .convert_to_outgoing_federation_event(leave_event.clone()) + .await, }, ) .await?; diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 51aee8c1..bab5fa54 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,7 +1,8 @@ use std::collections::{BTreeMap, HashSet}; use axum::extract::State; -use conduit::PduCount; +use conduit::{err, utils::ReadyExt, Err, PduCount}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -9,13 +10,14 @@ use ruma::{ message::{get_message_events, send_message_event}, }, events::{MessageLikeEventType, StateEventType}, - RoomId, UserId, + UserId, }; use serde_json::{from_str, Value}; +use service::rooms::timeline::PdusIterItem; use crate::{ service::{pdu::PduBuilder, Services}, - utils, Error, PduEvent, Result, Ruma, + utils, Error, Result, Ruma, }; /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` @@ -30,79 +32,78 @@ use crate::{ pub(crate) async fn send_message_event_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + let appservice_info = body.appservice_info.as_ref(); // Forbid m.room.encrypted if encryption is disabled if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); + return Err!(Request(Forbidden("Encryption has been disabled"))); } - if body.event_type == MessageLikeEventType::CallInvite && services.rooms.directory.is_public_room(&body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Room call invites are not allowed in public rooms", - )); + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + if body.event_type == MessageLikeEventType::CallInvite + && services.rooms.directory.is_public_room(&body.room_id).await + { + return Err!(Request(Forbidden("Room call invites are not allowed in public rooms"))); } // Check if this is a new transaction id - if let Some(response) = services + if let Ok(response) = services .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await { // The client might have sent a txnid of the /sendToDevice endpoint // This txnid has no response associated with it if response.is_empty() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to use txn id already used for an incompatible endpoint.", - )); + return Err!(Request(InvalidParam( + "Tried to use txn id already used for an incompatible endpoint." + ))); } - let event_id = utils::string_from_bytes(&response) - .map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; return Ok(send_message_event::v3::Response { - event_id, + event_id: utils::string_from_bytes(&response) + .map(TryInto::try_into) + .map_err(|e| err!(Database("Invalid event_id in txnid data: {e:?}")))??, }); } let mut unsigned = BTreeMap::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); + let content = from_str(body.body.body.json().get()) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?; + let event_id = services .rooms .timeline .build_and_append_pdu( PduBuilder { event_type: body.event_type.to_string().into(), - content: from_str(body.body.body.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, + content, unsigned: Some(unsigned), state_key: None, redacts: None, - timestamp: if body.appservice_info.is_some() { - body.timestamp - } else { - None - }, + timestamp: appservice_info.and(body.timestamp), }, sender_user, &body.room_id, &state_lock, ) - .await?; + .await + .map(|event_id| (*event_id).to_owned())?; services .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; + .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes()); drop(state_lock); - Ok(send_message_event::v3::Response::new((*event_id).to_owned())) + Ok(send_message_event::v3::Response { + event_id, + }) } /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` @@ -117,8 +118,12 @@ pub(crate) async fn get_message_events_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, + let room_id = &body.room_id; + let filter = &body.filter; + + let limit = usize::try_from(body.limit).unwrap_or(10).min(100); + let from = match body.from.as_ref() { + Some(from) => PduCount::try_from_string(from)?, None => match body.dir { ruma::api::Direction::Forward => PduCount::min(), ruma::api::Direction::Backward => PduCount::max(), @@ -133,30 +138,25 @@ pub(crate) async fn get_message_events_route( services .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) - .await?; - - let limit = usize::try_from(body.limit).unwrap_or(10).min(100); - - let next_token; + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, from); let mut resp = get_message_events::v3::Response::new(); - let mut lazy_loaded = HashSet::new(); - + let next_token; match body.dir { ruma::api::Direction::Forward => { - let events_after: Vec<_> = services + let events_after: Vec = services .rooms .timeline - .pdus_after(sender_user, &body.room_id, from)? - .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id) - - }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .pdus_after(sender_user, room_id, from) + .await? + .ready_filter_map(|item| contains_url_filter(item, filter)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) + .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` .take(limit) - .collect(); + .collect() + .boxed() + .await; for (_, event) in &events_after { /* TODO: Remove the not "element_hacks" check when these are resolved: @@ -164,16 +164,18 @@ pub(crate) async fn get_message_events_route( * https://github.com/vector-im/element-web/issues/21034 */ if !cfg!(feature = "element_hacks") - && !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { + && !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await + { lazy_loaded.insert(event.sender.clone()); } - lazy_loaded.insert(event.sender.clone()); + if cfg!(features = "element_hacks") { + lazy_loaded.insert(event.sender.clone()); + } } next_token = events_after.last().map(|(count, _)| count).copied(); @@ -191,17 +193,22 @@ pub(crate) async fn get_message_events_route( services .rooms .timeline - .backfill_if_required(&body.room_id, from) + .backfill_if_required(room_id, from) + .boxed() .await?; - let events_before: Vec<_> = services + + let events_before: Vec = services .rooms .timeline - .pdus_until(sender_user, &body.room_id, from)? - .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id)}) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .pdus_until(sender_user, room_id, from) + .await? + .ready_filter_map(|item| contains_url_filter(item, filter)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) + .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` .take(limit) - .collect(); + .collect() + .boxed() + .await; for (_, event) in &events_before { /* TODO: Remove the not "element_hacks" check when these are resolved: @@ -209,16 +216,18 @@ pub(crate) async fn get_message_events_route( * https://github.com/vector-im/element-web/issues/21034 */ if !cfg!(feature = "element_hacks") - && !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { + && !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await + { lazy_loaded.insert(event.sender.clone()); } - lazy_loaded.insert(event.sender.clone()); + if cfg!(features = "element_hacks") { + lazy_loaded.insert(event.sender.clone()); + } } next_token = events_before.last().map(|(count, _)| count).copied(); @@ -236,11 +245,11 @@ pub(crate) async fn get_message_events_route( resp.state = Vec::new(); for ll_id in &lazy_loaded { - if let Some(member_event) = - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? + if let Ok(member_event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, ll_id.as_str()) + .await { resp.state.push(member_event.to_state_event()); } @@ -249,34 +258,43 @@ pub(crate) async fn get_message_events_route( // remove the feature check when we are sure clients like element can handle it if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token) - .await; + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_token, + ); } } Ok(resp) } -fn visibility_filter(services: &Services, pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool { +async fn visibility_filter(services: &Services, item: PdusIterItem, user_id: &UserId) -> Option { + let (_, pdu) = &item; + services .rooms .state_accessor - .user_can_see_event(user_id, room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(user_id, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) } -fn contains_url_filter(pdu: &PduEvent, filter: &RoomEventFilter) -> bool { +fn contains_url_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option { + let (_, pdu) = &item; + if filter.url_filter.is_none() { - return true; + return Some(item); } let content: Value = from_str(pdu.content.get()).unwrap(); - match filter.url_filter { + let res = match filter.url_filter { Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(), Some(UrlFilter::EventsWithUrl) => content["url"].is_string(), None => true, - } + }; + + res.then_some(item) } diff --git a/src/api/client/presence.rs b/src/api/client/presence.rs index 8384d5ac..ba48808b 100644 --- a/src/api/client/presence.rs +++ b/src/api/client/presence.rs @@ -28,7 +28,8 @@ pub(crate) async fn set_presence_route( services .presence - .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?; + .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone()) + .await?; Ok(set_presence::v3::Response {}) } @@ -49,14 +50,15 @@ pub(crate) async fn get_presence_route( let mut presence_event = None; - for _room_id in services + let has_shared_rooms = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - { - if let Some(presence) = services.presence.get_presence(&body.user_id)? { + .has_shared_rooms(sender_user, &body.user_id) + .await; + + if has_shared_rooms { + if let Ok(presence) = services.presence.get_presence(&body.user_id).await { presence_event = Some(presence); - break; } } diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index bf47a3f8..495bc8ec 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -1,5 +1,10 @@ use axum::extract::State; -use conduit::{pdu::PduBuilder, warn, Err, Error, Result}; +use conduit::{ + pdu::PduBuilder, + utils::{stream::TryIgnore, IterStream}, + warn, Err, Error, Result, +}; +use futures::{StreamExt, TryStreamExt}; use ruma::{ api::{ client::{ @@ -35,16 +40,18 @@ pub(crate) async fn set_displayname_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; - update_displayname(&services, &body.user_id, body.displayname.clone(), all_joined_rooms).await?; + update_displayname(&services, &body.user_id, body.displayname.clone(), &all_joined_rooms).await?; if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_display_name::v3::Response {}) @@ -72,22 +79,19 @@ pub(crate) async fn get_displayname_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); return Ok(get_display_name::v3::Response { displayname: response.displayname, @@ -95,14 +99,14 @@ pub(crate) async fn get_displayname_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_display_name::v3::Response { - displayname: services.users.displayname(&body.user_id)?, + displayname: services.users.displayname(&body.user_id).await.ok(), }) } @@ -124,15 +128,16 @@ pub(crate) async fn set_avatar_url_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; update_avatar_url( &services, &body.user_id, body.avatar_url.clone(), body.blurhash.clone(), - all_joined_rooms, + &all_joined_rooms, ) .await?; @@ -140,7 +145,9 @@ pub(crate) async fn set_avatar_url_route( // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await + .ok(); } Ok(set_avatar_url::v3::Response {}) @@ -168,22 +175,21 @@ pub(crate) async fn get_avatar_url_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); return Ok(get_avatar_url::v3::Response { avatar_url: response.avatar_url, @@ -192,15 +198,15 @@ pub(crate) async fn get_avatar_url_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_avatar_url::v3::Response { - avatar_url: services.users.avatar_url(&body.user_id)?, - blurhash: services.users.blurhash(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), }) } @@ -226,31 +232,30 @@ pub(crate) async fn get_profile_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); for (profile_key, profile_key_value) in &response.custom_profile_fields { services .users - .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone()))?; + .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone())); } return Ok(get_profile::v3::Response { @@ -263,104 +268,93 @@ pub(crate) async fn get_profile_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_profile::v3::Response { - avatar_url: services.users.avatar_url(&body.user_id)?, - blurhash: services.users.blurhash(&body.user_id)?, - displayname: services.users.displayname(&body.user_id)?, - tz: services.users.timezone(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), + displayname: services.users.displayname(&body.user_id).await.ok(), + tz: services.users.timezone(&body.user_id).await.ok(), custom_profile_fields: services .users .all_profile_keys(&body.user_id) - .filter_map(Result::ok) - .collect(), + .collect() + .await, }) } pub async fn update_displayname( - services: &Services, user_id: &UserId, displayname: Option, all_joined_rooms: Vec, + services: &Services, user_id: &UserId, displayname: Option, all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - let current_display_name = services.users.displayname(user_id).unwrap_or_default(); + let current_display_name = services.users.displayname(user_id).await.ok(); if displayname == current_display_name { return Ok(()); } - services - .users - .set_displayname(user_id, displayname.clone()) - .await?; + services.users.set_displayname(user_id, displayname.clone()); // Send a new join membership event into all joined rooms - let all_joined_rooms: Vec<_> = all_joined_rooms - .iter() - .map(|room_id| { - Ok::<_, Error>(( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - displayname: displayname.clone(), - join_authorized_via_users_server: None, - ..serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or_else(|| { - Error::bad_database("Tried to send display name update for user not in the room.") - })? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - room_id, - )) - }) - .filter_map(Result::ok) - .collect(); + let mut joined_rooms = Vec::new(); + for room_id in all_joined_rooms { + let Ok(event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + else { + continue; + }; - update_all_rooms(services, all_joined_rooms, user_id).await; + let pdu = PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + displayname: displayname.clone(), + join_authorized_via_users_server: None, + ..serde_json::from_str(event.content.get()).expect("Database contains invalid PDU.") + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }; + + joined_rooms.push((pdu, room_id)); + } + + update_all_rooms(services, joined_rooms, user_id).await; Ok(()) } pub async fn update_avatar_url( services: &Services, user_id: &UserId, avatar_url: Option, blurhash: Option, - all_joined_rooms: Vec, + all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - let current_avatar_url = services.users.avatar_url(user_id).unwrap_or_default(); - let current_blurhash = services.users.blurhash(user_id).unwrap_or_default(); + let current_avatar_url = services.users.avatar_url(user_id).await.ok(); + let current_blurhash = services.users.blurhash(user_id).await.ok(); if current_avatar_url == avatar_url && current_blurhash == blurhash { return Ok(()); } - services - .users - .set_avatar_url(user_id, avatar_url.clone()) - .await?; - services - .users - .set_blurhash(user_id, blurhash.clone()) - .await?; + services.users.set_avatar_url(user_id, avatar_url.clone()); + + services.users.set_blurhash(user_id, blurhash.clone()); // Send a new join membership event into all joined rooms + let avatar_url = &avatar_url; + let blurhash = &blurhash; let all_joined_rooms: Vec<_> = all_joined_rooms .iter() - .map(|room_id| { - Ok::<_, Error>(( + .try_stream() + .and_then(|room_id: &OwnedRoomId| async move { + Ok(( PduBuilder { event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { @@ -371,8 +365,9 @@ pub async fn update_avatar_url( services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or_else(|| { + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + .map_err(|_| { Error::bad_database("Tried to send avatar URL update for user not in the room.") })? .content @@ -389,8 +384,9 @@ pub async fn update_avatar_url( room_id, )) }) - .filter_map(Result::ok) - .collect(); + .ignore_err() + .collect() + .await; update_all_rooms(services, all_joined_rooms, user_id).await; diff --git a/src/api/client/push.rs b/src/api/client/push.rs index 8723e676..39095199 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -29,41 +29,37 @@ pub(crate) async fn get_pushrules_all_route( let global_ruleset: Ruleset; - let Ok(event) = - services - .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) - else { - // push rules event doesn't exist, create it and return default - return recreate_push_rules_and_return(&services, sender_user); - }; + let event = services + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await; - if let Some(event) = event { - let value = serde_json::from_str::(event.get()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - - let Some(content_value) = value.get("content") else { - // user somehow has a push rule event with no content key, recreate it and - // return server default silently - return recreate_push_rules_and_return(&services, sender_user); - }; - - if content_value.to_string().is_empty() { - // user somehow has a push rule event with empty content, recreate it and return - // server default silently - return recreate_push_rules_and_return(&services, sender_user); - } - - let account_data_content = serde_json::from_value::(content_value.clone().into()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - - global_ruleset = account_data_content.global; - } else { + let Ok(event) = event else { // user somehow has non-existent push rule event. recreate it and return server // default silently - return recreate_push_rules_and_return(&services, sender_user); + return recreate_push_rules_and_return(&services, sender_user).await; + }; + + let value = serde_json::from_str::(event.get()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + + let Some(content_value) = value.get("content") else { + // user somehow has a push rule event with no content key, recreate it and + // return server default silently + return recreate_push_rules_and_return(&services, sender_user).await; + }; + + if content_value.to_string().is_empty() { + // user somehow has a push rule event with empty content, recreate it and return + // server default silently + return recreate_push_rules_and_return(&services, sender_user).await; } + let account_data_content = serde_json::from_value::(content_value.clone().into()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + + global_ruleset = account_data_content.global; + Ok(get_pushrules_all::v3::Response { global: global_ruleset, }) @@ -79,8 +75,9 @@ pub(crate) async fn get_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -118,8 +115,9 @@ pub(crate) async fn set_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -155,12 +153,15 @@ pub(crate) async fn set_pushrule_route( return Err(err); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule::v3::Response {}) } @@ -182,8 +183,9 @@ pub(crate) async fn get_pushrule_actions_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -217,8 +219,9 @@ pub(crate) async fn set_pushrule_actions_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -232,12 +235,15 @@ pub(crate) async fn set_pushrule_actions_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_actions::v3::Response {}) } @@ -259,8 +265,9 @@ pub(crate) async fn get_pushrule_enabled_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -293,8 +300,9 @@ pub(crate) async fn set_pushrule_enabled_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -308,12 +316,15 @@ pub(crate) async fn set_pushrule_enabled_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_enabled::v3::Response {}) } @@ -335,8 +346,9 @@ pub(crate) async fn delete_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -357,12 +369,15 @@ pub(crate) async fn delete_pushrule_route( return Err(err); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(delete_pushrule::v3::Response {}) } @@ -376,7 +391,7 @@ pub(crate) async fn get_pushers_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { - pushers: services.pusher.get_pushers(sender_user)?, + pushers: services.pusher.get_pushers(sender_user).await, }) } @@ -390,27 +405,30 @@ pub(crate) async fn set_pushers_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services.pusher.set_pusher(sender_user, &body.action)?; + services.pusher.set_pusher(sender_user, &body.action); Ok(set_pusher::v3::Response::default()) } /// user somehow has bad push rules, these must always exist per spec. /// so recreate it and return server default silently -fn recreate_push_rules_and_return( +async fn recreate_push_rules_and_return( services: &Services, sender_user: &ruma::UserId, ) -> Result { - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(PushRulesEvent { - content: PushRulesEventContent { - global: Ruleset::server_default(sender_user), - }, - }) - .expect("to json always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { + global: Ruleset::server_default(sender_user), + }, + }) + .expect("to json always works"), + ) + .await?; Ok(get_pushrules_all::v3::Response { global: Ruleset::server_default(sender_user), diff --git a/src/api/client/read_marker.rs b/src/api/client/read_marker.rs index f40f2493..f28b2aec 100644 --- a/src/api/client/read_marker.rs +++ b/src/api/client/read_marker.rs @@ -31,27 +31,32 @@ pub(crate) async fn set_read_marker_route( event_id: fully_read.clone(), }, }; - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; } if body.private_read_receipt.is_some() || body.read_receipt.is_some() { services .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id); } if let Some(event) = &body.private_read_receipt { let count = services .rooms .timeline - .get_pdu_count(event)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .get_pdu_count(event) + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -64,7 +69,7 @@ pub(crate) async fn set_read_marker_route( services .rooms .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count); } if let Some(event) = &body.read_receipt { @@ -83,14 +88,18 @@ pub(crate) async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - services.rooms.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - &ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services + .rooms + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + &ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await; } Ok(set_read_marker::v3::Response {}) @@ -111,7 +120,7 @@ pub(crate) async fn create_receipt_route( services .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id); } match body.receipt_type { @@ -121,12 +130,15 @@ pub(crate) async fn create_receipt_route( event_id: body.event_id.clone(), }, }; - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; }, create_receipt::v3::ReceiptType::Read => { let mut user_receipts = BTreeMap::new(); @@ -143,21 +155,27 @@ pub(crate) async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.clone(), receipts); - services.rooms.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - &ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services + .rooms + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + &ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await; }, create_receipt::v3::ReceiptType::ReadPrivate => { let count = services .rooms .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .get_pdu_count(&body.event_id) + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -170,7 +188,7 @@ pub(crate) async fn create_receipt_route( services .rooms .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count); }, _ => return Err(Error::bad_database("Unsupported receipt type")), } diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index ae645940..d4384730 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -9,20 +9,24 @@ use crate::{Result, Ruma}; pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let res = services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &Some(body.event_type.clone()), - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - )?; + let res = services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + body.event_type.clone().into(), + body.rel_type.clone().into(), + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await?; Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, @@ -36,20 +40,24 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let res = services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - )?; + let res = services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + body.rel_type.clone().into(), + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, @@ -63,18 +71,22 @@ pub(crate) async fn get_relating_events_with_rel_type_route( pub(crate) async fn get_relating_events_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &None, - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - ) + services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + None, + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await } diff --git a/src/api/client/report.rs b/src/api/client/report.rs index 588bd368..a40c35a2 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -1,6 +1,7 @@ use std::time::Duration; use axum::extract::State; +use conduit::{utils::ReadyExt, Err}; use rand::Rng; use ruma::{ api::client::{error::ErrorKind, room::report_content}, @@ -34,11 +35,8 @@ pub(crate) async fn report_event_route( delay_response().await; // check if we know about the reported event ID or if it's invalid - let Some(pdu) = services.rooms.timeline.get_pdu(&body.event_id)? else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Event ID is not known to us or Event ID is invalid", - )); + let Ok(pdu) = services.rooms.timeline.get_pdu(&body.event_id).await else { + return Err!(Request(NotFound("Event ID is not known to us or Event ID is invalid"))); }; is_report_valid( @@ -49,7 +47,8 @@ pub(crate) async fn report_event_route( &body.reason, body.score, &pdu, - )?; + ) + .await?; // send admin room message that we received the report with an @room ping for // urgency @@ -81,7 +80,8 @@ pub(crate) async fn report_event_route( HtmlEscape(body.reason.as_deref().unwrap_or("")) ), )) - .await; + .await + .ok(); Ok(report_content::v3::Response {}) } @@ -92,7 +92,7 @@ pub(crate) async fn report_event_route( /// check if score is in valid range /// check if report reasoning is less than or equal to 750 characters /// check if reporting user is in the reporting room -fn is_report_valid( +async fn is_report_valid( services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option, score: Option, pdu: &std::sync::Arc, ) -> Result<()> { @@ -123,8 +123,8 @@ fn is_report_valid( .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .any(|user_id| user_id == *sender_user) + .ready_any(|user_id| user_id == sender_user) + .await { return Err(Error::BadRequest( ErrorKind::NotFound, diff --git a/src/api/client/room.rs b/src/api/client/room.rs index 0112e76d..1edf85d8 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -2,6 +2,7 @@ use std::{cmp::max, collections::BTreeMap}; use axum::extract::State; use conduit::{debug_info, debug_warn, err, Err}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -74,7 +75,7 @@ pub(crate) async fn create_room_route( if !services.globals.allow_room_creation() && body.appservice_info.is_none() - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled.")); } @@ -86,7 +87,7 @@ pub(crate) async fn create_room_route( }; // check if room ID doesn't already exist instead of erroring on auth check - if services.rooms.short.get_shortroomid(&room_id)?.is_some() { + if services.rooms.short.get_shortroomid(&room_id).await.is_ok() { return Err(Error::BadRequest( ErrorKind::RoomInUse, "Room with that custom room ID already exists", @@ -95,7 +96,7 @@ pub(crate) async fn create_room_route( if body.visibility == room::Visibility::Public && services.globals.config.lockdown_public_room_directory - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { info!( @@ -118,7 +119,11 @@ pub(crate) async fn create_room_route( return Err!(Request(Forbidden("Publishing rooms to the room directory is not allowed"))); } - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await; let state_lock = services.rooms.state.mutex.lock(&room_id).await; let alias: Option = if let Some(alias) = &body.room_alias_name { @@ -218,6 +223,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 2. Let the room creator join @@ -229,11 +235,11 @@ pub(crate) async fn create_room_route( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: Some(body.is_direct), third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: None, join_authorized_via_users_server: None, }) @@ -247,6 +253,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 3. Power levels @@ -284,6 +291,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 4. Canonical room alias @@ -308,6 +316,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } @@ -335,6 +344,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 5.2 History Visibility @@ -355,6 +365,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 5.3 Guest Access @@ -378,6 +389,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 6. Events listed in initial_state @@ -410,6 +422,7 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .boxed() .await?; } @@ -432,6 +445,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } @@ -455,13 +469,17 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct).await { + if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct) + .boxed() + .await + { warn!(%e, "Failed to send invite"); } } @@ -475,7 +493,7 @@ pub(crate) async fn create_room_route( } if body.visibility == room::Visibility::Public { - services.rooms.directory.set_public(&room_id)?; + services.rooms.directory.set_public(&room_id); if services.globals.config.admin_room_notices { services @@ -505,13 +523,15 @@ pub(crate) async fn get_room_event_route( let event = services .rooms .timeline - .get_pdu(&body.event_id)? - .ok_or_else(|| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; + .get_pdu(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; if !services .rooms .state_accessor - .user_can_see_event(sender_user, &event.room_id, &body.event_id)? + .user_can_see_event(sender_user, &event.room_id, &body.event_id) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -541,7 +561,8 @@ pub(crate) async fn get_room_aliases_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -554,8 +575,9 @@ pub(crate) async fn get_room_aliases_route( .rooms .alias .local_aliases_for_room(&body.room_id) - .filter_map(Result::ok) - .collect(), + .map(ToOwned::to_owned) + .collect() + .await, }) } @@ -591,7 +613,8 @@ pub(crate) async fn upgrade_room_route( let _short_id = services .rooms .short - .get_or_create_shortroomid(&replacement_room)?; + .get_or_create_shortroomid(&replacement_room) + .await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; @@ -629,12 +652,12 @@ pub(crate) async fn upgrade_room_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .room_state_get(&body.room_id, &StateEventType::RoomCreate, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))? .content .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + )?; // Use the m.room.tombstone event as the predecessor let predecessor = Some(ruma::events::room::create::PreviousRoom::new( @@ -714,11 +737,11 @@ pub(crate) async fn upgrade_room_route( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: None, join_authorized_via_users_server: None, }) @@ -739,10 +762,11 @@ pub(crate) async fn upgrade_room_route( let event_content = match services .rooms .state_accessor - .room_state_get(&body.room_id, event_type, "")? + .room_state_get(&body.room_id, event_type, "") + .await { - Some(v) => v.content.clone(), - None => continue, // Skipping missing events. + Ok(v) => v.content.clone(), + Err(_) => continue, // Skipping missing events. }; services @@ -765,21 +789,23 @@ pub(crate) async fn upgrade_room_route( } // Moves any local aliases to the new room - for alias in services + let mut local_aliases = services .rooms .alias .local_aliases_for_room(&body.room_id) - .filter_map(Result::ok) - { + .boxed(); + + while let Some(alias) = local_aliases.next().await { services .rooms .alias - .remove_alias(&alias, sender_user) + .remove_alias(alias, sender_user) .await?; + services .rooms .alias - .set_alias(&alias, &replacement_room, sender_user)?; + .set_alias(alias, &replacement_room, sender_user)?; } // Get the old room power levels @@ -787,12 +813,12 @@ pub(crate) async fn upgrade_room_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))? .content .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + )?; // Setting events_default and invite to the greater of 50 and users_default + 1 let new_level = max( @@ -800,9 +826,7 @@ pub(crate) async fn upgrade_room_route( power_levels_event_content .users_default .checked_add(int!(1)) - .ok_or_else(|| { - Error::BadRequest(ErrorKind::BadJson, "users_default power levels event content is not valid") - })?, + .ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?, ); power_levels_event_content.events_default = new_level; power_levels_event_content.invite = new_level; @@ -921,8 +945,9 @@ async fn room_alias_check( if services .rooms .alias - .resolve_local_alias(&full_room_alias)? - .is_some() + .resolve_local_alias(&full_room_alias) + .await + .is_ok() { return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")); } diff --git a/src/api/client/search.rs b/src/api/client/search.rs index b143bd2c..7a061d49 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -1,6 +1,12 @@ use std::collections::BTreeMap; use axum::extract::State; +use conduit::{ + debug, + utils::{IterStream, ReadyExt}, + Err, +}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -13,7 +19,6 @@ use ruma::{ serde::Raw, uint, OwnedRoomId, }; -use tracing::debug; use crate::{Error, Result, Ruma}; @@ -32,14 +37,17 @@ pub(crate) async fn search_events_route( let filter = &search_criteria.filter; let include_state = &search_criteria.include_state; - let room_ids = filter.rooms.clone().unwrap_or_else(|| { + let room_ids = if let Some(room_ids) = &filter.rooms { + room_ids.clone() + } else { services .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) + .map(ToOwned::to_owned) .collect() - }); + .await + }; // Use limit or else 10, with maximum 100 let limit: usize = filter @@ -53,18 +61,21 @@ pub(crate) async fn search_events_route( if include_state.is_some_and(|include_state| include_state) { for room_id in &room_ids { - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { + return Err!(Request(Forbidden("You don't have permission to view this room."))); } // check if sender_user can see state events if services .rooms .state_accessor - .user_can_see_state_events(sender_user, room_id)? + .user_can_see_state_events(sender_user, room_id) + .await { let room_state = services .rooms @@ -87,10 +98,15 @@ pub(crate) async fn search_events_route( } } - let mut searches = Vec::new(); + let mut search_vecs = Vec::new(); for room_id in &room_ids { - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", @@ -100,12 +116,18 @@ pub(crate) async fn search_events_route( if let Some(search) = services .rooms .search - .search_pdus(room_id, &search_criteria.search_term)? + .search_pdus(room_id, &search_criteria.search_term) + .await { - searches.push(search.0.peekable()); + search_vecs.push(search.0); } } + let mut searches: Vec<_> = search_vecs + .iter() + .map(|vec| vec.iter().peekable()) + .collect(); + let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) { Some(Ok(s)) => s, Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")), @@ -118,8 +140,8 @@ pub(crate) async fn search_events_route( for _ in 0..next_batch { if let Some(s) = searches .iter_mut() - .map(|s| (s.peek().cloned(), s)) - .max_by_key(|(peek, _)| peek.clone()) + .map(|s| (s.peek().copied(), s)) + .max_by_key(|(peek, _)| *peek) .and_then(|(_, i)| i.next()) { results.push(s); @@ -127,42 +149,38 @@ pub(crate) async fn search_events_route( } let results: Vec<_> = results - .iter() + .into_iter() .skip(skip) - .filter_map(|result| { + .stream() + .filter_map(|id| services.rooms.timeline.get_pdu_from_id(id).map(Result::ok)) + .ready_filter(|pdu| !pdu.is_redacted()) + .filter_map(|pdu| async move { services .rooms - .timeline - .get_pdu_from_id(result) - .ok()? - .filter(|pdu| { - !pdu.is_redacted() - && services - .rooms - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) - .unwrap_or(false) - }) - .map(|pdu| pdu.to_room_event()) + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .await + .then_some(pdu) }) - .map(|result| { - Ok::<_, Error>(SearchResult { - context: EventContextResult { - end: None, - events_after: Vec::new(), - events_before: Vec::new(), - profile_info: BTreeMap::new(), - start: None, - }, - rank: None, - result: Some(result), - }) - }) - .filter_map(Result::ok) .take(limit) - .collect(); + .map(|pdu| pdu.to_room_event()) + .map(|result| SearchResult { + context: EventContextResult { + end: None, + events_after: Vec::new(), + events_before: Vec::new(), + profile_info: BTreeMap::new(), + start: None, + }, + rank: None, + result: Some(result), + }) + .collect() + .boxed() + .await; let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some()); + let next_batch = more_unloaded_results.then(|| next_batch.to_string()); Ok(search_events::v3::Response::new(ResultCategories { diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 4702b0ec..6347a2c9 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -1,5 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; +use conduit::{debug, err, info, utils::ReadyExt, warn, Err}; +use futures::StreamExt; use ruma::{ api::client::{ error::ErrorKind, @@ -19,7 +21,6 @@ use ruma::{ UserId, }; use serde::Deserialize; -use tracing::{debug, info, warn}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{utils, utils::hash, Error, Result, Ruma}; @@ -79,21 +80,22 @@ pub(crate) async fn login_route( UserId::parse(user) } else { warn!("Bad login type: {:?}", &body.login_info); - return Err(Error::BadRequest(ErrorKind::forbidden(), "Bad login type.")); + return Err!(Request(Forbidden("Bad login type."))); } .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; let hash = services .users - .password_hash(&user_id)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?; + .password_hash(&user_id) + .await + .map_err(|_| err!(Request(Forbidden("Wrong username or password."))))?; if hash.is_empty() { - return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated")); + return Err!(Request(UserDeactivated("The user has been deactivated"))); } if hash::verify_password(password, &hash).is_err() { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password.")); + return Err!(Request(Forbidden("Wrong username or password."))); } user_id @@ -112,15 +114,12 @@ pub(crate) async fn login_route( let username = token.claims.sub.to_lowercase(); - UserId::parse_with_server_name(username, services.globals.server_name()).map_err(|e| { - warn!("Failed to parse username from user logging in: {e}"); - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - })? + UserId::parse_with_server_name(username, services.globals.server_name()) + .map_err(|e| err!(Request(InvalidUsername(debug_error!(?e, "Failed to parse login username")))))? } else { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Token login is not supported (server has no jwt decoding key).", - )); + return Err!(Request(Unknown( + "Token login is not supported (server has no jwt decoding key)." + ))); } }, #[allow(deprecated)] @@ -169,23 +168,32 @@ pub(crate) async fn login_route( let token = utils::random_string(TOKEN_LENGTH); // Determine if device_id was provided and exists in the db for this user - let device_exists = body.device_id.as_ref().map_or(false, |device_id| { + let device_exists = if body.device_id.is_some() { services .users .all_device_ids(&user_id) - .any(|x| x.as_ref().map_or(false, |v| v == device_id)) - }); + .ready_any(|v| v == device_id) + .await + } else { + false + }; if device_exists { - services.users.set_token(&user_id, &device_id, &token)?; + services + .users + .set_token(&user_id, &device_id, &token) + .await?; } else { - services.users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - Some(client.to_string()), - )?; + services + .users + .create_device( + &user_id, + &device_id, + &token, + body.initial_device_display_name.clone(), + Some(client.to_string()), + ) + .await?; } // send client well-known if specified so the client knows to reconfigure itself @@ -228,10 +236,13 @@ pub(crate) async fn logout_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - services.users.remove_device(sender_user, sender_device)?; + services + .users + .remove_device(sender_user, sender_device) + .await; // send device list update for user after logout - services.users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user).await; Ok(logout::v3::Response::new()) } @@ -256,12 +267,14 @@ pub(crate) async fn logout_all_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in services.users.all_device_ids(sender_user).flatten() { - services.users.remove_device(sender_user, &device_id)?; - } + services + .users + .all_device_ids(sender_user) + .for_each(|device_id| services.users.remove_device(sender_user, device_id)) + .await; // send device list update for user after logout - services.users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user).await; Ok(logout_all::v3::Response::new()) } diff --git a/src/api/client/state.rs b/src/api/client/state.rs index fd049663..f9a4a763 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::extract::State; -use conduit::{debug_info, error, pdu::PduBuilder, Error, Result}; +use conduit::{err, error, pdu::PduBuilder, Err, Error, Result}; use ruma::{ api::client::{ error::ErrorKind, @@ -84,12 +84,10 @@ pub(crate) async fn get_state_events_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view the room state.", - )); + return Err!(Request(Forbidden("You don't have permission to view the room state."))); } Ok(get_state_events::v3::Response { @@ -120,22 +118,25 @@ pub(crate) async fn get_state_events_for_key_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view the room state.", - )); + return Err!(Request(Forbidden("You don't have permission to view the room state."))); } let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &body.event_type, &body.state_key)? - .ok_or_else(|| { - debug_info!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); - Error::BadRequest(ErrorKind::NotFound, "State event not found.") + .room_state_get(&body.room_id, &body.event_type, &body.state_key) + .await + .map_err(|_| { + err!(Request(NotFound(error!( + room_id = ?body.room_id, + event_type = ?body.event_type, + "State event not found in room.", + )))) })?; + if body .format .as_ref() @@ -204,7 +205,7 @@ async fn send_state_event_for_key_helper( async fn allowed_to_send_state_event( services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw, -) -> Result<()> { +) -> Result { match event_type { // Forbid m.room.encryption if encryption is disabled StateEventType::RoomEncryption => { @@ -214,7 +215,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made public StateEventType::RoomJoinRules => { - if let Some(admin_room_id) = services.admin.get_admin_room()? { + if let Ok(admin_room_id) = services.admin.get_admin_room().await { if admin_room_id == room_id { if let Ok(join_rule) = serde_json::from_str::(json.json().get()) { if join_rule.join_rule == JoinRule::Public { @@ -229,7 +230,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made world readable StateEventType::RoomHistoryVisibility => { - if let Some(admin_room_id) = services.admin.get_admin_room()? { + if let Ok(admin_room_id) = services.admin.get_admin_room().await { if admin_room_id == room_id { if let Ok(visibility_content) = serde_json::from_str::(json.json().get()) @@ -254,23 +255,27 @@ async fn allowed_to_send_state_event( } for alias in aliases { - if !services.globals.server_is_ours(alias.server_name()) - || services - .rooms - .alias - .resolve_local_alias(&alias)? - .filter(|room| room == room_id) // Make sure it's the right room - .is_none() + if !services.globals.server_is_ours(alias.server_name()) { + return Err!(Request(Forbidden("canonical_alias must be for this server"))); + } + + if !services + .rooms + .alias + .resolve_local_alias(&alias) + .await + .is_ok_and(|room| room == room_id) + // Make sure it's the right room { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You are only allowed to send canonical_alias events when its aliases already exist", - )); + return Err!(Request(Forbidden( + "You are only allowed to send canonical_alias events when its aliases already exist" + ))); } } } }, _ => (), } + Ok(()) } diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index eb534205..53d4f3c3 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -6,10 +6,14 @@ use std::{ use axum::extract::State; use conduit::{ - debug, error, - utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, - warn, Err, PduCount, + debug, err, error, is_equal_to, + utils::{ + math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, + IterStream, ReadyExt, + }, + warn, PduCount, }; +use futures::{pin_mut, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -108,7 +112,8 @@ pub(crate) async fn sync_events_route( if services.globals.allow_local_presence() { services .presence - .ping_presence(&sender_user, &body.set_presence)?; + .ping_presence(&sender_user, &body.set_presence) + .await?; } // Setup watchers, so if there's no response, we can wait for them @@ -124,7 +129,8 @@ pub(crate) async fn sync_events_route( Some(Filter::FilterDefinition(filter)) => filter, Some(Filter::FilterId(filter_id)) => services .users - .get_filter(&sender_user, &filter_id)? + .get_filter(&sender_user, &filter_id) + .await .unwrap_or_default(), }; @@ -157,7 +163,9 @@ pub(crate) async fn sync_events_route( services .users .keys_changed(sender_user.as_ref(), since, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); if services.globals.allow_local_presence() { @@ -168,13 +176,14 @@ pub(crate) async fn sync_events_route( .rooms .state_cache .rooms_joined(&sender_user) - .collect::>(); + .map(ToOwned::to_owned) + .collect::>() + .await; // Coalesce database writes for the remainder of this scope. let _cork = services.db.cork_and_flush(); for room_id in all_joined_rooms { - let room_id = room_id?; if let Ok(joined_room) = load_joined_room( &services, &sender_user, @@ -203,12 +212,14 @@ pub(crate) async fn sync_events_route( .rooms .state_cache .rooms_left(&sender_user) - .collect(); + .collect() + .await; + for result in all_left_rooms { handle_left_room( &services, since, - &result?.0, + &result.0, &sender_user, &mut left_rooms, &next_batch_string, @@ -224,10 +235,10 @@ pub(crate) async fn sync_events_route( .rooms .state_cache .rooms_invited(&sender_user) - .collect(); - for result in all_invited_rooms { - let (room_id, invite_state_events) = result?; + .collect() + .await; + for (room_id, invite_state_events) in all_invited_rooms { // Get and drop the lock to wait for remaining operations to finish let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await; drop(insert_lock); @@ -235,7 +246,9 @@ pub(crate) async fn sync_events_route( let invite_count = services .rooms .state_cache - .get_invite_count(&room_id, &sender_user)?; + .get_invite_count(&room_id, &sender_user) + .await + .ok(); // Invited before last sync if Some(since) >= invite_count { @@ -253,22 +266,8 @@ pub(crate) async fn sync_events_route( } for user_id in left_encrypted_users { - let dont_share_encrypted_room = services - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(Result::ok) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); + let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; + // If the user doesn't share an encrypted room with the target anymore, we need // to tell them if dont_share_encrypted_room { @@ -279,7 +278,8 @@ pub(crate) async fn sync_events_route( // Remove all to-device events the device received *last time* services .users - .remove_to_device_events(&sender_user, &sender_device, since)?; + .remove_to_device_events(&sender_user, &sender_device, since) + .await; let response = sync_events::v3::Response { next_batch: next_batch_string, @@ -298,7 +298,8 @@ pub(crate) async fn sync_events_route( account_data: GlobalAccountData { events: services .account_data - .changes_since(None, &sender_user, since)? + .changes_since(None, &sender_user, since) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) .collect(), @@ -309,11 +310,14 @@ pub(crate) async fn sync_events_route( }, device_one_time_keys_count: services .users - .count_one_time_keys(&sender_user, &sender_device)?, + .count_one_time_keys(&sender_user, &sender_device) + .await, to_device: ToDevice { events: services .users - .get_to_device_events(&sender_user, &sender_device)?, + .get_to_device_events(&sender_user, &sender_device) + .collect() + .await, }, // Fallback keys are not yet supported device_unused_fallback_key_types: None, @@ -351,14 +355,16 @@ async fn handle_left_room( let left_count = services .rooms .state_cache - .get_left_count(room_id, sender_user)?; + .get_left_count(room_id, sender_user) + .await + .ok(); // Left before last sync if Some(since) >= left_count { return Ok(()); } - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { // This is just a rejected invite, not a room we know // Insert a leave event anyways let event = PduEvent { @@ -408,27 +414,29 @@ async fn handle_left_room( let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, since)?; + .get_token_shortstatehash(room_id, since) + .await; let since_state_ids = match since_shortstatehash { - Some(s) => services.rooms.state_accessor.state_full_ids(s).await?, - None => HashMap::new(), + Ok(s) => services.rooms.state_accessor.state_full_ids(s).await?, + Err(_) => HashMap::new(), }; - let Some(left_event_id) = - services - .rooms - .state_accessor - .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str())? + let Ok(left_event_id) = services + .rooms + .state_accessor + .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str()) + .await else { error!("Left room but no left state event"); return Ok(()); }; - let Some(left_shortstatehash) = services + let Ok(left_shortstatehash) = services .rooms .state_accessor - .pdu_shortstatehash(&left_event_id)? + .pdu_shortstatehash(&left_event_id) + .await else { error!(event_id = %left_event_id, "Leave event has no state"); return Ok(()); @@ -443,14 +451,15 @@ async fn handle_left_room( let leave_shortstatekey = services .rooms .short - .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; + .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str()) + .await; left_state_ids.insert(leave_shortstatekey, left_event_id); let mut i: u8 = 0; for (key, id) in left_state_ids { if full_state || since_state_ids.get(&key) != Some(&id) { - let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key)?; + let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key).await?; if !lazy_load_enabled || event_type != StateEventType::RoomMember @@ -458,7 +467,7 @@ async fn handle_left_room( // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { error!("Pdu in state not found: {}", id); continue; }; @@ -495,19 +504,25 @@ async fn handle_left_room( async fn process_presence_updates( services: &Services, presence_updates: &mut HashMap, since: u64, syncing_user: &UserId, ) -> Result<()> { + let presence_since = services.presence.presence_since(since); + // Take presence updates - for (user_id, _, presence_bytes) in services.presence.presence_since(since) { + pin_mut!(presence_since); + while let Some((user_id, _, presence_bytes)) = presence_since.next().await { if !services .rooms .state_cache - .user_sees_user(syncing_user, &user_id)? + .user_sees_user(syncing_user, &user_id) + .await { continue; } let presence_event = services .presence - .from_json_bytes_to_event(&presence_bytes, &user_id)?; + .from_json_bytes_to_event(&presence_bytes, &user_id) + .await?; + match presence_updates.entry(user_id) { Entry::Vacant(slot) => { slot.insert(presence_event); @@ -551,14 +566,14 @@ async fn load_joined_room( let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); - let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10)?; + let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10).await?; let send_notification_counts = !timeline_pdus.is_empty() || services .rooms .user - .last_notification_read(sender_user, room_id)? - > since; + .last_notification_read(sender_user, room_id) + .await > since; let mut timeline_users = HashSet::new(); for (_, event) in &timeline_pdus { @@ -568,355 +583,384 @@ async fn load_joined_room( services .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount) - .await?; + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount); // Database queries: - let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { - return Err!(Database(error!("Room {room_id} has no state"))); - }; + let current_shortstatehash = services + .rooms + .state + .get_room_shortstatehash(room_id) + .await + .map_err(|_| err!(Database(error!("Room {room_id} has no state"))))?; let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, since)?; + .get_token_shortstatehash(room_id, since) + .await + .ok(); - let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = - if timeline_pdus.is_empty() && since_shortstatehash == Some(current_shortstatehash) { - // No state changes - (Vec::new(), None, None, false, Vec::new()) - } else { - // Calculates joined_member_count, invited_member_count and heroes - let calculate_counts = || { - let joined_member_count = services - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(0); - let invited_member_count = services - .rooms - .state_cache - .room_invited_count(room_id)? - .unwrap_or(0); + let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = if timeline_pdus + .is_empty() + && (since_shortstatehash.is_none() || since_shortstatehash.is_some_and(is_equal_to!(current_shortstatehash))) + { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || async { + let joined_member_count = services + .rooms + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(0); - // Recalculate heroes (first 5 members) - let mut heroes: Vec = Vec::with_capacity(5); + let invited_member_count = services + .rooms + .state_cache + .room_invited_count(room_id) + .await + .unwrap_or(0); - if joined_member_count.saturating_add(invited_member_count) <= 5 { - // Go through all PDUs and for each member event, check if the user is still - // joined or invited until we have 5 or we reach the end + if joined_member_count.saturating_add(invited_member_count) > 5 { + return Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), Vec::new())); + } - for hero in services - .rooms - .timeline - .all_pdus(sender_user, room_id)? - .filter_map(Result::ok) // Ignore all broken pdus - .filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) - .map(|(_, pdu)| { - let content: RoomMemberEventContent = serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + // Go through all PDUs and for each member event, check if the user is still + // joined or invited until we have 5 or we reach the end - if let Some(state_key) = &pdu.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + // Recalculate heroes (first 5 members) + let heroes = services + .rooms + .timeline + .all_pdus(sender_user, room_id) + .await? + .ready_filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) + .filter_map(|(_, pdu)| async move { + let Ok(content) = serde_json::from_str::(pdu.content.get()) else { + return None; + }; - // The membership was and still is invite or join - if matches!(content.membership, MembershipState::Join | MembershipState::Invite) - && (services.rooms.state_cache.is_joined(&user_id, room_id)? - || services.rooms.state_cache.is_invited(&user_id, room_id)?) - { - Ok::<_, Error>(Some(user_id)) - } else { - Ok(None) - } - } else { - Ok(None) - } - }) - .filter_map(Result::ok) - // Filter for possible heroes - .flatten() + let Some(state_key) = &pdu.state_key else { + return None; + }; + + let Ok(user_id) = UserId::parse(state_key) else { + return None; + }; + + if user_id == sender_user { + return None; + } + + // The membership was and still is invite or join + if !matches!(content.membership, MembershipState::Join | MembershipState::Invite) { + return None; + } + + if !services + .rooms + .state_cache + .is_joined(&user_id, room_id) + .await && services + .rooms + .state_cache + .is_invited(&user_id, room_id) + .await { - if heroes.contains(&hero) || hero == sender_user { - continue; - } + return None; + } - heroes.push(hero); + Some(user_id) + }) + .collect::>() + .await; + + Ok::<_, Error>(( + Some(joined_member_count), + Some(invited_member_count), + heroes.into_iter().collect::>(), + )) + }; + + let since_sender_member: Option = if let Some(short) = since_shortstatehash { + services + .rooms + .state_accessor + .state_get(short, &StateEventType::RoomMember, sender_user.as_str()) + .await + .and_then(|pdu| serde_json::from_str(pdu.content.get()).map_err(Into::into)) + .ok() + } else { + None + }; + + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + + if since_shortstatehash.is_none() || joined_since_last_sync { + // Probably since = 0, we will do an initial sync + + let (joined_member_count, invited_member_count, heroes) = calculate_counts().await?; + + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); + + let mut i: u8 = 0; + for (shortstatekey, id) in current_state_ids { + let (event_type, state_key) = services + .rooms + .short + .get_statekey_from_short(shortstatekey) + .await?; + + if event_type != StateEventType::RoomMember { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; + state_events.push(pdu); + + i = i.wrapping_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } else if !lazy_load_enabled + || full_state + || timeline_users.contains(&state_key) + // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 + || (cfg!(feature = "element_hacks") && *sender_user == state_key) + { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; + + // This check is in case a bad user ID made it into the database + if let Ok(uid) = UserId::parse(&state_key) { + lazy_loaded.insert(uid); + } + state_events.push(pdu); + + i = i.wrapping_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; } } + } - Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), heroes)) - }; + // Reset lazy loading because this is an initial sync + services + .rooms + .lazy_loading + .lazy_load_reset(sender_user, sender_device, room_id) + .await; - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services - .rooms - .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + // The state_events above should contain all timeline_users, let's mark them as + // lazy loaded. + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); - let joined_since_last_sync = - since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + (heroes, joined_member_count, invited_member_count, true, state_events) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.expect("missing since_shortstatehash on incremental sync"); - if since_shortstatehash.is_none() || joined_since_last_sync { - // Probably since = 0, we will do an initial sync - - let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; + let mut delta_state_events = Vec::new(); + if since_shortstatehash != current_shortstatehash { let current_state_ids = services .rooms .state_accessor .state_full_ids(current_shortstatehash) .await?; - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); - - let mut i: u8 = 0; - for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = services - .rooms - .short - .get_statekey_from_short(shortstatekey)?; - - if event_type != StateEventType::RoomMember { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - state_events.push(pdu); - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled - || full_state - || timeline_users.contains(&state_key) - // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 - || (cfg!(feature = "element_hacks") && *sender_user == state_key) - { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - - // This check is in case a bad user ID made it into the database - if let Ok(uid) = UserId::parse(&state_key) { - lazy_loaded.insert(uid); - } - state_events.push(pdu); - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - } - - // Reset lazy loading because this is an initial sync - services - .rooms - .lazy_loading - .lazy_load_reset(sender_user, sender_device, room_id)?; - - // The state_events above should contain all timeline_users, let's mark them as - // lazy loaded. - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; - - (heroes, joined_member_count, invited_member_count, true, state_events) - } else { - // Incremental /sync - let since_shortstatehash = since_shortstatehash.unwrap(); - - let mut delta_state_events = Vec::new(); - - if since_shortstatehash != current_shortstatehash { - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let since_state_ids = services - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; - - for (key, id) in current_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - - delta_state_events.push(pdu); - tokio::task::yield_now().await; - } - } - } - - let encrypted_room = services + let since_state_ids = services .rooms .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); + .state_full_ids(since_shortstatehash) + .await?; - let since_encryption = services.rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; - - // Calculations: - let new_encrypted_room = encrypted_room && since_encryption.is_none(); - - let send_member_count = delta_state_events - .iter() - .any(|event| event.kind == TimelineEventType::RoomMember); - - if encrypted_room { - for state_event in &delta_state_events { - if state_event.kind != TimelineEventType::RoomMember { + for (key, id) in current_state_ids { + if full_state || since_state_ids.get(&key) != Some(&id) { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; - } + }; - if let Some(state_key) = &state_event.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - - if user_id == sender_user { - continue; - } - - let new_membership = - serde_json::from_str::(state_event.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; - - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(services, sender_user, &user_id, room_id)? { - device_list_updates.insert(user_id); - } - }, - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - }, - _ => {}, - } - } + delta_state_events.push(pdu); + tokio::task::yield_now().await; } } + } - if joined_since_last_sync && encrypted_room || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_updates.extend( - services - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - !share_encrypted_room(services, sender_user, user_id, room_id).unwrap_or(false) - }), - ); - } + let encrypted_room = services + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); - let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts()? - } else { - (None, None, Vec::new()) - }; + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; - let mut state_events = delta_state_events; - let mut lazy_loaded = HashSet::new(); + // Calculations: + let new_encrypted_room = encrypted_room && since_encryption.is_err(); - // Mark all member events we're returning as lazy-loaded - for pdu in &state_events { - if pdu.kind == TimelineEventType::RoomMember { - match UserId::parse( - pdu.state_key - .as_ref() - .expect("State event has state key") - .clone(), - ) { - Ok(state_key_userid) => { - lazy_loaded.insert(state_key_userid); - }, - Err(e) => error!("Invalid state key for member event: {}", e), - } - } - } + let send_member_count = delta_state_events + .iter() + .any(|event| event.kind == TimelineEventType::RoomMember); - // Fetch contextual member state events for events from the timeline, and - // mark them as lazy-loaded as well. - for (_, event) in &timeline_pdus { - if lazy_loaded.contains(&event.sender) { + if encrypted_room { + for state_event in &delta_state_events { + if state_event.kind != TimelineEventType::RoomMember { continue; } - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - room_id, - &event.sender, - )? || lazy_load_send_redundant - { - if let Some(member_event) = services.rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomMember, - event.sender.as_str(), - )? { - lazy_loaded.insert(event.sender.clone()); - state_events.push(member_event); + if let Some(state_key) = &state_event.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + + if user_id == sender_user { + continue; + } + + let new_membership = serde_json::from_str::(state_event.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database."))? + .membership; + + match new_membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(services, sender_user, &user_id, Some(room_id)).await { + device_list_updates.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, } } } + } - services + if joined_since_last_sync && encrypted_room || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_updates.extend( + services + .rooms + .state_cache + .room_members(room_id) + .ready_filter(|user_id| { + // Don't send key updates from the sender to the sender + sender_user != *user_id + }) + .filter_map(|user_id| async move { + // Only send keys if the sender doesn't share an encrypted room with the target + // already + (!share_encrypted_room(services, sender_user, user_id, Some(room_id)).await) + .then_some(user_id.to_owned()) + }) + .collect::>() + .await, + ); + } + + let (joined_member_count, invited_member_count, heroes) = if send_member_count { + calculate_counts().await? + } else { + (None, None, Vec::new()) + }; + + let mut state_events = delta_state_events; + let mut lazy_loaded = HashSet::new(); + + // Mark all member events we're returning as lazy-loaded + for pdu in &state_events { + if pdu.kind == TimelineEventType::RoomMember { + match UserId::parse( + pdu.state_key + .as_ref() + .expect("State event has state key") + .clone(), + ) { + Ok(state_key_userid) => { + lazy_loaded.insert(state_key_userid); + }, + Err(e) => error!("Invalid state key for member event: {}", e), + } + } + } + + // Fetch contextual member state events for events from the timeline, and + // mark them as lazy-loaded as well. + for (_, event) in &timeline_pdus { + if lazy_loaded.contains(&event.sender) { + continue; + } + + if !services .rooms .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; - - ( - heroes, - joined_member_count, - invited_member_count, - joined_since_last_sync, - state_events, - ) + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant + { + if let Ok(member_event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, event.sender.as_str()) + .await + { + lazy_loaded.insert(event.sender.clone()); + state_events.push(member_event); + } + } } - }; + + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); + + ( + heroes, + joined_member_count, + invited_member_count, + joined_since_last_sync, + state_events, + ) + } + }; // Look for device list updates in this room device_list_updates.extend( services .users .keys_changed(room_id.as_ref(), since, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); let notification_count = if send_notification_counts { @@ -924,7 +968,8 @@ async fn load_joined_room( services .rooms .user - .notification_count(sender_user, room_id)? + .notification_count(sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ) @@ -937,7 +982,8 @@ async fn load_joined_room( services .rooms .user - .highlight_count(sender_user, room_id)? + .highlight_count(sender_user, room_id) + .await .try_into() .expect("highlight count can't go that high"), ) @@ -966,9 +1012,9 @@ async fn load_joined_room( .rooms .read_receipt .readreceipts_since(room_id, since) - .filter_map(Result::ok) // Filter out buggy events .map(|(_, _, v)| v) - .collect(); + .collect() + .await; if services.rooms.typing.last_typing_update(room_id).await? > since { edus.push( @@ -985,13 +1031,15 @@ async fn load_joined_room( services .rooms .user - .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; + .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash) + .await; Ok(JoinedRoom { account_data: RoomAccountData { events: services .account_data - .changes_since(Some(room_id), sender_user, since)? + .changes_since(Some(room_id), sender_user, since) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), @@ -1023,41 +1071,37 @@ async fn load_joined_room( }) } -fn load_timeline( +async fn load_timeline( services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let timeline_pdus; let limited = if services .rooms .timeline - .last_timeline_count(sender_user, room_id)? + .last_timeline_count(sender_user, room_id) + .await? > roomsincecount { let mut non_timeline_pdus = services .rooms .timeline - .pdus_until(sender_user, room_id, PduCount::max())? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .take_while(|(pducount, _)| pducount > &roomsincecount); + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .ready_take_while(|(pducount, _)| pducount > &roomsincecount); // Take the last events for the timeline timeline_pdus = non_timeline_pdus .by_ref() .take(usize_from_u64_truncated(limit)) .collect::>() + .await .into_iter() .rev() .collect::>(); // They /sync response doesn't always return all messages, so we say the output // is limited unless there are events in non_timeline_pdus - non_timeline_pdus.next().is_some() + non_timeline_pdus.next().await.is_some() } else { timeline_pdus = Vec::new(); false @@ -1065,26 +1109,23 @@ fn load_timeline( Ok((timeline_pdus, limited)) } -fn share_encrypted_room( - services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, -) -> Result { - Ok(services +async fn share_encrypted_room( + services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: Option<&RoomId>, +) -> bool { + services .rooms .user - .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? - .filter_map(Result::ok) - .filter(|room_id| room_id != ignore_room) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) + .get_shared_rooms(sender_user, user_id) + .ready_filter(|&room_id| Some(room_id) != ignore_room) + .any(|other_room_id| async move { + services + .rooms + .state_accessor + .room_state_get(other_room_id, &StateEventType::RoomEncryption, "") + .await + .is_ok() }) - .any(|encrypted| encrypted)) + .await } /// POST `/_matrix/client/unstable/org.matrix.msc3575/sync` @@ -1114,7 +1155,7 @@ pub(crate) async fn sync_events_v4_route( if globalsince != 0 && !services - .users + .sync .remembered(sender_user.clone(), sender_device.clone(), conn_id.clone()) { debug!("Restarting sync stream because it was gone from the database"); @@ -1127,41 +1168,43 @@ pub(crate) async fn sync_events_v4_route( if globalsince == 0 { services - .users + .sync .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone()); } // Get sticky parameters from cache let known_rooms = services - .users + .sync .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); let all_joined_rooms = services .rooms .state_cache .rooms_joined(&sender_user) - .filter_map(Result::ok) - .collect::>(); + .map(ToOwned::to_owned) + .collect::>() + .await; let all_invited_rooms = services .rooms .state_cache .rooms_invited(&sender_user) - .filter_map(Result::ok) .map(|r| r.0) - .collect::>(); + .collect::>() + .await; let all_rooms = all_joined_rooms .iter() - .cloned() - .chain(all_invited_rooms.clone()) + .chain(all_invited_rooms.iter()) + .map(Clone::clone) .collect(); if body.extensions.to_device.enabled.unwrap_or(false) { services .users - .remove_to_device_events(&sender_user, &sender_device, globalsince)?; + .remove_to_device_events(&sender_user, &sender_device, globalsince) + .await; } let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in @@ -1179,7 +1222,8 @@ pub(crate) async fn sync_events_v4_route( if body.extensions.account_data.enabled.unwrap_or(false) { account_data.global = services .account_data - .changes_since(None, &sender_user, globalsince)? + .changes_since(None, &sender_user, globalsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) .collect(); @@ -1190,7 +1234,8 @@ pub(crate) async fn sync_events_v4_route( room.clone(), services .account_data - .changes_since(Some(&room), &sender_user, globalsince)? + .changes_since(Some(&room), &sender_user, globalsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), @@ -1205,40 +1250,42 @@ pub(crate) async fn sync_events_v4_route( services .users .keys_changed(sender_user.as_ref(), globalsince, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); for room_id in &all_joined_rooms { - let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { - error!("Room {} has no state", room_id); + let Ok(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id).await else { + error!("Room {room_id} has no state"); continue; }; let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, globalsince)?; + .get_token_shortstatehash(room_id, globalsince) + .await + .ok(); - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services - .rooms - .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + let since_sender_member: Option = if let Some(short) = since_shortstatehash { + services + .rooms + .state_accessor + .state_get(short, &StateEventType::RoomMember, sender_user.as_str()) + .await + .and_then(|pdu| serde_json::from_str(pdu.content.get()).map_err(Into::into)) + .ok() + } else { + None + }; let encrypted_room = services .rooms .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); if let Some(since_shortstatehash) = since_shortstatehash { // Skip if there are only timeline changes @@ -1246,22 +1293,24 @@ pub(crate) async fn sync_events_v4_route( continue; } - let since_encryption = services.rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; let joined_since_last_sync = since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - let new_encrypted_room = encrypted_room && since_encryption.is_none(); + let new_encrypted_room = encrypted_room && since_encryption.is_err(); + if encrypted_room { let current_state_ids = services .rooms .state_accessor .state_full_ids(current_shortstatehash) .await?; + let since_state_ids = services .rooms .state_accessor @@ -1270,8 +1319,8 @@ pub(crate) async fn sync_events_v4_route( for (key, id) in current_state_ids { if since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; if pdu.kind == TimelineEventType::RoomMember { @@ -1291,7 +1340,9 @@ pub(crate) async fn sync_events_v4_route( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(&services, &sender_user, &user_id, room_id)? { + if !share_encrypted_room(&services, &sender_user, &user_id, Some(room_id)) + .await + { device_list_changes.insert(user_id); } }, @@ -1306,22 +1357,25 @@ pub(crate) async fn sync_events_v4_route( } } if joined_since_last_sync || new_encrypted_room { + let sender_user = &sender_user; // If the user is in a new encrypted room, give them all joined users device_list_changes.extend( services .rooms .state_cache .room_members(room_id) - .flatten() - .filter(|user_id| { + .ready_filter(|user_id| { // Don't send key updates from the sender to the sender - &sender_user != user_id + sender_user != user_id }) - .filter(|user_id| { + .filter_map(|user_id| async move { // Only send keys if the sender doesn't share an encrypted room with the target // already - !share_encrypted_room(&services, &sender_user, user_id, room_id).unwrap_or(false) - }), + (!share_encrypted_room(&services, sender_user, user_id, Some(room_id)).await) + .then_some(user_id.to_owned()) + }) + .collect::>() + .await, ); } } @@ -1331,26 +1385,15 @@ pub(crate) async fn sync_events_v4_route( services .users .keys_changed(room_id.as_ref(), globalsince, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); } + for user_id in left_encrypted_users { - let dont_share_encrypted_room = services - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(Result::ok) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); + let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; + // If the user doesn't share an encrypted room with the target anymore, we need // to tell them if dont_share_encrypted_room { @@ -1362,7 +1405,7 @@ pub(crate) async fn sync_events_v4_route( let mut lists = BTreeMap::new(); let mut todo_rooms = BTreeMap::new(); // and required state - for (list_id, list) in body.lists { + for (list_id, list) in &body.lists { let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) { Some(true) => &all_invited_rooms, Some(false) => &all_joined_rooms, @@ -1371,23 +1414,23 @@ pub(crate) async fn sync_events_v4_route( let active_rooms = match list.filters.clone().map(|f| f.not_room_types) { Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(active_rooms, State(services), &value, true), + Some(value) => filter_rooms(active_rooms, State(services), &value, true).await, None => active_rooms.clone(), }; let active_rooms = match list.filters.clone().map(|f| f.room_types) { Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(&active_rooms, State(services), &value, false), + Some(value) => filter_rooms(&active_rooms, State(services), &value, false).await, None => active_rooms, }; let mut new_known_rooms = BTreeSet::new(); + let ranges = list.ranges.clone(); lists.insert( list_id.clone(), sync_events::v4::SyncList { - ops: list - .ranges + ops: ranges .into_iter() .map(|mut r| { r.0 = r.0.clamp( @@ -1396,29 +1439,34 @@ pub(crate) async fn sync_events_v4_route( ); r.1 = r.1.clamp(r.0, UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX)); + let room_ids = if !active_rooms.is_empty() { active_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec() } else { Vec::new() }; + new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { let todo_room = todo_rooms .entry(room_id.clone()) .or_insert((BTreeSet::new(), 0, u64::MAX)); + let limit = list .room_details .timeline_limit .map_or(10, u64::from) .min(100); + todo_room .0 .extend(list.room_details.required_state.iter().cloned()); + todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date todo_room.2 = todo_room.2.min( known_rooms - .get(&list_id) + .get(list_id.as_str()) .and_then(|k| k.get(room_id)) .copied() .unwrap_or(0), @@ -1438,11 +1486,11 @@ pub(crate) async fn sync_events_v4_route( ); if let Some(conn_id) = &body.conn_id { - services.users.update_sync_known_rooms( + services.sync.update_sync_known_rooms( sender_user.clone(), sender_device.clone(), conn_id.clone(), - list_id, + list_id.clone(), new_known_rooms, globalsince, ); @@ -1451,7 +1499,7 @@ pub(crate) async fn sync_events_v4_route( let mut known_subscription_rooms = BTreeSet::new(); for (room_id, room) in &body.room_subscriptions { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { continue; } let todo_room = todo_rooms @@ -1477,7 +1525,7 @@ pub(crate) async fn sync_events_v4_route( } if let Some(conn_id) = &body.conn_id { - services.users.update_sync_known_rooms( + services.sync.update_sync_known_rooms( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1488,7 +1536,7 @@ pub(crate) async fn sync_events_v4_route( } if let Some(conn_id) = &body.conn_id { - services.users.update_sync_subscriptions( + services.sync.update_sync_subscriptions( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1509,12 +1557,13 @@ pub(crate) async fn sync_events_v4_route( .rooms .state_cache .invite_state(&sender_user, room_id) - .unwrap_or(None); + .await + .ok(); (timeline_pdus, limited) = (Vec::new(), true); } else { (timeline_pdus, limited) = - match load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit) { + match load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit).await { Ok(value) => value, Err(err) => { warn!("Encountered missing timeline in {}, error {}", room_id, err); @@ -1527,17 +1576,20 @@ pub(crate) async fn sync_events_v4_route( room_id.clone(), services .account_data - .changes_since(Some(room_id), &sender_user, *roomsince)? + .changes_since(Some(room_id), &sender_user, *roomsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), ); - let room_receipts = services + let vector: Vec<_> = services .rooms .read_receipt - .readreceipts_since(room_id, *roomsince); - let vector: Vec<_> = room_receipts.into_iter().collect(); + .readreceipts_since(room_id, *roomsince) + .collect() + .await; + let receipt_size = vector.len(); receipts .rooms @@ -1584,41 +1636,42 @@ pub(crate) async fn sync_events_v4_route( let required_state = required_state_request .iter() - .map(|state| { + .stream() + .filter_map(|state| async move { services .rooms .state_accessor .room_state_get(room_id, &state.0, &state.1) + .await + .map(|s| s.to_sync_state_event()) + .ok() }) - .filter_map(Result::ok) - .flatten() - .map(|state| state.to_sync_state_event()) - .collect(); + .collect() + .await; // Heroes let heroes = services .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|member| member != &sender_user) - .map(|member| { - Ok::<_, Error>( - services - .rooms - .state_accessor - .get_member(room_id, &member)? - .map(|memberevent| SlidingSyncRoomHero { - user_id: member, - name: memberevent.displayname, - avatar: memberevent.avatar_url, - }), - ) + .ready_filter(|member| member != &sender_user) + .filter_map(|member| async move { + services + .rooms + .state_accessor + .get_member(room_id, member) + .await + .map(|memberevent| SlidingSyncRoomHero { + user_id: member.to_owned(), + name: memberevent.displayname, + avatar: memberevent.avatar_url, + }) + .ok() }) - .filter_map(Result::ok) - .flatten() .take(5) - .collect::>(); + .collect::>() + .await; + let name = match heroes.len().cmp(&(1_usize)) { Ordering::Greater => { let firsts = heroes[1..] @@ -1626,10 +1679,12 @@ pub(crate) async fn sync_events_v4_route( .map(|h| h.name.clone().unwrap_or_else(|| h.user_id.to_string())) .collect::>() .join(", "); + let last = heroes[0] .name .clone() .unwrap_or_else(|| heroes[0].user_id.to_string()); + Some(format!("{firsts} and {last}")) }, Ordering::Equal => Some( @@ -1650,11 +1705,17 @@ pub(crate) async fn sync_events_v4_route( rooms.insert( room_id.clone(), sync_events::v4::SlidingSyncRoom { - name: services.rooms.state_accessor.get_name(room_id)?.or(name), + name: services + .rooms + .state_accessor + .get_name(room_id) + .await + .ok() + .or(name), avatar: if let Some(heroes_avatar) = heroes_avatar { ruma::JsOption::Some(heroes_avatar) } else { - match services.rooms.state_accessor.get_avatar(room_id)? { + match services.rooms.state_accessor.get_avatar(room_id).await { ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), ruma::JsOption::Null => ruma::JsOption::Null, ruma::JsOption::Undefined => ruma::JsOption::Undefined, @@ -1668,7 +1729,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .user - .highlight_count(&sender_user, room_id)? + .highlight_count(&sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ), @@ -1676,7 +1738,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .user - .notification_count(&sender_user, room_id)? + .notification_count(&sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ), @@ -1689,7 +1752,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .state_cache - .room_joined_count(room_id)? + .room_joined_count(room_id) + .await .unwrap_or(0) .try_into() .unwrap_or_else(|_| uint!(0)), @@ -1698,7 +1762,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .state_cache - .room_invited_count(room_id)? + .room_invited_count(room_id) + .await .unwrap_or(0) .try_into() .unwrap_or_else(|_| uint!(0)), @@ -1732,7 +1797,9 @@ pub(crate) async fn sync_events_v4_route( Some(sync_events::v4::ToDevice { events: services .users - .get_to_device_events(&sender_user, &sender_device)?, + .get_to_device_events(&sender_user, &sender_device) + .collect() + .await, next_batch: next_batch.to_string(), }) } else { @@ -1745,7 +1812,8 @@ pub(crate) async fn sync_events_v4_route( }, device_one_time_keys_count: services .users - .count_one_time_keys(&sender_user, &sender_device)?, + .count_one_time_keys(&sender_user, &sender_device) + .await, // Fallback keys are not yet supported device_unused_fallback_key_types: None, }, @@ -1759,25 +1827,26 @@ pub(crate) async fn sync_events_v4_route( }) } -fn filter_rooms( +async fn filter_rooms( rooms: &[OwnedRoomId], State(services): State, filter: &[RoomTypeFilter], negate: bool, ) -> Vec { - return rooms + rooms .iter() - .filter(|r| match services.rooms.state_accessor.get_room_type(r) { - Err(e) => { - warn!("Requested room type for {}, but could not retrieve with error {}", r, e); - false - }, - Ok(result) => { - let result = RoomTypeFilter::from(result); - if negate { - !filter.contains(&result) - } else { - filter.is_empty() || filter.contains(&result) - } - }, + .stream() + .filter_map(|r| async move { + match services.rooms.state_accessor.get_room_type(r).await { + Err(_) => false, + Ok(result) => { + let result = RoomTypeFilter::from(Some(result)); + if negate { + !filter.contains(&result) + } else { + filter.is_empty() || filter.contains(&result) + } + }, + } + .then_some(r.to_owned()) }) - .cloned() - .collect(); + .collect() + .await } diff --git a/src/api/client/tag.rs b/src/api/client/tag.rs index 301568e5..bcd0f817 100644 --- a/src/api/client/tag.rs +++ b/src/api/client/tag.rs @@ -23,10 +23,11 @@ pub(crate) async fn update_tag_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), @@ -41,12 +42,15 @@ pub(crate) async fn update_tag_route( .tags .insert(body.tag.clone().into(), body.tag_info.clone()); - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(create_tag::v3::Response {}) } @@ -63,10 +67,11 @@ pub(crate) async fn delete_tag_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), @@ -78,12 +83,15 @@ pub(crate) async fn delete_tag_route( tags_event.content.tags.remove(&body.tag.clone().into()); - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(delete_tag::v3::Response {}) } @@ -100,10 +108,11 @@ pub(crate) async fn get_tags_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 8100f0e6..50f6cdfb 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::PduEvent; +use futures::StreamExt; use ruma::{ api::client::{error::ErrorKind, threads::get_threads}, uint, @@ -27,20 +29,23 @@ pub(crate) async fn get_threads_route( u64::MAX }; - let threads = services + let room_id = &body.room_id; + let threads: Vec<(u64, PduEvent)> = services .rooms .threads - .threads_until(sender_user, &body.room_id, from, &body.include)? + .threads_until(sender_user, &body.room_id, from, &body.include) + .await? .take(limit) - .filter_map(Result::ok) - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect::>(); + .collect() + .await; let next_batch = threads.last().map(|(count, _)| count.to_string()); diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 1f557ad7..2b37a9ec 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use axum::extract::State; use conduit::{Error, Result}; +use futures::StreamExt; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -24,8 +25,9 @@ pub(crate) async fn send_event_to_device_route( // Check if this is a new transaction id if services .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? - .is_some() + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await + .is_ok() { return Ok(send_event_to_device::v3::Response {}); } @@ -53,31 +55,35 @@ pub(crate) async fn send_event_to_device_route( continue; } + let event_type = &body.event_type.to_string(); + + let event = event + .deserialize_as() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?; + match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services.users.add_to_device_event( - sender_user, - target_user_id, - target_device_id, - &body.event_type.to_string(), - event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, - )?; + services + .users + .add_to_device_event(sender_user, target_user_id, target_device_id, event_type, event) + .await; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services.users.all_device_ids(target_user_id) { - services.users.add_to_device_event( - sender_user, - target_user_id, - &target_device_id?, - &body.event_type.to_string(), - event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, - )?; - } + let (event_type, event) = (&event_type, &event); + services + .users + .all_device_ids(target_user_id) + .for_each(|target_device_id| { + services.users.add_to_device_event( + sender_user, + target_user_id, + target_device_id, + event_type, + event.clone(), + ) + }) + .await; }, } } @@ -86,7 +92,7 @@ pub(crate) async fn send_event_to_device_route( // Save transaction id with empty data services .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; + .add_txnid(sender_user, sender_device, &body.txn_id, &[]); Ok(send_event_to_device::v3::Response {}) } diff --git a/src/api/client/typing.rs b/src/api/client/typing.rs index a06648e0..932d221e 100644 --- a/src/api/client/typing.rs +++ b/src/api/client/typing.rs @@ -16,7 +16,8 @@ pub(crate) async fn create_typing_event_route( if !services .rooms .state_cache - .is_joined(sender_user, &body.room_id)? + .is_joined(sender_user, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "You are not in this room.")); } diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index ab4703fd..dc570295 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -2,7 +2,8 @@ use std::collections::BTreeMap; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{warn, Err}; +use conduit::Err; +use futures::StreamExt; use ruma::{ api::{ client::{ @@ -45,7 +46,7 @@ pub(crate) async fn get_mutual_rooms_route( )); } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { return Ok(mutual_rooms::unstable::Response { joined: vec![], next_batch_token: None, @@ -55,9 +56,10 @@ pub(crate) async fn get_mutual_rooms_route( let mutual_rooms: Vec = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - .filter_map(Result::ok) - .collect(); + .get_shared_rooms(sender_user, &body.user_id) + .map(ToOwned::to_owned) + .collect() + .await; Ok(mutual_rooms::unstable::Response { joined: mutual_rooms, @@ -99,7 +101,7 @@ pub(crate) async fn get_room_summary( let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?; - if !services.rooms.metadata.exists(&room_id)? { + if !services.rooms.metadata.exists(&room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } @@ -108,7 +110,7 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .is_world_readable(&room_id) - .unwrap_or(false) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -122,50 +124,58 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .get_canonical_alias(&room_id) - .unwrap_or(None), + .await + .ok(), avatar_url: services .rooms .state_accessor - .get_avatar(&room_id)? + .get_avatar(&room_id) + .await .into_option() .unwrap_or_default() .url, - guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id)?, - name: services - .rooms - .state_accessor - .get_name(&room_id) - .unwrap_or(None), + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await, + name: services.rooms.state_accessor.get_name(&room_id).await.ok(), num_joined_members: services .rooms .state_cache .room_joined_count(&room_id) - .unwrap_or_default() - .unwrap_or_else(|| { - warn!("Room {room_id} has no member count"); - 0 - }) - .try_into() - .expect("user count should not be that big"), + .await + .unwrap_or(0) + .try_into()?, topic: services .rooms .state_accessor .get_room_topic(&room_id) - .unwrap_or(None), + .await + .ok(), world_readable: services .rooms .state_accessor .is_world_readable(&room_id) - .unwrap_or(false), - join_rule: services.rooms.state_accessor.get_join_rule(&room_id)?.0, - room_type: services.rooms.state_accessor.get_room_type(&room_id)?, - room_version: Some(services.rooms.state.get_room_version(&room_id)?), + .await, + join_rule: services + .rooms + .state_accessor + .get_join_rule(&room_id) + .await + .unwrap_or_default() + .0, + room_type: services + .rooms + .state_accessor + .get_room_type(&room_id) + .await + .ok(), + room_version: services.rooms.state.get_room_version(&room_id).await.ok(), membership: if let Some(sender_user) = sender_user { services .rooms .state_accessor - .get_member(&room_id, sender_user)? - .map_or_else(|| Some(MembershipState::Leave), |content| Some(content.membership)) + .get_member(&room_id, sender_user) + .await + .map_or_else(|_| MembershipState::Leave, |content| content.membership) + .into() } else { None }, @@ -173,7 +183,8 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .get_room_encryption(&room_id) - .unwrap_or_else(|_e| None), + .await + .ok(), }) } @@ -191,13 +202,14 @@ pub(crate) async fn delete_timezone_key_route( return Err!(Request(Forbidden("You cannot update the profile of another user"))); } - services.users.set_timezone(&body.user_id, None).await?; + services.users.set_timezone(&body.user_id, None); if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(delete_timezone_key::unstable::Response {}) @@ -217,16 +229,14 @@ pub(crate) async fn set_timezone_key_route( return Err!(Request(Forbidden("You cannot update the profile of another user"))); } - services - .users - .set_timezone(&body.user_id, body.tz.clone()) - .await?; + services.users.set_timezone(&body.user_id, body.tz.clone()); if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_timezone_key::unstable::Response {}) @@ -280,10 +290,11 @@ pub(crate) async fn set_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), all_joined_rooms).await?; + update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), &all_joined_rooms).await?; } else if body.key == "avatar_url" { let mxc = ruma::OwnedMxcUri::from(profile_key_value.to_string()); @@ -291,21 +302,23 @@ pub(crate) async fn set_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_avatar_url(&services, &body.user_id, Some(mxc), None, all_joined_rooms).await?; + update_avatar_url(&services, &body.user_id, Some(mxc), None, &all_joined_rooms).await?; } else { services .users - .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone()))?; + .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone())); } if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_profile_key::unstable::Response {}) @@ -335,30 +348,33 @@ pub(crate) async fn delete_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_displayname(&services, &body.user_id, None, all_joined_rooms).await?; + update_displayname(&services, &body.user_id, None, &all_joined_rooms).await?; } else if body.key == "avatar_url" { let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_avatar_url(&services, &body.user_id, None, None, all_joined_rooms).await?; + update_avatar_url(&services, &body.user_id, None, None, &all_joined_rooms).await?; } else { services .users - .set_profile_key(&body.user_id, &body.key, None)?; + .set_profile_key(&body.user_id, &body.key, None); } if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(delete_profile_key::unstable::Response {}) @@ -386,26 +402,25 @@ pub(crate) async fn get_timezone_key_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); return Ok(get_timezone_key::unstable::Response { tz: response.tz, @@ -413,14 +428,14 @@ pub(crate) async fn get_timezone_key_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_timezone_key::unstable::Response { - tz: services.users.timezone(&body.user_id)?, + tz: services.users.timezone(&body.user_id).await.ok(), }) } @@ -448,32 +463,31 @@ pub(crate) async fn get_profile_key_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); if let Some(value) = response.custom_profile_fields.get(&body.key) { profile_key_value.insert(body.key.clone(), value.clone()); services .users - .set_profile_key(&body.user_id, &body.key, Some(value.clone()))?; + .set_profile_key(&body.user_id, &body.key, Some(value.clone())); } else { return Err!(Request(NotFound("The requested profile key does not exist."))); } @@ -484,13 +498,13 @@ pub(crate) async fn get_profile_key_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation - return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); + return Err!(Request(NotFound("Profile was not found."))); } - if let Some(value) = services.users.profile_key(&body.user_id, &body.key)? { + if let Ok(value) = services.users.profile_key(&body.user_id, &body.key).await { profile_key_value.insert(body.key.clone(), value); } else { return Err!(Request(NotFound("The requested profile key does not exist."))); diff --git a/src/api/client/unversioned.rs b/src/api/client/unversioned.rs index d714fda5..d5bb14e5 100644 --- a/src/api/client/unversioned.rs +++ b/src/api/client/unversioned.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use axum::{extract::State, response::IntoResponse, Json}; +use futures::StreamExt; use ruma::api::client::{ discovery::{ discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo}, @@ -173,7 +174,7 @@ pub(crate) async fn conduwuit_server_version() -> Result { /// homeserver. Endpoint is disabled if federation is disabled for privacy. This /// only includes active users (not deactivated, no guests, etc) pub(crate) async fn conduwuit_local_user_count(State(services): State) -> Result { - let user_count = services.users.list_local_users()?.len(); + let user_count = services.users.list_local_users().count().await; Ok(Json(serde_json::json!({ "count": user_count diff --git a/src/api/client/user_directory.rs b/src/api/client/user_directory.rs index 87d4062c..8ea7f1b8 100644 --- a/src/api/client/user_directory.rs +++ b/src/api/client/user_directory.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use futures::{pin_mut, StreamExt}; use ruma::{ api::client::user_directory::search_users, events::{ @@ -21,14 +22,12 @@ pub(crate) async fn search_users_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10 - let mut users = services.users.iter().filter_map(|user_id| { + let users = services.users.stream().filter_map(|user_id| async { // Filter out buggy users (they should not exist, but you never know...) - let user_id = user_id.ok()?; - let user = search_users::v3::User { - user_id: user_id.clone(), - display_name: services.users.displayname(&user_id).ok()?, - avatar_url: services.users.avatar_url(&user_id).ok()?, + user_id: user_id.to_owned(), + display_name: services.users.displayname(user_id).await.ok(), + avatar_url: services.users.avatar_url(user_id).await.ok(), }; let user_id_matches = user @@ -56,20 +55,19 @@ pub(crate) async fn search_users_route( let user_is_in_public_rooms = services .rooms .state_cache - .rooms_joined(&user_id) - .filter_map(Result::ok) - .any(|room| { + .rooms_joined(&user.user_id) + .any(|room| async move { services .rooms .state_accessor - .room_state_get(&room, &StateEventType::RoomJoinRules, "") + .room_state_get(room, &StateEventType::RoomJoinRules, "") + .await .map_or(false, |event| { - event.map_or(false, |event| { - serde_json::from_str(event.content.get()) - .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) - }) + serde_json::from_str(event.content.get()) + .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) }) - }); + }) + .await; if user_is_in_public_rooms { user_visible = true; @@ -77,25 +75,22 @@ pub(crate) async fn search_users_route( let user_is_in_shared_rooms = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), user_id]) - .ok()? - .next() - .is_some(); + .has_shared_rooms(sender_user, &user.user_id) + .await; if user_is_in_shared_rooms { user_visible = true; } } - if !user_visible { - return None; - } - - Some(user) + user_visible.then_some(user) }); - let results = users.by_ref().take(limit).collect(); - let limited = users.next().is_some(); + pin_mut!(users); + + let limited = users.by_ref().next().await.is_some(); + + let results = users.take(limit).collect().await; Ok(search_users::v3::Response { results, diff --git a/src/api/router.rs b/src/api/router.rs index 4264e01d..c4275f05 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -22,101 +22,101 @@ use crate::{client, server}; pub fn build(router: Router, server: &Server) -> Router { let config = &server.config; let mut router = router - .ruma_route(client::get_timezone_key_route) - .ruma_route(client::get_profile_key_route) - .ruma_route(client::set_profile_key_route) - .ruma_route(client::delete_profile_key_route) - .ruma_route(client::set_timezone_key_route) - .ruma_route(client::delete_timezone_key_route) - .ruma_route(client::appservice_ping) - .ruma_route(client::get_supported_versions_route) - .ruma_route(client::get_register_available_route) - .ruma_route(client::register_route) - .ruma_route(client::get_login_types_route) - .ruma_route(client::login_route) - .ruma_route(client::whoami_route) - .ruma_route(client::logout_route) - .ruma_route(client::logout_all_route) - .ruma_route(client::change_password_route) - .ruma_route(client::deactivate_route) - .ruma_route(client::third_party_route) - .ruma_route(client::request_3pid_management_token_via_email_route) - .ruma_route(client::request_3pid_management_token_via_msisdn_route) - .ruma_route(client::check_registration_token_validity) - .ruma_route(client::get_capabilities_route) - .ruma_route(client::get_pushrules_all_route) - .ruma_route(client::set_pushrule_route) - .ruma_route(client::get_pushrule_route) - .ruma_route(client::set_pushrule_enabled_route) - .ruma_route(client::get_pushrule_enabled_route) - .ruma_route(client::get_pushrule_actions_route) - .ruma_route(client::set_pushrule_actions_route) - .ruma_route(client::delete_pushrule_route) - .ruma_route(client::get_room_event_route) - .ruma_route(client::get_room_aliases_route) - .ruma_route(client::get_filter_route) - .ruma_route(client::create_filter_route) - .ruma_route(client::create_openid_token_route) - .ruma_route(client::set_global_account_data_route) - .ruma_route(client::set_room_account_data_route) - .ruma_route(client::get_global_account_data_route) - .ruma_route(client::get_room_account_data_route) - .ruma_route(client::set_displayname_route) - .ruma_route(client::get_displayname_route) - .ruma_route(client::set_avatar_url_route) - .ruma_route(client::get_avatar_url_route) - .ruma_route(client::get_profile_route) - .ruma_route(client::set_presence_route) - .ruma_route(client::get_presence_route) - .ruma_route(client::upload_keys_route) - .ruma_route(client::get_keys_route) - .ruma_route(client::claim_keys_route) - .ruma_route(client::create_backup_version_route) - .ruma_route(client::update_backup_version_route) - .ruma_route(client::delete_backup_version_route) - .ruma_route(client::get_latest_backup_info_route) - .ruma_route(client::get_backup_info_route) - .ruma_route(client::add_backup_keys_route) - .ruma_route(client::add_backup_keys_for_room_route) - .ruma_route(client::add_backup_keys_for_session_route) - .ruma_route(client::delete_backup_keys_for_room_route) - .ruma_route(client::delete_backup_keys_for_session_route) - .ruma_route(client::delete_backup_keys_route) - .ruma_route(client::get_backup_keys_for_room_route) - .ruma_route(client::get_backup_keys_for_session_route) - .ruma_route(client::get_backup_keys_route) - .ruma_route(client::set_read_marker_route) - .ruma_route(client::create_receipt_route) - .ruma_route(client::create_typing_event_route) - .ruma_route(client::create_room_route) - .ruma_route(client::redact_event_route) - .ruma_route(client::report_event_route) - .ruma_route(client::create_alias_route) - .ruma_route(client::delete_alias_route) - .ruma_route(client::get_alias_route) - .ruma_route(client::join_room_by_id_route) - .ruma_route(client::join_room_by_id_or_alias_route) - .ruma_route(client::joined_members_route) - .ruma_route(client::leave_room_route) - .ruma_route(client::forget_room_route) - .ruma_route(client::joined_rooms_route) - .ruma_route(client::kick_user_route) - .ruma_route(client::ban_user_route) - .ruma_route(client::unban_user_route) - .ruma_route(client::invite_user_route) - .ruma_route(client::set_room_visibility_route) - .ruma_route(client::get_room_visibility_route) - .ruma_route(client::get_public_rooms_route) - .ruma_route(client::get_public_rooms_filtered_route) - .ruma_route(client::search_users_route) - .ruma_route(client::get_member_events_route) - .ruma_route(client::get_protocols_route) + .ruma_route(&client::get_timezone_key_route) + .ruma_route(&client::get_profile_key_route) + .ruma_route(&client::set_profile_key_route) + .ruma_route(&client::delete_profile_key_route) + .ruma_route(&client::set_timezone_key_route) + .ruma_route(&client::delete_timezone_key_route) + .ruma_route(&client::appservice_ping) + .ruma_route(&client::get_supported_versions_route) + .ruma_route(&client::get_register_available_route) + .ruma_route(&client::register_route) + .ruma_route(&client::get_login_types_route) + .ruma_route(&client::login_route) + .ruma_route(&client::whoami_route) + .ruma_route(&client::logout_route) + .ruma_route(&client::logout_all_route) + .ruma_route(&client::change_password_route) + .ruma_route(&client::deactivate_route) + .ruma_route(&client::third_party_route) + .ruma_route(&client::request_3pid_management_token_via_email_route) + .ruma_route(&client::request_3pid_management_token_via_msisdn_route) + .ruma_route(&client::check_registration_token_validity) + .ruma_route(&client::get_capabilities_route) + .ruma_route(&client::get_pushrules_all_route) + .ruma_route(&client::set_pushrule_route) + .ruma_route(&client::get_pushrule_route) + .ruma_route(&client::set_pushrule_enabled_route) + .ruma_route(&client::get_pushrule_enabled_route) + .ruma_route(&client::get_pushrule_actions_route) + .ruma_route(&client::set_pushrule_actions_route) + .ruma_route(&client::delete_pushrule_route) + .ruma_route(&client::get_room_event_route) + .ruma_route(&client::get_room_aliases_route) + .ruma_route(&client::get_filter_route) + .ruma_route(&client::create_filter_route) + .ruma_route(&client::create_openid_token_route) + .ruma_route(&client::set_global_account_data_route) + .ruma_route(&client::set_room_account_data_route) + .ruma_route(&client::get_global_account_data_route) + .ruma_route(&client::get_room_account_data_route) + .ruma_route(&client::set_displayname_route) + .ruma_route(&client::get_displayname_route) + .ruma_route(&client::set_avatar_url_route) + .ruma_route(&client::get_avatar_url_route) + .ruma_route(&client::get_profile_route) + .ruma_route(&client::set_presence_route) + .ruma_route(&client::get_presence_route) + .ruma_route(&client::upload_keys_route) + .ruma_route(&client::get_keys_route) + .ruma_route(&client::claim_keys_route) + .ruma_route(&client::create_backup_version_route) + .ruma_route(&client::update_backup_version_route) + .ruma_route(&client::delete_backup_version_route) + .ruma_route(&client::get_latest_backup_info_route) + .ruma_route(&client::get_backup_info_route) + .ruma_route(&client::add_backup_keys_route) + .ruma_route(&client::add_backup_keys_for_room_route) + .ruma_route(&client::add_backup_keys_for_session_route) + .ruma_route(&client::delete_backup_keys_for_room_route) + .ruma_route(&client::delete_backup_keys_for_session_route) + .ruma_route(&client::delete_backup_keys_route) + .ruma_route(&client::get_backup_keys_for_room_route) + .ruma_route(&client::get_backup_keys_for_session_route) + .ruma_route(&client::get_backup_keys_route) + .ruma_route(&client::set_read_marker_route) + .ruma_route(&client::create_receipt_route) + .ruma_route(&client::create_typing_event_route) + .ruma_route(&client::create_room_route) + .ruma_route(&client::redact_event_route) + .ruma_route(&client::report_event_route) + .ruma_route(&client::create_alias_route) + .ruma_route(&client::delete_alias_route) + .ruma_route(&client::get_alias_route) + .ruma_route(&client::join_room_by_id_route) + .ruma_route(&client::join_room_by_id_or_alias_route) + .ruma_route(&client::joined_members_route) + .ruma_route(&client::leave_room_route) + .ruma_route(&client::forget_room_route) + .ruma_route(&client::joined_rooms_route) + .ruma_route(&client::kick_user_route) + .ruma_route(&client::ban_user_route) + .ruma_route(&client::unban_user_route) + .ruma_route(&client::invite_user_route) + .ruma_route(&client::set_room_visibility_route) + .ruma_route(&client::get_room_visibility_route) + .ruma_route(&client::get_public_rooms_route) + .ruma_route(&client::get_public_rooms_filtered_route) + .ruma_route(&client::search_users_route) + .ruma_route(&client::get_member_events_route) + .ruma_route(&client::get_protocols_route) .route("/_matrix/client/unstable/thirdparty/protocols", get(client::get_protocols_route_unstable)) - .ruma_route(client::send_message_event_route) - .ruma_route(client::send_state_event_for_key_route) - .ruma_route(client::get_state_events_route) - .ruma_route(client::get_state_events_for_key_route) + .ruma_route(&client::send_message_event_route) + .ruma_route(&client::send_state_event_for_key_route) + .ruma_route(&client::get_state_events_route) + .ruma_route(&client::get_state_events_for_key_route) // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes // share one Ruma request / response type pair with {get,send}_state_event_for_key_route .route( @@ -140,46 +140,46 @@ pub fn build(router: Router, server: &Server) -> Router { get(client::get_state_events_for_empty_key_route) .put(client::send_state_event_for_empty_key_route), ) - .ruma_route(client::sync_events_route) - .ruma_route(client::sync_events_v4_route) - .ruma_route(client::get_context_route) - .ruma_route(client::get_message_events_route) - .ruma_route(client::search_events_route) - .ruma_route(client::turn_server_route) - .ruma_route(client::send_event_to_device_route) - .ruma_route(client::create_content_route) - .ruma_route(client::get_content_thumbnail_route) - .ruma_route(client::get_content_route) - .ruma_route(client::get_content_as_filename_route) - .ruma_route(client::get_media_preview_route) - .ruma_route(client::get_media_config_route) - .ruma_route(client::get_devices_route) - .ruma_route(client::get_device_route) - .ruma_route(client::update_device_route) - .ruma_route(client::delete_device_route) - .ruma_route(client::delete_devices_route) - .ruma_route(client::get_tags_route) - .ruma_route(client::update_tag_route) - .ruma_route(client::delete_tag_route) - .ruma_route(client::upload_signing_keys_route) - .ruma_route(client::upload_signatures_route) - .ruma_route(client::get_key_changes_route) - .ruma_route(client::get_pushers_route) - .ruma_route(client::set_pushers_route) - .ruma_route(client::upgrade_room_route) - .ruma_route(client::get_threads_route) - .ruma_route(client::get_relating_events_with_rel_type_and_event_type_route) - .ruma_route(client::get_relating_events_with_rel_type_route) - .ruma_route(client::get_relating_events_route) - .ruma_route(client::get_hierarchy_route) - .ruma_route(client::get_mutual_rooms_route) - .ruma_route(client::get_room_summary) + .ruma_route(&client::sync_events_route) + .ruma_route(&client::sync_events_v4_route) + .ruma_route(&client::get_context_route) + .ruma_route(&client::get_message_events_route) + .ruma_route(&client::search_events_route) + .ruma_route(&client::turn_server_route) + .ruma_route(&client::send_event_to_device_route) + .ruma_route(&client::create_content_route) + .ruma_route(&client::get_content_thumbnail_route) + .ruma_route(&client::get_content_route) + .ruma_route(&client::get_content_as_filename_route) + .ruma_route(&client::get_media_preview_route) + .ruma_route(&client::get_media_config_route) + .ruma_route(&client::get_devices_route) + .ruma_route(&client::get_device_route) + .ruma_route(&client::update_device_route) + .ruma_route(&client::delete_device_route) + .ruma_route(&client::delete_devices_route) + .ruma_route(&client::get_tags_route) + .ruma_route(&client::update_tag_route) + .ruma_route(&client::delete_tag_route) + .ruma_route(&client::upload_signing_keys_route) + .ruma_route(&client::upload_signatures_route) + .ruma_route(&client::get_key_changes_route) + .ruma_route(&client::get_pushers_route) + .ruma_route(&client::set_pushers_route) + .ruma_route(&client::upgrade_room_route) + .ruma_route(&client::get_threads_route) + .ruma_route(&client::get_relating_events_with_rel_type_and_event_type_route) + .ruma_route(&client::get_relating_events_with_rel_type_route) + .ruma_route(&client::get_relating_events_route) + .ruma_route(&client::get_hierarchy_route) + .ruma_route(&client::get_mutual_rooms_route) + .ruma_route(&client::get_room_summary) .route( "/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary", get(client::get_room_summary_legacy) ) - .ruma_route(client::well_known_support) - .ruma_route(client::well_known_client) + .ruma_route(&client::well_known_support) + .ruma_route(&client::well_known_client) .route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) @@ -187,35 +187,35 @@ pub fn build(router: Router, server: &Server) -> Router { if config.allow_federation { router = router - .ruma_route(server::get_server_version_route) + .ruma_route(&server::get_server_version_route) .route("/_matrix/key/v2/server", get(server::get_server_keys_route)) .route("/_matrix/key/v2/server/:key_id", get(server::get_server_keys_deprecated_route)) - .ruma_route(server::get_public_rooms_route) - .ruma_route(server::get_public_rooms_filtered_route) - .ruma_route(server::send_transaction_message_route) - .ruma_route(server::get_event_route) - .ruma_route(server::get_backfill_route) - .ruma_route(server::get_missing_events_route) - .ruma_route(server::get_event_authorization_route) - .ruma_route(server::get_room_state_route) - .ruma_route(server::get_room_state_ids_route) - .ruma_route(server::create_leave_event_template_route) - .ruma_route(server::create_leave_event_v1_route) - .ruma_route(server::create_leave_event_v2_route) - .ruma_route(server::create_join_event_template_route) - .ruma_route(server::create_join_event_v1_route) - .ruma_route(server::create_join_event_v2_route) - .ruma_route(server::create_invite_route) - .ruma_route(server::get_devices_route) - .ruma_route(server::get_room_information_route) - .ruma_route(server::get_profile_information_route) - .ruma_route(server::get_keys_route) - .ruma_route(server::claim_keys_route) - .ruma_route(server::get_openid_userinfo_route) - .ruma_route(server::get_hierarchy_route) - .ruma_route(server::well_known_server) - .ruma_route(server::get_content_route) - .ruma_route(server::get_content_thumbnail_route) + .ruma_route(&server::get_public_rooms_route) + .ruma_route(&server::get_public_rooms_filtered_route) + .ruma_route(&server::send_transaction_message_route) + .ruma_route(&server::get_event_route) + .ruma_route(&server::get_backfill_route) + .ruma_route(&server::get_missing_events_route) + .ruma_route(&server::get_event_authorization_route) + .ruma_route(&server::get_room_state_route) + .ruma_route(&server::get_room_state_ids_route) + .ruma_route(&server::create_leave_event_template_route) + .ruma_route(&server::create_leave_event_v1_route) + .ruma_route(&server::create_leave_event_v2_route) + .ruma_route(&server::create_join_event_template_route) + .ruma_route(&server::create_join_event_v1_route) + .ruma_route(&server::create_join_event_v2_route) + .ruma_route(&server::create_invite_route) + .ruma_route(&server::get_devices_route) + .ruma_route(&server::get_room_information_route) + .ruma_route(&server::get_profile_information_route) + .ruma_route(&server::get_keys_route) + .ruma_route(&server::claim_keys_route) + .ruma_route(&server::get_openid_userinfo_route) + .ruma_route(&server::get_hierarchy_route) + .ruma_route(&server::well_known_server) + .ruma_route(&server::get_content_route) + .ruma_route(&server::get_content_thumbnail_route) .route("/_conduwuit/local_user_count", get(client::conduwuit_local_user_count)); } else { router = router @@ -227,11 +227,11 @@ pub fn build(router: Router, server: &Server) -> Router { if config.allow_legacy_media { router = router - .ruma_route(client::get_media_config_legacy_route) - .ruma_route(client::get_media_preview_legacy_route) - .ruma_route(client::get_content_legacy_route) - .ruma_route(client::get_content_as_filename_legacy_route) - .ruma_route(client::get_content_thumbnail_legacy_route) + .ruma_route(&client::get_media_config_legacy_route) + .ruma_route(&client::get_media_preview_legacy_route) + .ruma_route(&client::get_content_legacy_route) + .ruma_route(&client::get_content_as_filename_legacy_route) + .ruma_route(&client::get_content_thumbnail_legacy_route) .route("/_matrix/media/v1/config", get(client::get_media_config_legacy_legacy_route)) .route("/_matrix/media/v1/upload", post(client::create_content_legacy_route)) .route( diff --git a/src/api/router/args.rs b/src/api/router/args.rs index a3d09dff..7381a55f 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -10,7 +10,10 @@ use super::{auth, auth::Auth, request, request::Request}; use crate::{service::appservice::RegistrationInfo, State}; /// Extractor for Ruma request structs -pub(crate) struct Args { +pub(crate) struct Args +where + T: IncomingRequest + Send + Sync + 'static, +{ /// Request struct body pub(crate) body: T, @@ -38,7 +41,7 @@ pub(crate) struct Args { #[async_trait] impl FromRequest for Args where - T: IncomingRequest, + T: IncomingRequest + Send + Sync + 'static, { type Rejection = Error; @@ -57,7 +60,10 @@ where } } -impl Deref for Args { +impl Deref for Args +where + T: IncomingRequest + Send + Sync + 'static, +{ type Target = T; fn deref(&self) -> &Self::Target { &self.body } @@ -67,7 +73,7 @@ fn make_body( services: &Services, request: &mut Request, json_body: &mut Option, auth: &Auth, ) -> Result where - T: IncomingRequest, + T: IncomingRequest + Send + Sync + 'static, { let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { let user_id = auth.sender_user.clone().unwrap_or_else(|| { @@ -77,15 +83,13 @@ where let uiaa_request = json_body .get("auth") - .and_then(|auth| auth.as_object()) + .and_then(CanonicalJsonValue::as_object) .and_then(|auth| auth.get("session")) - .and_then(|session| session.as_str()) + .and_then(CanonicalJsonValue::as_str) .and_then(|session| { - services.uiaa.get_uiaa_request( - &user_id, - &auth.sender_device.clone().unwrap_or_else(|| EMPTY.into()), - session, - ) + services + .uiaa + .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session) }); if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 670f72ba..8d76b4be 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -44,8 +44,8 @@ pub(super) async fn auth( let token = if let Some(token) = token { if let Some(reg_info) = services.appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info)) - } else if let Some((user_id, device_id)) = services.users.find_from_token(token)? { - Token::User((user_id, OwnedDeviceId::from(device_id))) + } else if let Ok((user_id, device_id)) = services.users.find_from_token(token).await { + Token::User((user_id, device_id)) } else { Token::Invalid } @@ -98,7 +98,7 @@ pub(super) async fn auth( )) } }, - (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info)?), + (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info).await?), (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { Ok(Auth { origin: None, @@ -150,7 +150,7 @@ pub(super) async fn auth( } } -fn auth_appservice(services: &Services, request: &Request, info: Box) -> Result { +async fn auth_appservice(services: &Services, request: &Request, info: Box) -> Result { let user_id = request .query .user_id @@ -170,7 +170,7 @@ fn auth_appservice(services: &Services, request: &Request, info: Box { + fn add_route(&'static self, router: Router, path: &str) -> Router; + fn add_routes(&'static self, router: Router) -> Router; +} + pub(in super::super) trait RouterExt { - fn ruma_route(self, handler: H) -> Self + fn ruma_route(self, handler: &'static H) -> Self where H: RumaHandler; } impl RouterExt for Router { - fn ruma_route(self, handler: H) -> Self + fn ruma_route(self, handler: &'static H) -> Self where H: RumaHandler, { @@ -27,34 +31,28 @@ impl RouterExt for Router { } } -pub(in super::super) trait RumaHandler { - fn add_routes(&self, router: Router) -> Router; - - fn add_route(&self, router: Router, path: &str) -> Router; -} - macro_rules! ruma_handler { ( $($tx:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl RumaHandler<($($tx,)* Ruma,)> for Fun + impl RumaHandler<($($tx,)* Ruma,)> for Fun where - Req: IncomingRequest + Send + 'static, - Ret: IntoResponse, - Fut: Future> + Send, - Fun: FnOnce($($tx,)* Ruma,) -> Fut + Clone + Send + Sync + 'static, - $( $tx: FromRequestParts + Send + 'static, )* + Fun: Fn($($tx,)* Ruma,) -> Fut + Send + Sync + 'static, + Fut: Future> + Send, + Req: IncomingRequest + Send + Sync, + Err: IntoResponse + Send, + ::OutgoingResponse: Send, + $( $tx: FromRequestParts + Send + Sync + 'static, )* { - fn add_routes(&self, router: Router) -> Router { + fn add_routes(&'static self, router: Router) -> Router { Req::METADATA .history .all_paths() .fold(router, |router, path| self.add_route(router, path)) } - fn add_route(&self, router: Router, path: &str) -> Router { - let handle = self.clone(); + fn add_route(&'static self, router: Router, path: &str) -> Router { + let action = |$($tx,)* req| self($($tx,)* req).map_ok(RumaResponse); let method = method_to_filter(&Req::METADATA.method); - let action = |$($tx,)* req| async { handle($($tx,)* req).await.map(RumaResponse) }; router.route(path, on(method, action)) } } diff --git a/src/api/router/response.rs b/src/api/router/response.rs index 2aaa79fa..70bbb936 100644 --- a/src/api/router/response.rs +++ b/src/api/router/response.rs @@ -5,13 +5,18 @@ use http::StatusCode; use http_body_util::Full; use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; -pub(crate) struct RumaResponse(pub(crate) T); +pub(crate) struct RumaResponse(pub(crate) T) +where + T: OutgoingResponse; impl From for RumaResponse { fn from(t: Error) -> Self { Self(t.into()) } } -impl IntoResponse for RumaResponse { +impl IntoResponse for RumaResponse +where + T: OutgoingResponse, +{ fn into_response(self) -> Response { self.0 .try_into_http_response::() diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 1b665c19..2bbc95ca 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -1,9 +1,13 @@ +use std::cmp; + use axum::extract::State; -use conduit::{Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation::backfill::get_backfill}, - uint, user_id, MilliSecondsSinceUnixEpoch, +use conduit::{ + is_equal_to, + utils::{IterStream, ReadyExt}, + Err, PduCount, Result, }; +use futures::{FutureExt, StreamExt}; +use ruma::{api::federation::backfill::get_backfill, uint, user_id, MilliSecondsSinceUnixEpoch}; use crate::Ruma; @@ -19,27 +23,35 @@ pub(crate) async fn get_backfill_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let until = body .v .iter() - .map(|event_id| services.rooms.timeline.get_pdu_count(event_id)) - .filter_map(|r| r.ok().flatten()) - .max() - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?; + .stream() + .filter_map(|event_id| { + services + .rooms + .timeline + .get_pdu_count(event_id) + .map(Result::ok) + }) + .ready_fold(PduCount::Backfilled(0), cmp::max) + .await; let limit = body .limit @@ -47,31 +59,37 @@ pub(crate) async fn get_backfill_route( .try_into() .expect("UInt could not be converted to usize"); - let all_events = services + let pdus = services .rooms .timeline - .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? - .take(limit); + .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until) + .await? + .take(limit) + .filter_map(|(_, pdu)| async move { + if !services + .rooms + .state_accessor + .server_can_see_event(origin, &pdu.room_id, &pdu.event_id) + .await + .is_ok_and(is_equal_to!(true)) + { + return None; + } - let events = all_events - .filter_map(Result::ok) - .filter(|(_, e)| { - matches!( - services - .rooms - .state_accessor - .server_can_see_event(origin, &e.room_id, &e.event_id,), - Ok(true), - ) + services + .rooms + .timeline + .get_pdu_json(&pdu.event_id) + .await + .ok() }) - .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) - .filter_map(|r| r.ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(); + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await; Ok(get_backfill::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdus: events, + pdus, }) } diff --git a/src/api/server/event.rs b/src/api/server/event.rs index e11a01a2..e4eac794 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,9 +1,6 @@ use axum::extract::State; -use conduit::{Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation::event::get_event}, - MilliSecondsSinceUnixEpoch, RoomId, -}; +use conduit::{err, Err, Result}; +use ruma::{api::federation::event::get_event, MilliSecondsSinceUnixEpoch, RoomId}; use crate::Ruma; @@ -21,34 +18,46 @@ pub(crate) async fn get_event_route( let event = services .rooms .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + .get_pdu_json(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Event not found."))))?; let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database."))?; + .ok_or_else(|| err!(Database("Invalid event in database.")))?; let room_id = - <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; + <&RoomId>::try_from(room_id_str).map_err(|_| err!(Database("Invalid room_id in event in database.")))?; - if !services.rooms.state_accessor.is_world_readable(room_id)? - && !services.rooms.state_cache.server_in_room(origin, room_id)? + if !services + .rooms + .state_accessor + .is_world_readable(room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } if !services .rooms .state_accessor - .server_can_see_event(origin, room_id, &body.event_id)? + .server_can_see_event(origin, room_id, &body.event_id) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not allowed to see event.")); + return Err!(Request(Forbidden("Server is not allowed to see event."))); } Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: services.sending.convert_to_outgoing_federation_event(event), + pdu: services + .sending + .convert_to_outgoing_federation_event(event) + .await, }) } diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 4b0f6bc0..6ec00b50 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use axum::extract::State; use conduit::{Error, Result}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, RoomId, @@ -22,16 +23,18 @@ pub(crate) async fn get_event_authorization_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); } @@ -39,8 +42,9 @@ pub(crate) async fn get_event_authorization_route( let event = services .rooms .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + .get_pdu_json(&body.event_id) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; let room_id_str = event .get("room_id") @@ -50,16 +54,17 @@ pub(crate) async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; - let auth_chain_ids = services + let auth_chain = services .rooms .auth_chain .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) - .await?; + .await? + .filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() }) + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await; Ok(get_event_authorization::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), + auth_chain, }) } diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index e2c3c93c..7ae0ff60 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -18,16 +18,18 @@ pub(crate) async fn get_missing_events_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room")); } @@ -43,7 +45,12 @@ pub(crate) async fn get_missing_events_route( let mut i: usize = 0; while i < queued_events.len() && events.len() < limit { - if let Some(pdu) = services.rooms.timeline.get_pdu_json(&queued_events[i])? { + if let Ok(pdu) = services + .rooms + .timeline + .get_pdu_json(&queued_events[i]) + .await + { let room_id_str = pdu .get("room_id") .and_then(|val| val.as_str()) @@ -64,7 +71,8 @@ pub(crate) async fn get_missing_events_route( if !services .rooms .state_accessor - .server_can_see_event(origin, &body.room_id, &queued_events[i])? + .server_can_see_event(origin, &body.room_id, &queued_events[i]) + .await? { i = i.saturating_add(1); continue; @@ -81,7 +89,12 @@ pub(crate) async fn get_missing_events_route( ) .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, ); - events.push(services.sending.convert_to_outgoing_federation_event(pdu)); + events.push( + services + .sending + .convert_to_outgoing_federation_event(pdu) + .await, + ); } i = i.saturating_add(1); } diff --git a/src/api/server/hierarchy.rs b/src/api/server/hierarchy.rs index 530ed145..002bd763 100644 --- a/src/api/server/hierarchy.rs +++ b/src/api/server/hierarchy.rs @@ -12,7 +12,7 @@ pub(crate) async fn get_hierarchy_route( ) -> Result { let origin = body.origin.as_ref().expect("server is authenticated"); - if services.rooms.metadata.exists(&body.room_id)? { + if services.rooms.metadata.exists(&body.room_id).await { services .rooms .spaces diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 688e026c..9968bdf7 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -24,7 +24,8 @@ pub(crate) async fn create_invite_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .globals @@ -98,7 +99,8 @@ pub(crate) async fn create_invite_route( services .rooms .event_handler - .acl_check(invited_user.server_name(), &body.room_id)?; + .acl_check(invited_user.server_name(), &body.room_id) + .await?; ruma::signatures::hash_and_sign_event( services.globals.server_name().as_str(), @@ -128,14 +130,14 @@ pub(crate) async fn create_invite_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?; - if services.rooms.metadata.is_banned(&body.room_id)? && !services.users.is_admin(&invited_user)? { + if services.rooms.metadata.is_banned(&body.room_id).await && !services.users.is_admin(&invited_user).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "This room is banned on this homeserver.", )); } - if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user)? { + if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "This server does not allow room invites.", @@ -159,22 +161,28 @@ pub(crate) async fn create_invite_route( if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), &body.room_id)? + .server_in_room(services.globals.server_name(), &body.room_id) + .await { - services.rooms.state_cache.update_membership( - &body.room_id, - &invited_user, - RoomMemberEventContent::new(MembershipState::Invite), - &sender, - Some(invite_state), - body.via.clone(), - true, - )?; + services + .rooms + .state_cache + .update_membership( + &body.room_id, + &invited_user, + RoomMemberEventContent::new(MembershipState::Invite), + &sender, + Some(invite_state), + body.via.clone(), + true, + ) + .await?; } Ok(create_invite::v2::Response { event: services .sending - .convert_to_outgoing_federation_event(signed_event), + .convert_to_outgoing_federation_event(signed_event) + .await, }) } diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index 021016be..ba081aad 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::utils::{IterStream, ReadyExt}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_join_event}, events::{ @@ -24,7 +26,7 @@ use crate::{ pub(crate) async fn create_join_event_template_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -40,7 +42,8 @@ pub(crate) async fn create_join_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if services .globals @@ -73,7 +76,7 @@ pub(crate) async fn create_join_event_template_route( } } - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; @@ -81,22 +84,24 @@ pub(crate) async fn create_join_event_template_route( .rooms .state_cache .is_left(&body.user_id, &body.room_id) - .unwrap_or(true)) - && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id)? + .await) + && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id).await? { let auth_user = services .rooms .state_cache .room_members(&body.room_id) - .filter_map(Result::ok) - .filter(|user| user.server_name() == services.globals.server_name()) - .find(|user| { + .ready_filter(|user| user.server_name() == services.globals.server_name()) + .filter(|user| { services .rooms .state_accessor .user_can_invite(&body.room_id, user, &body.user_id, &state_lock) - .unwrap_or(false) - }); + }) + .boxed() + .next() + .await + .map(ToOwned::to_owned); if auth_user.is_some() { auth_user @@ -110,7 +115,7 @@ pub(crate) async fn create_join_event_template_route( None }; - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; if !body.ver.contains(&room_version_id) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { @@ -132,19 +137,23 @@ pub(crate) async fn create_join_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + timestamp: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); @@ -161,7 +170,7 @@ pub(crate) async fn create_join_event_template_route( /// This doesn't check the current user's membership. This should be done /// externally, either by using the state cache or attempting to authorize the /// event. -pub(crate) fn user_can_perform_restricted_join( +pub(crate) async fn user_can_perform_restricted_join( services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result { use RoomVersionId::*; @@ -169,18 +178,15 @@ pub(crate) fn user_can_perform_restricted_join( let join_rules_event = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + .room_state_get(room_id, &StateEventType::RoomJoinRules, "") + .await; - let Some(join_rules_event_content) = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str::(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event in database: {e}"); - Error::bad_database("Invalid join rules event in database") - }) + let Ok(Ok(join_rules_event_content)) = join_rules_event.as_ref().map(|join_rules_event| { + serde_json::from_str::(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event in database: {e}"); + Error::bad_database("Invalid join rules event in database") }) - .transpose()? - else { + }) else { return Ok(false); }; @@ -201,13 +207,10 @@ pub(crate) fn user_can_perform_restricted_join( None } }) - .any(|m| { - services - .rooms - .state_cache - .is_joined(user_id, &m.room_id) - .unwrap_or(false) - }) { + .stream() + .any(|m| services.rooms.state_cache.is_joined(user_id, &m.room_id)) + .await + { Ok(true) } else { Err(Error::BadRequest( diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 3eb0d77a..41ea1c80 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -18,7 +18,7 @@ use crate::{service::pdu::PduBuilder, Ruma}; pub(crate) async fn create_leave_event_template_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -34,9 +34,10 @@ pub(crate) async fn create_leave_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, @@ -50,19 +51,23 @@ pub(crate) async fn create_leave_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + timestamp: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); diff --git a/src/api/server/openid.rs b/src/api/server/openid.rs index 6a1b99b7..9b54807a 100644 --- a/src/api/server/openid.rs +++ b/src/api/server/openid.rs @@ -10,6 +10,9 @@ pub(crate) async fn get_openid_userinfo_route( State(services): State, body: Ruma, ) -> Result { Ok(get_openid_userinfo::v1::Response::new( - services.users.find_from_openid_token(&body.access_token)?, + services + .users + .find_from_openid_token(&body.access_token) + .await?, )) } diff --git a/src/api/server/query.rs b/src/api/server/query.rs index c2b78bde..348b8c6e 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -1,7 +1,8 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{err, Error, Result}; +use futures::StreamExt; use get_profile_information::v1::ProfileField; use rand::seq::SliceRandom; use ruma::{ @@ -23,15 +24,17 @@ pub(crate) async fn get_room_information_route( let room_id = services .rooms .alias - .resolve_local_alias(&body.room_alias)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; + .resolve_local_alias(&body.room_alias) + .await + .map_err(|_| err!(Request(NotFound("Room alias not found."))))?; let mut servers: Vec = services .rooms .state_cache .room_servers(&room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; servers.sort_unstable(); servers.dedup(); @@ -82,30 +85,31 @@ pub(crate) async fn get_profile_information_route( match &body.field { Some(ProfileField::DisplayName) => { - displayname = services.users.displayname(&body.user_id)?; + displayname = services.users.displayname(&body.user_id).await.ok(); }, Some(ProfileField::AvatarUrl) => { - avatar_url = services.users.avatar_url(&body.user_id)?; - blurhash = services.users.blurhash(&body.user_id)?; + avatar_url = services.users.avatar_url(&body.user_id).await.ok(); + blurhash = services.users.blurhash(&body.user_id).await.ok(); }, Some(custom_field) => { - if let Some(value) = services + if let Ok(value) = services .users - .profile_key(&body.user_id, custom_field.as_str())? + .profile_key(&body.user_id, custom_field.as_str()) + .await { custom_profile_fields.insert(custom_field.to_string(), value); } }, None => { - displayname = services.users.displayname(&body.user_id)?; - avatar_url = services.users.avatar_url(&body.user_id)?; - blurhash = services.users.blurhash(&body.user_id)?; - tz = services.users.timezone(&body.user_id)?; + displayname = services.users.displayname(&body.user_id).await.ok(); + avatar_url = services.users.avatar_url(&body.user_id).await.ok(); + blurhash = services.users.blurhash(&body.user_id).await.ok(); + tz = services.users.timezone(&body.user_id).await.ok(); custom_profile_fields = services .users .all_profile_keys(&body.user_id) - .filter_map(Result::ok) - .collect(); + .collect() + .await; }, } diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 15f82faa..bb424988 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -2,7 +2,8 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant}; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug, debug_warn, err, trace, warn, Err}; +use conduit::{debug, debug_warn, err, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, Result}; +use futures::StreamExt; use ruma::{ api::{ client::error::ErrorKind, @@ -23,10 +24,13 @@ use tokio::sync::RwLock; use crate::{ services::Services, utils::{self}, - Error, Result, Ruma, + Ruma, }; -type ResolvedMap = BTreeMap>; +const PDU_LIMIT: usize = 50; +const EDU_LIMIT: usize = 100; + +type ResolvedMap = BTreeMap>; /// # `PUT /_matrix/federation/v1/send/{txnId}` /// @@ -44,12 +48,16 @@ pub(crate) async fn send_transaction_message_route( ))); } - if body.pdus.len() > 50_usize { - return Err!(Request(Forbidden("Not allowed to send more than 50 PDUs in one transaction"))); + if body.pdus.len() > PDU_LIMIT { + return Err!(Request(Forbidden( + "Not allowed to send more than {PDU_LIMIT} PDUs in one transaction" + ))); } - if body.edus.len() > 100_usize { - return Err!(Request(Forbidden("Not allowed to send more than 100 EDUs in one transaction"))); + if body.edus.len() > EDU_LIMIT { + return Err!(Request(Forbidden( + "Not allowed to send more than {EDU_LIMIT} EDUs in one transaction" + ))); } let txn_start_time = Instant::now(); @@ -62,8 +70,8 @@ pub(crate) async fn send_transaction_message_route( "Starting txn", ); - let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await?; - handle_edus(&services, &client, &body, origin).await?; + let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await; + handle_edus(&services, &client, &body, origin).await; debug!( pdus = ?body.pdus.len(), @@ -85,10 +93,10 @@ pub(crate) async fn send_transaction_message_route( async fn handle_pdus( services: &Services, _client: &IpAddr, body: &Ruma, origin: &ServerName, txn_start_time: &Instant, -) -> Result { +) -> ResolvedMap { let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); for pdu in &body.pdus { - parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) { + parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await { Ok(t) => t, Err(e) => { debug_warn!("Could not parse PDU: {e}"); @@ -151,38 +159,34 @@ async fn handle_pdus( } } - Ok(resolved_map) + resolved_map } async fn handle_edus( services: &Services, client: &IpAddr, body: &Ruma, origin: &ServerName, -) -> Result<()> { +) { for edu in body .edus .iter() .filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) { match edu { - Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await?, - Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await?, - Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await?, - Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await?, - Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await?, - Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await?, + Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await, + Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await, + Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await, + Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await, + Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await, + Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await, Edu::_Custom(ref _custom) => { debug_warn!(?body.edus, "received custom/unknown EDU"); }, } } - - Ok(()) } -async fn handle_edu_presence( - services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent, -) -> Result<()> { +async fn handle_edu_presence(services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent) { if !services.globals.allow_incoming_presence() { - return Ok(()); + return; } for update in presence.push { @@ -194,23 +198,24 @@ async fn handle_edu_presence( continue; } - services.presence.set_presence( - &update.user_id, - &update.presence, - Some(update.currently_active), - Some(update.last_active_ago), - update.status_msg.clone(), - )?; + services + .presence + .set_presence( + &update.user_id, + &update.presence, + Some(update.currently_active), + Some(update.last_active_ago), + update.status_msg.clone(), + ) + .await + .log_err() + .ok(); } - - Ok(()) } -async fn handle_edu_receipt( - services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent, -) -> Result<()> { +async fn handle_edu_receipt(services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent) { if !services.globals.allow_incoming_read_receipts() { - return Ok(()); + return; } for (room_id, room_updates) in receipt.receipts { @@ -218,6 +223,7 @@ async fn handle_edu_receipt( .rooms .event_handler .acl_check(origin, &room_id) + .await .is_err() { debug_warn!( @@ -240,8 +246,8 @@ async fn handle_edu_receipt( .rooms .state_cache .room_members(&room_id) - .filter_map(Result::ok) - .any(|member| member.server_name() == user_id.server_name()) + .ready_any(|member| member.server_name() == user_id.server_name()) + .await { for event_id in &user_updates.event_ids { let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]); @@ -255,7 +261,8 @@ async fn handle_edu_receipt( services .rooms .read_receipt - .readreceipt_update(&user_id, &room_id, &event)?; + .readreceipt_update(&user_id, &room_id, &event) + .await; } } else { debug_warn!( @@ -266,15 +273,11 @@ async fn handle_edu_receipt( } } } - - Ok(()) } -async fn handle_edu_typing( - services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent, -) -> Result<()> { +async fn handle_edu_typing(services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent) { if !services.globals.config.allow_incoming_typing { - return Ok(()); + return; } if typing.user_id.server_name() != origin { @@ -282,26 +285,28 @@ async fn handle_edu_typing( %typing.user_id, %origin, "received typing EDU for user not belonging to origin" ); - return Ok(()); + return; } if services .rooms .event_handler .acl_check(typing.user_id.server_name(), &typing.room_id) + .await .is_err() { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for ACL'd user's server" ); - return Ok(()); + return; } if services .rooms .state_cache - .is_joined(&typing.user_id, &typing.room_id)? + .is_joined(&typing.user_id, &typing.room_id) + .await { if typing.typing { let timeout = utils::millis_since_unix_epoch().saturating_add( @@ -315,28 +320,29 @@ async fn handle_edu_typing( .rooms .typing .typing_add(&typing.user_id, &typing.room_id, timeout) - .await?; + .await + .log_err() + .ok(); } else { services .rooms .typing .typing_remove(&typing.user_id, &typing.room_id) - .await?; + .await + .log_err() + .ok(); } } else { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for user not in room" ); - return Ok(()); } - - Ok(()) } async fn handle_edu_device_list_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent, -) -> Result<()> { +) { let DeviceListUpdateContent { user_id, .. @@ -347,17 +353,15 @@ async fn handle_edu_device_list_update( %user_id, %origin, "received device list update EDU for user not belonging to origin" ); - return Ok(()); + return; } - services.users.mark_device_key_update(&user_id)?; - - Ok(()) + services.users.mark_device_key_update(&user_id).await; } async fn handle_edu_direct_to_device( services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent, -) -> Result<()> { +) { let DirectDeviceContent { sender, ev_type, @@ -370,45 +374,52 @@ async fn handle_edu_direct_to_device( %sender, %origin, "received direct to device EDU for user not belonging to origin" ); - return Ok(()); + return; } // Check if this is a new transaction id if services .transaction_ids - .existing_txnid(&sender, None, &message_id)? - .is_some() + .existing_txnid(&sender, None, &message_id) + .await + .is_ok() { - return Ok(()); + return; } for (target_user_id, map) in &messages { for (target_device_id_maybe, event) in map { + let Ok(event) = event + .deserialize_as() + .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}"))))) + else { + continue; + }; + + let ev_type = ev_type.to_string(); match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services.users.add_to_device_event( - &sender, - target_user_id, - target_device_id, - &ev_type.to_string(), - event - .deserialize_as() - .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))?, - )?; + services + .users + .add_to_device_event(&sender, target_user_id, target_device_id, &ev_type, event) + .await; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services.users.all_device_ids(target_user_id) { - services.users.add_to_device_event( - &sender, - target_user_id, - &target_device_id?, - &ev_type.to_string(), - event - .deserialize_as() - .map_err(|e| err!(Request(InvalidParam("Event is invalid: {e}"))))?, - )?; - } + let (sender, ev_type, event) = (&sender, &ev_type, &event); + services + .users + .all_device_ids(target_user_id) + .for_each(|target_device_id| { + services.users.add_to_device_event( + sender, + target_user_id, + target_device_id, + ev_type, + event.clone(), + ) + }) + .await; }, } } @@ -417,14 +428,12 @@ async fn handle_edu_direct_to_device( // Save transaction id with empty data services .transaction_ids - .add_txnid(&sender, None, &message_id, &[])?; - - Ok(()) + .add_txnid(&sender, None, &message_id, &[]); } async fn handle_edu_signing_key_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent, -) -> Result<()> { +) { let SigningKeyUpdateContent { user_id, master_key, @@ -436,14 +445,15 @@ async fn handle_edu_signing_key_update( %user_id, %origin, "received signing key update EDU from server that does not belong to user's server" ); - return Ok(()); + return; } if let Some(master_key) = master_key { services .users - .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; + .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true) + .await + .log_err() + .ok(); } - - Ok(()) } diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index c4d016f6..639fcafd 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -3,7 +3,8 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{pdu::gen_event_id_canonical_json, warn, Error, Result}; +use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_join_event}, events::{ @@ -22,27 +23,32 @@ use crate::Ruma; async fn create_join_event( services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin server - services.rooms.event_handler.acl_check(origin, room_id)?; + services + .rooms + .event_handler + .acl_check(origin, room_id) + .await?; // We need to return the state prior to joining, let's keep a reference to that // here let shortstatehash = services .rooms .state - .get_room_shortstatehash(room_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event state not found."))?; + .get_room_shortstatehash(room_id) + .await + .map_err(|_| err!(Request(NotFound("Event state not found."))))?; let pub_key_map = RwLock::new(BTreeMap::new()); // let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json @@ -97,7 +103,8 @@ async fn create_join_event( services .rooms .event_handler - .acl_check(sender.server_name(), room_id)?; + .acl_check(sender.server_name(), room_id) + .await?; // check if origin server is trying to send for another server if sender.server_name() != origin { @@ -126,7 +133,9 @@ async fn create_join_event( if content .join_authorized_via_users_server .is_some_and(|user| services.globals.user_is_local(&user)) - && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default() + && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id) + .await + .unwrap_or_default() { ruma::signatures::hash_and_sign_event( services.globals.server_name().as_str(), @@ -158,12 +167,14 @@ async fn create_join_event( .mutex_federation .lock(room_id) .await; + let pdu_id: Vec = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map) .await? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + drop(mutex_lock); let state_ids = services @@ -171,29 +182,43 @@ async fn create_join_event( .state_accessor .state_full_ids(shortstatehash) .await?; - let auth_chain_ids = services + + let state = state_ids + .iter() + .try_stream() + .and_then(|(_, event_id)| services.rooms.timeline.get_pdu_json(event_id)) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() + .await?; + + let auth_chain = services .rooms .auth_chain .event_ids_iter(room_id, state_ids.values().cloned().collect()) + .await? + .map(Ok) + .and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() .await?; - services.sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id).await?; Ok(create_join_event::v1::RoomState { - auth_chain: auth_chain_ids - .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), - state: state_ids - .iter() - .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), + auth_chain, + state, // Event field is required if the room version supports restricted join rules. - event: Some( - to_raw_value(&CanonicalJsonValue::Object(value)) - .expect("To raw json should not fail since only change was adding signature"), - ), + event: to_raw_value(&CanonicalJsonValue::Object(value)).ok(), }) } diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index e77c5d78..81f41af0 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{utils::ReadyExt, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ @@ -49,18 +49,22 @@ pub(crate) async fn create_leave_event_v2_route( async fn create_leave_event( services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result<()> { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin - services.rooms.event_handler.acl_check(origin, room_id)?; + services + .rooms + .event_handler + .acl_check(origin, room_id) + .await?; let pub_key_map = RwLock::new(BTreeMap::new()); // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json return Err(Error::BadRequest( @@ -114,7 +118,8 @@ async fn create_leave_event( services .rooms .event_handler - .acl_check(sender.server_name(), room_id)?; + .acl_check(sender.server_name(), room_id) + .await?; if sender.server_name() != origin { return Err(Error::BadRequest( @@ -173,10 +178,9 @@ async fn create_leave_event( .rooms .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| !services.globals.server_is_ours(server)); + .ready_filter(|server| !services.globals.server_is_ours(server)); - services.sending.send_pdu_servers(servers, &pdu_id)?; + services.sending.send_pdu_servers(servers, &pdu_id).await?; Ok(()) } diff --git a/src/api/server/state.rs b/src/api/server/state.rs index d215236a..37a14a3f 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use axum::extract::State; -use conduit::{Error, Result}; -use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; +use conduit::{err, result::LogErr, utils::IterStream, Err, Result}; +use futures::{FutureExt, StreamExt, TryStreamExt}; +use ruma::api::federation::event::get_room_state; use crate::Ruma; @@ -17,56 +18,66 @@ pub(crate) async fn get_room_state_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; + .pdu_shortstatehash(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("PDU state not found."))))?; let pdus = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await? - .into_values() - .map(|id| { + .await + .log_err() + .map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))? + .values() + .try_stream() + .and_then(|id| services.rooms.timeline.get_pdu_json(id)) + .and_then(|pdu| { services .sending - .convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()) + .convert_to_outgoing_federation_event(pdu) + .map(Ok) }) - .collect(); + .try_collect() + .await?; - let auth_chain_ids = services + let auth_chain = services .rooms .auth_chain .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) + .await? + .map(Ok) + .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() .await?; Ok(get_room_state::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| { - services - .rooms - .timeline - .get_pdu_json(&id) - .ok()? - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - }) - .collect(), + auth_chain, pdus, }) } diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index d22f2df4..95ca65aa 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,9 +1,11 @@ use std::sync::Arc; use axum::extract::State; -use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids}; +use conduit::{err, Err}; +use futures::StreamExt; +use ruma::api::federation::event::get_room_state_ids; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// @@ -17,31 +19,35 @@ pub(crate) async fn get_room_state_ids_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; + .pdu_shortstatehash(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; let pdu_ids = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await? + .await + .map_err(|_| err!(Request(NotFound("State ids not found"))))? .into_values() .map(|id| (*id).to_owned()) .collect(); @@ -50,10 +56,13 @@ pub(crate) async fn get_room_state_ids_route( .rooms .auth_chain .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) - .await?; + .await? + .map(|id| (*id).to_owned()) + .collect() + .await; Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), + auth_chain_ids, pdu_ids, }) } diff --git a/src/api/server/user.rs b/src/api/server/user.rs index e9a400a7..0718da58 100644 --- a/src/api/server/user.rs +++ b/src/api/server/user.rs @@ -1,5 +1,6 @@ use axum::extract::State; use conduit::{Error, Result}; +use futures::{FutureExt, StreamExt, TryFutureExt}; use ruma::api::{ client::error::ErrorKind, federation::{ @@ -28,41 +29,51 @@ pub(crate) async fn get_devices_route( let origin = body.origin.as_ref().expect("server is authenticated"); + let user_id = &body.user_id; Ok(get_devices::v1::Response { - user_id: body.user_id.clone(), + user_id: user_id.clone(), stream_id: services .users - .get_devicelist_version(&body.user_id)? + .get_devicelist_version(user_id) + .await .unwrap_or(0) - .try_into() - .expect("version will not grow that large"), + .try_into()?, devices: services .users - .all_devices_metadata(&body.user_id) - .filter_map(Result::ok) - .filter_map(|metadata| { - let device_id_string = metadata.device_id.as_str().to_owned(); + .all_devices_metadata(user_id) + .filter_map(|metadata| async move { + let device_id = metadata.device_id.clone(); + let device_id_clone = device_id.clone(); + let device_id_string = device_id.as_str().to_owned(); let device_display_name = if services.globals.allow_device_name_federation() { - metadata.display_name + metadata.display_name.clone() } else { Some(device_id_string) }; - Some(UserDevice { - keys: services - .users - .get_device_keys(&body.user_id, &metadata.device_id) - .ok()??, - device_id: metadata.device_id, - device_display_name, - }) + + services + .users + .get_device_keys(user_id, &device_id_clone) + .map_ok(|keys| UserDevice { + device_id, + keys, + device_display_name, + }) + .map(Result::ok) + .await }) - .collect(), + .collect() + .await, master_key: services .users - .get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?, + .get_master_key(None, &body.user_id, &|u| u.server_name() == origin) + .await + .ok(), self_signing_key: services .users - .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?, + .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin) + .await + .ok(), }) } diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 71364734..cb957bc9 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -67,6 +67,7 @@ ctor.workspace = true cyborgtime.workspace = true either.workspace = true figment.workspace = true +futures.workspace = true http-body-util.workspace = true http.workspace = true image.workspace = true diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 48b9b58f..79e3d5b4 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -86,7 +86,7 @@ pub enum Error { #[error("There was a problem with the '{0}' directive in your configuration: {1}")] Config(&'static str, Cow<'static, str>), #[error("{0}")] - Conflict(&'static str), // This is only needed for when a room alias already exists + Conflict(Cow<'static, str>), // This is only needed for when a room alias already exists #[error(transparent)] ContentDisposition(#[from] ruma::http_headers::ContentDispositionParseError), #[error("{0}")] @@ -107,6 +107,8 @@ pub enum Error { Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode), #[error(transparent)] Ruma(#[from] ruma::api::client::error::Error), + #[error(transparent)] + StateRes(#[from] ruma::state_res::Error), #[error("uiaa")] Uiaa(ruma::api::client::uiaa::UiaaInfo), diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 439c831a..cf9ffe64 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -3,8 +3,6 @@ mod count; use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; -pub use builder::PduBuilder; -pub use count::PduCount; use ruma::{ canonical_json::redact_content_in_place, events::{ @@ -23,7 +21,8 @@ use serde_json::{ value::{to_raw_value, RawValue as RawJsonValue}, }; -use crate::{err, warn, Error}; +pub use self::{builder::PduBuilder, count::PduCount}; +use crate::{err, warn, Error, Result}; #[derive(Deserialize)] struct ExtractRedactedBecause { @@ -65,11 +64,12 @@ pub struct PduEvent { impl PduEvent { #[tracing::instrument(skip(self), level = "debug")] - pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> { + pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result<()> { self.unsigned = None; let mut content = serde_json::from_str(self.content.get()) .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; + redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; @@ -98,31 +98,38 @@ impl PduEvent { unsigned.redacted_because.is_some() } - pub fn remove_transaction_id(&mut self) -> crate::Result<()> { - if let Some(unsigned) = &self.unsigned { - let mut unsigned: BTreeMap> = serde_json::from_str(unsigned.get()) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; - unsigned.remove("transaction_id"); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); - } + pub fn remove_transaction_id(&mut self) -> Result<()> { + let Some(unsigned) = &self.unsigned else { + return Ok(()); + }; + + let mut unsigned: BTreeMap> = + serde_json::from_str(unsigned.get()).map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + unsigned.remove("transaction_id"); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); Ok(()) } - pub fn add_age(&mut self) -> crate::Result<()> { + pub fn add_age(&mut self) -> Result<()> { let mut unsigned: BTreeMap> = self .unsigned .as_ref() .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; // deliberately allowing for the possibility of negative age let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into(); let then: i128 = self.origin_server_ts.into(); let this_age = now.saturating_sub(then); - unsigned.insert("age".to_owned(), to_raw_value(&this_age).unwrap()); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); + unsigned.insert("age".to_owned(), to_raw_value(&this_age).expect("age is valid")); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); Ok(()) } @@ -369,9 +376,9 @@ impl state_res::Event for PduEvent { fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } - fn prev_events(&self) -> Box + '_> { Box::new(self.prev_events.iter()) } + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.prev_events.iter() } - fn auth_events(&self) -> Box + '_> { Box::new(self.auth_events.iter()) } + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.auth_events.iter() } fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } } @@ -395,7 +402,7 @@ impl Ord for PduEvent { /// CanonicalJsonValue>`. pub fn gen_event_id_canonical_json( pdu: &RawJsonValue, room_version_id: &RoomVersionId, -) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { +) -> Result<(OwnedEventId, CanonicalJsonObject)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; diff --git a/src/core/result/log_debug_err.rs b/src/core/result/log_debug_err.rs index be2000ae..8835afd1 100644 --- a/src/core/result/log_debug_err.rs +++ b/src/core/result/log_debug_err.rs @@ -1,18 +1,14 @@ -use std::fmt; +use std::fmt::Debug; use tracing::Level; use super::{DebugInspect, Result}; use crate::error; -pub trait LogDebugErr -where - E: fmt::Debug, -{ +pub trait LogDebugErr { #[must_use] fn err_debug_log(self, level: Level) -> Self; - #[inline] #[must_use] fn log_debug_err(self) -> Self where @@ -22,15 +18,9 @@ where } } -impl LogDebugErr for Result -where - E: fmt::Debug, -{ +impl LogDebugErr for Result { #[inline] - fn err_debug_log(self, level: Level) -> Self - where - Self: Sized, - { + fn err_debug_log(self, level: Level) -> Self { self.debug_inspect_err(|error| error::inspect_debug_log_level(&error, level)) } } diff --git a/src/core/result/log_err.rs b/src/core/result/log_err.rs index 079571f5..374a5e59 100644 --- a/src/core/result/log_err.rs +++ b/src/core/result/log_err.rs @@ -1,18 +1,14 @@ -use std::fmt; +use std::fmt::Display; use tracing::Level; use super::Result; use crate::error; -pub trait LogErr -where - E: fmt::Display, -{ +pub trait LogErr { #[must_use] fn err_log(self, level: Level) -> Self; - #[inline] #[must_use] fn log_err(self) -> Self where @@ -22,15 +18,7 @@ where } } -impl LogErr for Result -where - E: fmt::Display, -{ +impl LogErr for Result { #[inline] - fn err_log(self, level: Level) -> Self - where - Self: Sized, - { - self.inspect_err(|error| error::inspect_log_level(&error, level)) - } + fn err_log(self, level: Level) -> Self { self.inspect_err(|error| error::inspect_log_level(&error, level)) } } diff --git a/src/core/utils/algorithm.rs b/src/core/utils/algorithm.rs deleted file mode 100644 index 9bc1bc8a..00000000 --- a/src/core/utils/algorithm.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::cmp::Ordering; - -#[allow(clippy::impl_trait_in_params)] -pub fn common_elements( - mut iterators: impl Iterator>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, -) -> Option>> { - let first_iterator = iterators.next()?; - let mut other_iterators = iterators.map(Iterator::peekable).collect::>(); - - Some(first_iterator.filter(move |target| { - other_iterators.iter_mut().all(|it| { - while let Some(element) = it.peek() { - match check_order(element, target) { - Ordering::Greater => return false, // We went too far - Ordering::Equal => return true, // Element is in both iters - Ordering::Less => { - // Keep searching - it.next(); - }, - } - } - false - }) - })) -} diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 03b755e9..b1ea3709 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,4 +1,3 @@ -pub mod algorithm; pub mod bytes; pub mod content_disposition; pub mod debug; @@ -9,25 +8,30 @@ pub mod json; pub mod math; pub mod mutex_map; pub mod rand; +pub mod set; +pub mod stream; pub mod string; pub mod sys; mod tests; pub mod time; +pub use ::conduit_macros::implement; pub use ::ctor::{ctor, dtor}; -pub use algorithm::common_elements; -pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; -pub use conduit_macros::implement; -pub use debug::slice_truncated as debug_slice_truncated; -pub use hash::calculate_hash; -pub use html::Escape as HtmlEscape; -pub use json::{deserialize_from_str, to_canonical_object}; -pub use math::clamp; -pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; -pub use rand::string as random_string; -pub use string::{str_from_bytes, string_from_bytes}; -pub use sys::available_parallelism; -pub use time::now_millis as millis_since_unix_epoch; + +pub use self::{ + bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}, + debug::slice_truncated as debug_slice_truncated, + hash::calculate_hash, + html::Escape as HtmlEscape, + json::{deserialize_from_str, to_canonical_object}, + math::clamp, + mutex_map::{Guard as MutexMapGuard, MutexMap}, + rand::string as random_string, + stream::{IterStream, ReadyExt, TryReadyExt}, + string::{str_from_bytes, string_from_bytes}, + sys::available_parallelism, + time::now_millis as millis_since_unix_epoch, +}; #[inline] pub fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } diff --git a/src/core/utils/set.rs b/src/core/utils/set.rs new file mode 100644 index 00000000..563f9df5 --- /dev/null +++ b/src/core/utils/set.rs @@ -0,0 +1,47 @@ +use std::cmp::{Eq, Ord}; + +use crate::{is_equal_to, is_less_than}; + +/// Intersection of sets +/// +/// Outputs the set of elements common to all input sets. Inputs do not have to +/// be sorted. If inputs are sorted a more optimized function is available in +/// this suite and should be used. +pub fn intersection(mut input: Iters) -> impl Iterator + Send +where + Iters: Iterator + Clone + Send, + Iter: Iterator + Send, + Item: Eq + Send, +{ + input.next().into_iter().flat_map(move |first| { + let input = input.clone(); + first.filter(move |targ| { + input + .clone() + .all(|mut other| other.any(is_equal_to!(*targ))) + }) + }) +} + +/// Intersection of sets +/// +/// Outputs the set of elements common to all input sets. Inputs must be sorted. +pub fn intersection_sorted(mut input: Iters) -> impl Iterator + Send +where + Iters: Iterator + Clone + Send, + Iter: Iterator + Send, + Item: Eq + Ord + Send, +{ + input.next().into_iter().flat_map(move |first| { + let mut input = input.clone().collect::>(); + first.filter(move |targ| { + input.iter_mut().all(|it| { + it.by_ref() + .skip_while(is_less_than!(targ)) + .peekable() + .peek() + .is_some_and(is_equal_to!(targ)) + }) + }) + }) +} diff --git a/src/core/utils/stream/cloned.rs b/src/core/utils/stream/cloned.rs new file mode 100644 index 00000000..d6a0e647 --- /dev/null +++ b/src/core/utils/stream/cloned.rs @@ -0,0 +1,20 @@ +use std::clone::Clone; + +use futures::{stream::Map, Stream, StreamExt}; + +pub trait Cloned<'a, T, S> +where + S: Stream, + T: Clone + 'a, +{ + fn cloned(self) -> Map T>; +} + +impl<'a, T, S> Cloned<'a, T, S> for S +where + S: Stream, + T: Clone + 'a, +{ + #[inline] + fn cloned(self) -> Map T> { self.map(Clone::clone) } +} diff --git a/src/core/utils/stream/expect.rs b/src/core/utils/stream/expect.rs new file mode 100644 index 00000000..3ab7181a --- /dev/null +++ b/src/core/utils/stream/expect.rs @@ -0,0 +1,17 @@ +use futures::{Stream, StreamExt, TryStream}; + +use crate::Result; + +pub trait TryExpect<'a, Item> { + fn expect_ok(self) -> impl Stream + Send + 'a; +} + +impl<'a, T, Item> TryExpect<'a, Item> for T +where + T: Stream> + TryStream + Send + 'a, +{ + #[inline] + fn expect_ok(self: T) -> impl Stream + Send + 'a { + self.map(|res| res.expect("stream expectation failure")) + } +} diff --git a/src/core/utils/stream/ignore.rs b/src/core/utils/stream/ignore.rs new file mode 100644 index 00000000..997aa4ba --- /dev/null +++ b/src/core/utils/stream/ignore.rs @@ -0,0 +1,21 @@ +use futures::{future::ready, Stream, StreamExt, TryStream}; + +use crate::{Error, Result}; + +pub trait TryIgnore<'a, Item> { + fn ignore_err(self) -> impl Stream + Send + 'a; + + fn ignore_ok(self) -> impl Stream + Send + 'a; +} + +impl<'a, T, Item> TryIgnore<'a, Item> for T +where + T: Stream> + TryStream + Send + 'a, + Item: Send + 'a, +{ + #[inline] + fn ignore_err(self: T) -> impl Stream + Send + 'a { self.filter_map(|res| ready(res.ok())) } + + #[inline] + fn ignore_ok(self: T) -> impl Stream + Send + 'a { self.filter_map(|res| ready(res.err())) } +} diff --git a/src/core/utils/stream/iter_stream.rs b/src/core/utils/stream/iter_stream.rs new file mode 100644 index 00000000..69edf64f --- /dev/null +++ b/src/core/utils/stream/iter_stream.rs @@ -0,0 +1,27 @@ +use futures::{ + stream, + stream::{Stream, TryStream}, + StreamExt, +}; + +pub trait IterStream { + /// Convert an Iterator into a Stream + fn stream(self) -> impl Stream::Item> + Send; + + /// Convert an Iterator into a TryStream + fn try_stream(self) -> impl TryStream::Item, Error = crate::Error> + Send; +} + +impl IterStream for I +where + I: IntoIterator + Send, + ::IntoIter: Send, +{ + #[inline] + fn stream(self) -> impl Stream::Item> + Send { stream::iter(self) } + + #[inline] + fn try_stream(self) -> impl TryStream::Item, Error = crate::Error> + Send { + self.stream().map(Ok) + } +} diff --git a/src/core/utils/stream/mod.rs b/src/core/utils/stream/mod.rs new file mode 100644 index 00000000..781bd522 --- /dev/null +++ b/src/core/utils/stream/mod.rs @@ -0,0 +1,13 @@ +mod cloned; +mod expect; +mod ignore; +mod iter_stream; +mod ready; +mod try_ready; + +pub use cloned::Cloned; +pub use expect::TryExpect; +pub use ignore::TryIgnore; +pub use iter_stream::IterStream; +pub use ready::ReadyExt; +pub use try_ready::TryReadyExt; diff --git a/src/core/utils/stream/ready.rs b/src/core/utils/stream/ready.rs new file mode 100644 index 00000000..13f730a7 --- /dev/null +++ b/src/core/utils/stream/ready.rs @@ -0,0 +1,109 @@ +//! Synchronous combinator extensions to futures::Stream + +use futures::{ + future::{ready, Ready}, + stream::{Any, Filter, FilterMap, Fold, ForEach, SkipWhile, Stream, StreamExt, TakeWhile}, +}; + +/// Synchronous combinators to augment futures::StreamExt. Most Stream +/// combinators take asynchronous arguments, but often only simple predicates +/// are required to steer a Stream like an Iterator. This suite provides a +/// convenience to reduce boilerplate by de-cluttering non-async predicates. +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait ReadyExt +where + S: Stream + Send + ?Sized, + Self: Stream + Send + Sized, +{ + fn ready_any(self, f: F) -> Any, impl FnMut(S::Item) -> Ready> + where + F: Fn(S::Item) -> bool; + + fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a; + + fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(S::Item) -> Ready>> + where + F: Fn(S::Item) -> Option; + + fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, S::Item) -> Ready> + where + F: Fn(T, S::Item) -> T; + + fn ready_for_each(self, f: F) -> ForEach, impl FnMut(S::Item) -> Ready<()>> + where + F: FnMut(S::Item); + + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a; + + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a; +} + +impl ReadyExt for S +where + S: Stream + Send + ?Sized, + Self: Stream + Send + Sized, +{ + #[inline] + fn ready_any(self, f: F) -> Any, impl FnMut(S::Item) -> Ready> + where + F: Fn(S::Item) -> bool, + { + self.any(move |t| ready(f(t))) + } + + #[inline] + fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a, + { + self.filter(move |t| ready(f(t))) + } + + #[inline] + fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(S::Item) -> Ready>> + where + F: Fn(S::Item) -> Option, + { + self.filter_map(move |t| ready(f(t))) + } + + #[inline] + fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, S::Item) -> Ready> + where + F: Fn(T, S::Item) -> T, + { + self.fold(init, move |a, t| ready(f(a, t))) + } + + #[inline] + #[allow(clippy::unit_arg)] + fn ready_for_each(self, mut f: F) -> ForEach, impl FnMut(S::Item) -> Ready<()>> + where + F: FnMut(S::Item), + { + self.for_each(move |t| ready(f(t))) + } + + #[inline] + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a, + { + self.take_while(move |t| ready(f(t))) + } + + #[inline] + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&S::Item) -> Ready + 'a> + where + F: Fn(&S::Item) -> bool + 'a, + { + self.skip_while(move |t| ready(f(t))) + } +} diff --git a/src/core/utils/stream/try_ready.rs b/src/core/utils/stream/try_ready.rs new file mode 100644 index 00000000..ab37d9b3 --- /dev/null +++ b/src/core/utils/stream/try_ready.rs @@ -0,0 +1,35 @@ +//! Synchronous combinator extensions to futures::TryStream + +use futures::{ + future::{ready, Ready}, + stream::{AndThen, TryStream, TryStreamExt}, +}; + +use crate::Result; + +/// Synchronous combinators to augment futures::TryStreamExt. +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait TryReadyExt +where + S: TryStream> + Send + ?Sized, + Self: TryStream + Send + Sized, +{ + fn ready_and_then(self, f: F) -> AndThen>, impl FnMut(S::Ok) -> Ready>> + where + F: Fn(S::Ok) -> Result; +} + +impl TryReadyExt for S +where + S: TryStream> + Send + ?Sized, + Self: TryStream + Send + Sized, +{ + #[inline] + fn ready_and_then(self, f: F) -> AndThen>, impl FnMut(S::Ok) -> Ready>> + where + F: Fn(S::Ok) -> Result, + { + self.and_then(move |t| ready(f(t))) + } +} diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index 5880470a..84d35936 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -107,3 +107,133 @@ async fn mutex_map_contend() { tokio::try_join!(join_b, join_a).expect("joined"); assert!(map.is_empty(), "Must be empty"); } + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_none() { + use utils::set::intersection; + + let a: [&str; 0] = []; + let b: [&str; 0] = []; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + + let a: [&str; 0] = []; + let b = ["abc", "def"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + let i = [b.iter(), a.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + let i = [a.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + + let a = ["foo", "bar", "baz"]; + let b = ["def", "hij", "klm", "nop"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_all() { + use utils::set::intersection; + + let a = ["foo"]; + let b = ["foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["foo", "bar"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "bar"].iter())); + let i = [b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + + let a = ["foo", "bar", "baz"]; + let b = ["baz", "foo", "bar"]; + let c = ["bar", "baz", "foo"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "bar", "baz"].iter())); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_some() { + use utils::set::intersection; + + let a = ["foo"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + let i = [b.iter(), a.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["abcdef", "foo", "hijkl", "abc"]; + let b = ["hij", "bar", "baz", "abc", "foo"]; + let c = ["abc", "xyz", "foo", "ghi"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "abc"].iter())); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_sorted_some() { + use utils::set::intersection_sorted; + + let a = ["bar"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar"].iter())); + let i = [b.iter(), a.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar"].iter())); + + let a = ["aaa", "ccc", "eee", "ggg"]; + let b = ["aaa", "bbb", "ccc", "ddd", "eee"]; + let c = ["bbb", "ccc", "eee", "fff"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["ccc", "eee"].iter())); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_sorted_all() { + use utils::set::intersection_sorted; + + let a = ["foo"]; + let b = ["foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["bar", "foo"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + let i = [b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + + let a = ["bar", "baz", "foo"]; + let b = ["bar", "baz", "foo"]; + let c = ["bar", "baz", "foo"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "baz", "foo"].iter())); +} diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml index 34d98416..b5eb7612 100644 --- a/src/database/Cargo.toml +++ b/src/database/Cargo.toml @@ -37,8 +37,11 @@ zstd_compression = [ [dependencies] conduit-core.workspace = true const-str.workspace = true +futures.workspace = true log.workspace = true rust-rocksdb.workspace = true +serde.workspace = true +serde_json.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/src/database/database.rs b/src/database/database.rs index c357d50f..ac6f62e9 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -37,7 +37,7 @@ impl Database { pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } #[inline] - pub fn iter_maps(&self) -> impl Iterator + '_ { self.map.iter() } + pub fn iter_maps(&self) -> impl Iterator + Send + '_ { self.map.iter() } } impl Index<&str> for Database { diff --git a/src/database/de.rs b/src/database/de.rs new file mode 100644 index 00000000..8ce25aa3 --- /dev/null +++ b/src/database/de.rs @@ -0,0 +1,261 @@ +use conduit::{checked, debug::DebugInspect, err, utils::string, Error, Result}; +use serde::{ + de, + de::{DeserializeSeed, Visitor}, + Deserialize, +}; + +pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result +where + T: Deserialize<'a>, +{ + let mut deserializer = Deserializer { + buf, + pos: 0, + }; + + T::deserialize(&mut deserializer).debug_inspect(|_| { + deserializer + .finished() + .expect("deserialization failed to consume trailing bytes"); + }) +} + +pub(crate) struct Deserializer<'de> { + buf: &'de [u8], + pos: usize, +} + +/// Directive to ignore a record. This type can be used to skip deserialization +/// until the next separator is found. +#[derive(Debug, Deserialize)] +pub struct Ignore; + +impl<'de> Deserializer<'de> { + const SEP: u8 = b'\xFF'; + + fn finished(&self) -> Result<()> { + let pos = self.pos; + let len = self.buf.len(); + let parsed = &self.buf[0..pos]; + let unparsed = &self.buf[pos..]; + let remain = checked!(len - pos)?; + let trailing_sep = remain == 1 && unparsed[0] == Self::SEP; + (remain == 0 || trailing_sep) + .then_some(()) + .ok_or(err!(SerdeDe( + "{remain} trailing of {len} bytes not deserialized.\n{parsed:?}\n{unparsed:?}", + ))) + } + + #[inline] + fn record_next(&mut self) -> &'de [u8] { + self.buf[self.pos..] + .split(|b| *b == Deserializer::SEP) + .inspect(|record| self.inc_pos(record.len())) + .next() + .expect("remainder of buf even if SEP was not found") + } + + #[inline] + fn record_trail(&mut self) -> &'de [u8] { + let record = &self.buf[self.pos..]; + self.inc_pos(record.len()); + record + } + + #[inline] + fn record_start(&mut self) { + let started = self.pos != 0; + debug_assert!( + !started || self.buf[self.pos] == Self::SEP, + "Missing expected record separator at current position" + ); + + self.inc_pos(started.into()); + } + + #[inline] + fn inc_pos(&mut self, n: usize) { + self.pos = self.pos.saturating_add(n); + debug_assert!(self.pos <= self.buf.len(), "pos out of range"); + } +} + +impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn deserialize_map(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Map not implemented") + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_tuple_struct(self, _name: &'static str, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_struct( + self, _name: &'static str, _fields: &'static [&'static str], _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Struct not implemented") + } + + fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + match name { + "Ignore" => self.record_next(), + _ => unimplemented!("Unrecognized deserialization Directive {name:?}"), + }; + + visitor.visit_unit() + } + + fn deserialize_newtype_struct(self, _name: &'static str, _visitor: V) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Newtype Struct not implemented") + } + + fn deserialize_enum( + self, _name: &'static str, _variants: &'static [&'static str], _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Enum not implemented") + } + + fn deserialize_option>(self, _visitor: V) -> Result { + unimplemented!("deserialize Option not implemented") + } + + fn deserialize_bool>(self, _visitor: V) -> Result { + unimplemented!("deserialize bool not implemented") + } + + fn deserialize_i8>(self, _visitor: V) -> Result { + unimplemented!("deserialize i8 not implemented") + } + + fn deserialize_i16>(self, _visitor: V) -> Result { + unimplemented!("deserialize i16 not implemented") + } + + fn deserialize_i32>(self, _visitor: V) -> Result { + unimplemented!("deserialize i32 not implemented") + } + + fn deserialize_i64>(self, visitor: V) -> Result { + let bytes: [u8; size_of::()] = self.buf[self.pos..].try_into()?; + self.pos = self.pos.saturating_add(size_of::()); + visitor.visit_i64(i64::from_be_bytes(bytes)) + } + + fn deserialize_u8>(self, _visitor: V) -> Result { + unimplemented!("deserialize u8 not implemented") + } + + fn deserialize_u16>(self, _visitor: V) -> Result { + unimplemented!("deserialize u16 not implemented") + } + + fn deserialize_u32>(self, _visitor: V) -> Result { + unimplemented!("deserialize u32 not implemented") + } + + fn deserialize_u64>(self, visitor: V) -> Result { + let bytes: [u8; size_of::()] = self.buf[self.pos..].try_into()?; + self.pos = self.pos.saturating_add(size_of::()); + visitor.visit_u64(u64::from_be_bytes(bytes)) + } + + fn deserialize_f32>(self, _visitor: V) -> Result { + unimplemented!("deserialize f32 not implemented") + } + + fn deserialize_f64>(self, _visitor: V) -> Result { + unimplemented!("deserialize f64 not implemented") + } + + fn deserialize_char>(self, _visitor: V) -> Result { + unimplemented!("deserialize char not implemented") + } + + fn deserialize_str>(self, visitor: V) -> Result { + let input = self.record_next(); + let out = string::str_from_bytes(input)?; + visitor.visit_borrowed_str(out) + } + + fn deserialize_string>(self, visitor: V) -> Result { + let input = self.record_next(); + let out = string::string_from_bytes(input)?; + visitor.visit_string(out) + } + + fn deserialize_bytes>(self, visitor: V) -> Result { + let input = self.record_trail(); + visitor.visit_borrowed_bytes(input) + } + + fn deserialize_byte_buf>(self, _visitor: V) -> Result { + unimplemented!("deserialize Byte Buf not implemented") + } + + fn deserialize_unit>(self, _visitor: V) -> Result { + unimplemented!("deserialize Unit Struct not implemented") + } + + fn deserialize_identifier>(self, _visitor: V) -> Result { + unimplemented!("deserialize Identifier not implemented") + } + + fn deserialize_ignored_any>(self, _visitor: V) -> Result { + unimplemented!("deserialize Ignored Any not implemented") + } + + fn deserialize_any>(self, _visitor: V) -> Result { + unimplemented!("deserialize any not implemented") + } +} + +impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de>, + { + if self.pos >= self.buf.len() { + return Ok(None); + } + + self.record_start(); + seed.deserialize(&mut **self).map(Some) + } +} diff --git a/src/database/deserialized.rs b/src/database/deserialized.rs new file mode 100644 index 00000000..7da112d5 --- /dev/null +++ b/src/database/deserialized.rs @@ -0,0 +1,34 @@ +use std::convert::identity; + +use conduit::Result; +use serde::Deserialize; + +pub trait Deserialized { + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>; + + fn map_json(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>; + + #[inline] + fn deserialized(self) -> Result + where + T: for<'de> Deserialize<'de>, + Self: Sized, + { + self.map_de(identity::) + } + + #[inline] + fn deserialized_json(self) -> Result + where + T: for<'de> Deserialize<'de>, + Self: Sized, + { + self.map_json(identity::) + } +} diff --git a/src/database/engine.rs b/src/database/engine.rs index 3850c1d3..067232e6 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -106,7 +106,7 @@ impl Engine { })) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "trace")] pub(crate) fn open_cf(&self, name: &str) -> Result>> { let mut cfs = self.cfs.lock().expect("locked"); if !cfs.contains(name) { diff --git a/src/database/handle.rs b/src/database/handle.rs index 0b45a75f..89d87137 100644 --- a/src/database/handle.rs +++ b/src/database/handle.rs @@ -1,6 +1,10 @@ -use std::ops::Deref; +use std::{fmt, fmt::Debug, ops::Deref}; +use conduit::Result; use rocksdb::DBPinnableSlice; +use serde::{Deserialize, Serialize, Serializer}; + +use crate::{keyval::deserialize_val, Deserialized, Slice}; pub struct Handle<'a> { val: DBPinnableSlice<'a>, @@ -14,14 +18,91 @@ impl<'a> From> for Handle<'a> { } } +impl Debug for Handle<'_> { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { + let val: &Slice = self; + let ptr = val.as_ptr(); + let len = val.len(); + write!(out, "Handle {{val: {{ptr: {ptr:?}, len: {len}}}}}") + } +} + +impl Serialize for Handle<'_> { + #[inline] + fn serialize(&self, serializer: S) -> Result { + let bytes: &Slice = self; + serializer.serialize_bytes(bytes) + } +} + impl Deref for Handle<'_> { - type Target = [u8]; + type Target = Slice; #[inline] fn deref(&self) -> &Self::Target { &self.val } } -impl AsRef<[u8]> for Handle<'_> { +impl AsRef for Handle<'_> { #[inline] - fn as_ref(&self) -> &[u8] { &self.val } + fn as_ref(&self) -> &Slice { &self.val } +} + +impl Deserialized for Result> { + #[inline] + fn map_json(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self?.map_json(f) + } + + #[inline] + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self?.map_de(f) + } +} + +impl<'a> Deserialized for Result<&'a Handle<'a>> { + #[inline] + fn map_json(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self.and_then(|handle| handle.map_json(f)) + } + + #[inline] + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self.and_then(|handle| handle.map_de(f)) + } +} + +impl<'a> Deserialized for &'a Handle<'a> { + fn map_json(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + serde_json::from_slice::(self.as_ref()) + .map_err(Into::into) + .map(f) + } + + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + deserialize_val(self.as_ref()).map(f) + } } diff --git a/src/database/iter.rs b/src/database/iter.rs deleted file mode 100644 index 4845e977..00000000 --- a/src/database/iter.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::{iter::FusedIterator, sync::Arc}; - -use conduit::Result; -use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, Direction, IteratorMode, ReadOptions}; - -use crate::{ - engine::Db, - result, - slice::{OwnedKeyVal, OwnedKeyValPair}, - Engine, -}; - -type Cursor<'cursor> = DBRawIteratorWithThreadMode<'cursor, Db>; - -struct State<'cursor> { - cursor: Cursor<'cursor>, - direction: Direction, - valid: bool, - init: bool, -} - -impl<'cursor> State<'cursor> { - pub(crate) fn new( - db: &'cursor Arc, cf: &'cursor Arc, opts: ReadOptions, mode: &IteratorMode<'_>, - ) -> Self { - let mut cursor = db.db.raw_iterator_cf_opt(&**cf, opts); - let direction = into_direction(mode); - let valid = seek_init(&mut cursor, mode); - Self { - cursor, - direction, - valid, - init: true, - } - } -} - -pub struct Iter<'cursor> { - state: State<'cursor>, -} - -impl<'cursor> Iter<'cursor> { - pub(crate) fn new( - db: &'cursor Arc, cf: &'cursor Arc, opts: ReadOptions, mode: &IteratorMode<'_>, - ) -> Self { - Self { - state: State::new(db, cf, opts, mode), - } - } -} - -impl Iterator for Iter<'_> { - type Item = OwnedKeyValPair; - - fn next(&mut self) -> Option { - if !self.state.init && self.state.valid { - seek_next(&mut self.state.cursor, self.state.direction); - } else if self.state.init { - self.state.init = false; - } - - self.state - .cursor - .item() - .map(OwnedKeyVal::from) - .map(OwnedKeyVal::to_tuple) - .or_else(|| { - when_invalid(&mut self.state).expect("iterator invalidated due to error"); - None - }) - } -} - -impl FusedIterator for Iter<'_> {} - -fn when_invalid(state: &mut State<'_>) -> Result<()> { - state.valid = false; - result(state.cursor.status()) -} - -fn seek_next(cursor: &mut Cursor<'_>, direction: Direction) { - match direction { - Direction::Forward => cursor.next(), - Direction::Reverse => cursor.prev(), - } -} - -fn seek_init(cursor: &mut Cursor<'_>, mode: &IteratorMode<'_>) -> bool { - use Direction::{Forward, Reverse}; - use IteratorMode::{End, From, Start}; - - match mode { - Start => cursor.seek_to_first(), - End => cursor.seek_to_last(), - From(key, Forward) => cursor.seek(key), - From(key, Reverse) => cursor.seek_for_prev(key), - }; - - cursor.valid() -} - -fn into_direction(mode: &IteratorMode<'_>) -> Direction { - use Direction::{Forward, Reverse}; - use IteratorMode::{End, From, Start}; - - match mode { - Start | From(_, Forward) => Forward, - End | From(_, Reverse) => Reverse, - } -} diff --git a/src/database/keyval.rs b/src/database/keyval.rs new file mode 100644 index 00000000..c9d25977 --- /dev/null +++ b/src/database/keyval.rs @@ -0,0 +1,83 @@ +use conduit::Result; +use serde::Deserialize; + +use crate::de; + +pub(crate) type OwnedKeyVal = (Vec, Vec); +pub(crate) type OwnedKey = Vec; +pub(crate) type OwnedVal = Vec; + +pub type KeyVal<'a, K = &'a Slice, V = &'a Slice> = (Key<'a, K>, Val<'a, V>); +pub type Key<'a, T = &'a Slice> = T; +pub type Val<'a, T = &'a Slice> = T; + +pub type Slice = [u8]; + +#[inline] +pub(crate) fn _expect_deserialize<'a, K, V>(kv: Result>) -> KeyVal<'a, K, V> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + result_deserialize(kv).expect("failed to deserialize result key/val") +} + +#[inline] +pub(crate) fn _expect_deserialize_key<'a, K>(key: Result>) -> Key<'a, K> +where + K: Deserialize<'a>, +{ + result_deserialize_key(key).expect("failed to deserialize result key") +} + +#[inline] +pub(crate) fn result_deserialize<'a, K, V>(kv: Result>) -> Result> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + deserialize(kv?) +} + +#[inline] +pub(crate) fn result_deserialize_key<'a, K>(key: Result>) -> Result> +where + K: Deserialize<'a>, +{ + deserialize_key(key?) +} + +#[inline] +pub(crate) fn deserialize<'a, K, V>(kv: KeyVal<'a>) -> Result> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + Ok((deserialize_key::(kv.0)?, deserialize_val::(kv.1)?)) +} + +#[inline] +pub(crate) fn deserialize_key<'a, K>(key: Key<'a>) -> Result> +where + K: Deserialize<'a>, +{ + de::from_slice::(key) +} + +#[inline] +pub(crate) fn deserialize_val<'a, V>(val: Val<'a>) -> Result> +where + V: Deserialize<'a>, +{ + de::from_slice::(val) +} + +#[inline] +#[must_use] +pub fn to_owned(kv: KeyVal<'_>) -> OwnedKeyVal { (kv.0.to_owned(), kv.1.to_owned()) } + +#[inline] +pub fn key(kv: KeyVal<'_, K, V>) -> Key<'_, K> { kv.0 } + +#[inline] +pub fn val(kv: KeyVal<'_, K, V>) -> Val<'_, V> { kv.1 } diff --git a/src/database/map.rs b/src/database/map.rs index ddae8c81..a3cf32d4 100644 --- a/src/database/map.rs +++ b/src/database/map.rs @@ -1,15 +1,39 @@ -use std::{ffi::CStr, future::Future, mem::size_of, pin::Pin, sync::Arc}; +mod count; +mod keys; +mod keys_from; +mod keys_prefix; +mod rev_keys; +mod rev_keys_from; +mod rev_keys_prefix; +mod rev_stream; +mod rev_stream_from; +mod rev_stream_prefix; +mod stream; +mod stream_from; +mod stream_prefix; -use conduit::{utils, Result}; -use rocksdb::{ - AsColumnFamilyRef, ColumnFamily, Direction, IteratorMode, ReadOptions, WriteBatchWithTransaction, WriteOptions, +use std::{ + convert::AsRef, + ffi::CStr, + fmt, + fmt::{Debug, Display}, + future::Future, + io::Write, + pin::Pin, + sync::Arc, }; +use conduit::{err, Result}; +use futures::future; +use rocksdb::{AsColumnFamilyRef, ColumnFamily, ReadOptions, WriteBatchWithTransaction, WriteOptions}; +use serde::Serialize; + use crate::{ - or_else, result, - slice::{Byte, Key, KeyVal, OwnedKey, OwnedKeyValPair, OwnedVal, Val}, + keyval::{OwnedKey, OwnedVal}, + ser, + util::{map_err, or_else}, watchers::Watchers, - Engine, Handle, Iter, + Engine, Handle, }; pub struct Map { @@ -21,8 +45,6 @@ pub struct Map { read_options: ReadOptions, } -type OwnedKeyValPairIter<'a> = Box + Send + 'a>; - impl Map { pub(crate) fn open(db: &Arc, name: &str) -> Result> { Ok(Arc::new(Self { @@ -35,14 +57,125 @@ impl Map { })) } - pub fn get(&self, key: &Key) -> Result>> { - let read_options = &self.read_options; - let res = self.db.db.get_pinned_cf_opt(&self.cf(), key, read_options); - - Ok(result(res)?.map(Handle::from)) + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn del(&self, key: &K) + where + K: Serialize + ?Sized + Debug, + { + let mut buf = Vec::::with_capacity(64); + self.bdel(key, &mut buf); } - pub fn multi_get(&self, keys: &[&Key]) -> Result>> { + #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] + pub fn bdel(&self, key: &K, buf: &mut B) + where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, + { + let key = ser::serialize(buf, key).expect("failed to serialize deletion key"); + self.remove(&key); + } + + #[tracing::instrument(level = "trace")] + pub fn remove(&self, key: &K) + where + K: AsRef<[u8]> + ?Sized + Debug, + { + let write_options = &self.write_options; + self.db + .db + .delete_cf_opt(&self.cf(), key, write_options) + .or_else(or_else) + .expect("database remove error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } + } + + #[tracing::instrument(skip(self, value), fields(%self), level = "trace")] + pub fn insert(&self, key: &K, value: &V) + where + K: AsRef<[u8]> + ?Sized + Debug, + V: AsRef<[u8]> + ?Sized, + { + let write_options = &self.write_options; + self.db + .db + .put_cf_opt(&self.cf(), key, value, write_options) + .or_else(or_else) + .expect("database insert error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } + + self.watchers.wake(key.as_ref()); + } + + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn insert_batch<'a, I, K, V>(&'a self, iter: I) + where + I: Iterator + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, + V: AsRef<[u8]> + Sized + 'a, + { + let mut batch = WriteBatchWithTransaction::::default(); + for (key, val) in iter { + batch.put_cf(&self.cf(), key.as_ref(), val.as_ref()); + } + + let write_options = &self.write_options; + self.db + .db + .write_opt(batch, write_options) + .or_else(or_else) + .expect("database insert batch error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } + } + + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn qry(&self, key: &K) -> impl Future>> + Send + where + K: Serialize + ?Sized + Debug, + { + let mut buf = Vec::::with_capacity(64); + self.bqry(key, &mut buf) + } + + #[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] + pub fn bqry(&self, key: &K, buf: &mut B) -> impl Future>> + Send + where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, + { + let key = ser::serialize(buf, key).expect("failed to serialize query key"); + let val = self.get(key); + future::ready(val) + } + + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn get(&self, key: &K) -> Result> + where + K: AsRef<[u8]> + ?Sized + Debug, + { + self.db + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options) + .map_err(map_err)? + .map(Handle::from) + .ok_or(err!(Request(NotFound("Not found in database")))) + } + + #[tracing::instrument(skip(self), fields(%self), level = "trace")] + pub fn multi_get<'a, I, K>(&self, keys: I) -> Vec> + where + I: Iterator + ExactSizeIterator + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, + { // Optimization can be `true` if key vector is pre-sorted **by the column // comparator**. const SORTED: bool = false; @@ -57,140 +190,25 @@ impl Map { match res { Ok(Some(res)) => ret.push(Some((*res).to_vec())), Ok(None) => ret.push(None), - Err(e) => return or_else(e), + Err(e) => or_else(e).expect("database multiget error"), } } - Ok(ret) + ret } - pub fn insert(&self, key: &Key, value: &Val) -> Result<()> { - let write_options = &self.write_options; - self.db - .db - .put_cf_opt(&self.cf(), key, value, write_options) - .or_else(or_else)?; - - if !self.db.corked() { - self.db.flush()?; - } - - self.watchers.wake(key); - - Ok(()) - } - - pub fn insert_batch<'a, I>(&'a self, iter: I) -> Result<()> + #[inline] + pub fn watch_prefix<'a, K>(&'a self, prefix: &K) -> Pin + Send + 'a>> where - I: Iterator>, + K: AsRef<[u8]> + ?Sized + Debug, { - let mut batch = WriteBatchWithTransaction::::default(); - for KeyVal(key, value) in iter { - batch.put_cf(&self.cf(), key, value); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn remove(&self, key: &Key) -> Result<()> { - let write_options = &self.write_options; - let res = self.db.db.delete_cf_opt(&self.cf(), key, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn remove_batch<'a, I>(&'a self, iter: I) -> Result<()> - where - I: Iterator, - { - let mut batch = WriteBatchWithTransaction::::default(); - for key in iter { - batch.delete_cf(&self.cf(), key); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn iter(&self) -> OwnedKeyValPairIter<'_> { - let mode = IteratorMode::Start; - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode)) - } - - pub fn iter_from(&self, from: &Key, reverse: bool) -> OwnedKeyValPairIter<'_> { - let direction = if reverse { - Direction::Reverse - } else { - Direction::Forward - }; - let mode = IteratorMode::From(from, direction); - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode)) - } - - pub fn scan_prefix(&self, prefix: OwnedKey) -> OwnedKeyValPairIter<'_> { - let mode = IteratorMode::From(&prefix, Direction::Forward); - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode).take_while(move |(k, _)| k.starts_with(&prefix))) - } - - pub fn increment(&self, key: &Key) -> Result<[Byte; size_of::()]> { - let old = self.get(key)?; - let new = utils::increment(old.as_deref()); - self.insert(key, &new)?; - - if !self.db.corked() { - self.db.flush()?; - } - - Ok(new) - } - - pub fn increment_batch<'a, I>(&'a self, iter: I) -> Result<()> - where - I: Iterator, - { - let mut batch = WriteBatchWithTransaction::::default(); - for key in iter { - let old = self.get(key)?; - let new = utils::increment(old.as_deref()); - batch.put_cf(&self.cf(), key, new); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn watch_prefix<'a>(&'a self, prefix: &Key) -> Pin + Send + 'a>> { - self.watchers.watch(prefix) + self.watchers.watch(prefix.as_ref()) } + #[inline] pub fn property_integer(&self, name: &CStr) -> Result { self.db.property_integer(&self.cf(), name) } + #[inline] pub fn property(&self, name: &str) -> Result { self.db.property(&self.cf(), name) } #[inline] @@ -199,12 +217,12 @@ impl Map { fn cf(&self) -> impl AsColumnFamilyRef + '_ { &*self.cf } } -impl<'a> IntoIterator for &'a Map { - type IntoIter = Box + Send + 'a>; - type Item = OwnedKeyValPair; +impl Debug for Map { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "Map {{name: {0}}}", self.name) } +} - #[inline] - fn into_iter(self) -> Self::IntoIter { self.iter() } +impl Display for Map { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "{0}", self.name) } } fn open(db: &Arc, name: &str) -> Result> { diff --git a/src/database/map/count.rs b/src/database/map/count.rs new file mode 100644 index 00000000..4356b71f --- /dev/null +++ b/src/database/map/count.rs @@ -0,0 +1,36 @@ +use std::{fmt::Debug, future::Future}; + +use conduit::implement; +use futures::stream::StreamExt; +use serde::Serialize; + +use crate::de::Ignore; + +/// Count the total number of entries in the map. +#[implement(super::Map)] +#[inline] +pub fn count(&self) -> impl Future + Send + '_ { self.keys::().count() } + +/// Count the number of entries in the map starting from a lower-bound. +/// +/// - From is a structured key +#[implement(super::Map)] +#[inline] +pub fn count_from<'a, P>(&'a self, from: &P) -> impl Future + Send + 'a +where + P: Serialize + ?Sized + Debug + 'a, +{ + self.keys_from::(from).count() +} + +/// Count the number of entries in the map matching a prefix. +/// +/// - Prefix is structured key +#[implement(super::Map)] +#[inline] +pub fn count_prefix<'a, P>(&'a self, prefix: &P) -> impl Future + Send + 'a +where + P: Serialize + ?Sized + Debug + 'a, +{ + self.keys_prefix::(prefix).count() +} diff --git a/src/database/map/keys.rs b/src/database/map/keys.rs new file mode 100644 index 00000000..2396494c --- /dev/null +++ b/src/database/map/keys.rs @@ -0,0 +1,21 @@ +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::Key, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys<'a, K>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, +{ + self.raw_keys().map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::Keys::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/keys_from.rs b/src/database/map/keys_from.rs new file mode 100644 index 00000000..1993750a --- /dev/null +++ b/src/database/map/keys_from.rs @@ -0,0 +1,49 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_from<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.keys_raw_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.raw_keys_from(&key) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_from_raw<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, +{ + self.raw_keys_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::Keys::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/keys_prefix.rs b/src/database/map/keys_prefix.rs new file mode 100644 index 00000000..d6c0927b --- /dev/null +++ b/src/database/map/keys_prefix.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_prefix<'a, K, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.keys_raw_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.raw_keys_from(&key) + .try_take_while(move |k: &Key<'_>| future::ok(k.starts_with(&key))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_prefix_raw<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, +{ + self.raw_keys_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_keys_from(prefix) + .try_take_while(|k: &Key<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/rev_keys.rs b/src/database/map/rev_keys.rs new file mode 100644 index 00000000..449ccfff --- /dev/null +++ b/src/database/map/rev_keys.rs @@ -0,0 +1,21 @@ +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::Key, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys<'a, K>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, +{ + self.rev_raw_keys().map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::KeysRev::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/rev_keys_from.rs b/src/database/map/rev_keys_from.rs new file mode 100644 index 00000000..e012e60a --- /dev/null +++ b/src/database/map/rev_keys_from.rs @@ -0,0 +1,49 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_from<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.rev_keys_raw_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_raw_keys_from(&key) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_from_raw<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, +{ + self.rev_raw_keys_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::KeysRev::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/rev_keys_prefix.rs b/src/database/map/rev_keys_prefix.rs new file mode 100644 index 00000000..162c4f9b --- /dev/null +++ b/src/database/map/rev_keys_prefix.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_prefix<'a, K, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.rev_keys_raw_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.rev_raw_keys_from(&key) + .try_take_while(move |k: &Key<'_>| future::ok(k.starts_with(&key))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_prefix_raw<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, +{ + self.rev_raw_keys_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.rev_raw_keys_from(prefix) + .try_take_while(|k: &Key<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/rev_stream.rs b/src/database/map/rev_stream.rs new file mode 100644 index 00000000..de22fd5c --- /dev/null +++ b/src/database/map/rev_stream.rs @@ -0,0 +1,29 @@ +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::KeyVal, stream}; + +/// Iterate key-value entries in the map from the end. +/// +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream<'a, K, V>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_raw_stream() + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map from the end. +/// +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::ItemsRev::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/rev_stream_from.rs b/src/database/map/rev_stream_from.rs new file mode 100644 index 00000000..650cf038 --- /dev/null +++ b/src/database/map/rev_stream_from.rs @@ -0,0 +1,68 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser, stream}; + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_stream_raw_from(&key) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_raw_stream_from(&key) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_from_raw<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_raw_stream_from(from) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::ItemsRev::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/rev_stream_prefix.rs b/src/database/map/rev_stream_prefix.rs new file mode 100644 index 00000000..9ef89e9c --- /dev/null +++ b/src/database/map/rev_stream_prefix.rs @@ -0,0 +1,74 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser}; + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_prefix<'a, K, V, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_stream_raw_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.rev_raw_stream_from(&key) + .try_take_while(move |(k, _): &KeyVal<'_>| future::ok(k.starts_with(&key))) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_prefix_raw<'a, K, V, P>( + &'a self, prefix: &'a P, +) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, + V: Deserialize<'a> + Send + 'a, +{ + self.rev_raw_stream_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.rev_raw_stream_from(prefix) + .try_take_while(|(k, _): &KeyVal<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/stream.rs b/src/database/map/stream.rs new file mode 100644 index 00000000..dfbea072 --- /dev/null +++ b/src/database/map/stream.rs @@ -0,0 +1,28 @@ +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::KeyVal, stream}; + +/// Iterate key-value entries in the map from the beginning. +/// +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream<'a, K, V>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.raw_stream().map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map from the beginning. +/// +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::Items::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/stream_from.rs b/src/database/map/stream_from.rs new file mode 100644 index 00000000..153d5bb6 --- /dev/null +++ b/src/database/map/stream_from.rs @@ -0,0 +1,68 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser, stream}; + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.stream_raw_from(&key) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.raw_stream_from(&key) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_from_raw<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.raw_stream_from(from) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::Items::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/stream_prefix.rs b/src/database/map/stream_prefix.rs new file mode 100644 index 00000000..56154a8b --- /dev/null +++ b/src/database/map/stream_prefix.rs @@ -0,0 +1,74 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser}; + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_prefix<'a, K, V, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.stream_raw_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.raw_stream_from(&key) + .try_take_while(move |(k, _): &KeyVal<'_>| future::ok(k.starts_with(&key))) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_prefix_raw<'a, K, V, P>( + &'a self, prefix: &'a P, +) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, + V: Deserialize<'a> + Send + 'a, +{ + self.raw_stream_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_stream_from(prefix) + .try_take_while(|(k, _): &KeyVal<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 6446624c..e66abf68 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,25 +1,35 @@ mod cork; mod database; +mod de; +mod deserialized; mod engine; mod handle; -mod iter; +pub mod keyval; mod map; pub mod maps; mod opts; -mod slice; +mod ser; +mod stream; mod util; mod watchers; +pub(crate) use self::{ + engine::Engine, + util::{or_else, result}, +}; + extern crate conduit_core as conduit; extern crate rust_rocksdb as rocksdb; -pub use database::Database; -pub(crate) use engine::Engine; -pub use handle::Handle; -pub use iter::Iter; -pub use map::Map; -pub use slice::{Key, KeyVal, OwnedKey, OwnedKeyVal, OwnedVal, Val}; -pub(crate) use util::{or_else, result}; +pub use self::{ + database::Database, + de::Ignore, + deserialized::Deserialized, + handle::Handle, + keyval::{KeyVal, Slice}, + map::Map, + ser::{Interfix, Separator}, +}; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/database/ser.rs b/src/database/ser.rs new file mode 100644 index 00000000..bd4bbd9a --- /dev/null +++ b/src/database/ser.rs @@ -0,0 +1,315 @@ +use std::io::Write; + +use conduit::{err, result::DebugInspect, utils::exchange, Error, Result}; +use serde::{ser, Serialize}; + +#[inline] +pub(crate) fn serialize_to_vec(val: &T) -> Result> +where + T: Serialize + ?Sized, +{ + let mut buf = Vec::with_capacity(64); + serialize(&mut buf, val)?; + + Ok(buf) +} + +#[inline] +pub(crate) fn serialize<'a, W, T>(out: &'a mut W, val: &'a T) -> Result<&'a [u8]> +where + W: Write + AsRef<[u8]>, + T: Serialize + ?Sized, +{ + let mut serializer = Serializer { + out, + depth: 0, + sep: false, + fin: false, + }; + + val.serialize(&mut serializer) + .map_err(|error| err!(SerdeSer("{error}"))) + .debug_inspect(|()| { + debug_assert_eq!(serializer.depth, 0, "Serialization completed at non-zero recursion level"); + })?; + + Ok((*out).as_ref()) +} + +pub(crate) struct Serializer<'a, W: Write> { + out: &'a mut W, + depth: u32, + sep: bool, + fin: bool, +} + +/// Directive to force separator serialization specifically for prefix keying +/// use. This is a quirk of the database schema and prefix iterations. +#[derive(Debug, Serialize)] +pub struct Interfix; + +/// Directive to force separator serialization. Separators are usually +/// serialized automatically. +#[derive(Debug, Serialize)] +pub struct Separator; + +impl Serializer<'_, W> { + const SEP: &'static [u8] = b"\xFF"; + + fn sequence_start(&mut self) { + debug_assert!(!self.is_finalized(), "Sequence start with finalization set"); + debug_assert!(!self.sep, "Sequence start with separator set"); + if cfg!(debug_assertions) { + self.depth = self.depth.saturating_add(1); + } + } + + fn sequence_end(&mut self) { + self.sep = false; + if cfg!(debug_assertions) { + self.depth = self.depth.saturating_sub(1); + } + } + + fn record_start(&mut self) -> Result<()> { + debug_assert!(!self.is_finalized(), "Starting a record after serialization finalized"); + exchange(&mut self.sep, true) + .then(|| self.separator()) + .unwrap_or(Ok(())) + } + + fn separator(&mut self) -> Result<()> { + debug_assert!(!self.is_finalized(), "Writing a separator after serialization finalized"); + self.out.write_all(Self::SEP).map_err(Into::into) + } + + fn set_finalized(&mut self) { + debug_assert!(!self.is_finalized(), "Finalization already set"); + if cfg!(debug_assertions) { + self.fin = true; + } + } + + fn is_finalized(&self) -> bool { self.fin } +} + +impl ser::Serializer for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + type SerializeMap = Self; + type SerializeSeq = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + + fn serialize_map(self, _len: Option) -> Result { + unimplemented!("serialize Map not implemented") + } + + fn serialize_seq(self, _len: Option) -> Result { + self.sequence_start(); + self.record_start()?; + Ok(self) + } + + fn serialize_tuple(self, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_tuple_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, + ) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_struct_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, + ) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result { + unimplemented!("serialize New Type Struct not implemented") + } + + fn serialize_newtype_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _value: &T, + ) -> Result { + unimplemented!("serialize New Type Variant not implemented") + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + match name { + "Interfix" => { + self.set_finalized(); + }, + "Separator" => { + self.separator()?; + }, + _ => unimplemented!("Unrecognized serialization directive: {name:?}"), + }; + + Ok(()) + } + + fn serialize_unit_variant(self, _name: &'static str, _idx: u32, _var: &'static str) -> Result { + unimplemented!("serialize Unit Variant not implemented") + } + + fn serialize_some(self, val: &T) -> Result { val.serialize(self) } + + fn serialize_none(self) -> Result { Ok(()) } + + fn serialize_char(self, v: char) -> Result { + let mut buf: [u8; 4] = [0; 4]; + self.serialize_str(v.encode_utf8(&mut buf)) + } + + fn serialize_str(self, v: &str) -> Result { self.serialize_bytes(v.as_bytes()) } + + fn serialize_bytes(self, v: &[u8]) -> Result { self.out.write_all(v).map_err(Error::Io) } + + fn serialize_f64(self, _v: f64) -> Result { unimplemented!("serialize f64 not implemented") } + + fn serialize_f32(self, _v: f32) -> Result { unimplemented!("serialize f32 not implemented") } + + fn serialize_i64(self, v: i64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + + fn serialize_i32(self, _v: i32) -> Result { unimplemented!("serialize i32 not implemented") } + + fn serialize_i16(self, _v: i16) -> Result { unimplemented!("serialize i16 not implemented") } + + fn serialize_i8(self, _v: i8) -> Result { unimplemented!("serialize i8 not implemented") } + + fn serialize_u64(self, v: u64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + + fn serialize_u32(self, _v: u32) -> Result { unimplemented!("serialize u32 not implemented") } + + fn serialize_u16(self, _v: u16) -> Result { unimplemented!("serialize u16 not implemented") } + + fn serialize_u8(self, v: u8) -> Result { self.out.write_all(&[v]).map_err(Error::Io) } + + fn serialize_bool(self, _v: bool) -> Result { unimplemented!("serialize bool not implemented") } + + fn serialize_unit(self) -> Result { unimplemented!("serialize unit not implemented") } +} + +impl ser::SerializeMap for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_key(&mut self, _key: &T) -> Result { + unimplemented!("serialize Map Key not implemented") + } + + fn serialize_value(&mut self, _val: &T) -> Result { + unimplemented!("serialize Map Val not implemented") + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeSeq for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, val: &T) -> Result { val.serialize(&mut **self) } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeStructVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTuple for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTupleStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTupleVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} diff --git a/src/database/slice.rs b/src/database/slice.rs deleted file mode 100644 index 448d969d..00000000 --- a/src/database/slice.rs +++ /dev/null @@ -1,57 +0,0 @@ -pub struct OwnedKeyVal(pub OwnedKey, pub OwnedVal); -pub(crate) type OwnedKeyValPair = (OwnedKey, OwnedVal); -pub type OwnedVal = Vec; -pub type OwnedKey = Vec; - -pub struct KeyVal<'item>(pub &'item Key, pub &'item Val); -pub(crate) type KeyValPair<'item> = (&'item Key, &'item Val); -pub type Val = [Byte]; -pub type Key = [Byte]; - -pub(crate) type Byte = u8; - -impl OwnedKeyVal { - #[must_use] - pub fn as_slice(&self) -> KeyVal<'_> { KeyVal(&self.0, &self.1) } - - #[must_use] - pub fn to_tuple(self) -> OwnedKeyValPair { (self.0, self.1) } -} - -impl From for OwnedKeyVal { - fn from((key, val): OwnedKeyValPair) -> Self { Self(key, val) } -} - -impl From<&KeyVal<'_>> for OwnedKeyVal { - #[inline] - fn from(slice: &KeyVal<'_>) -> Self { slice.to_owned() } -} - -impl From> for OwnedKeyVal { - fn from((key, val): KeyValPair<'_>) -> Self { Self(Vec::from(key), Vec::from(val)) } -} - -impl From for OwnedKeyValPair { - fn from(val: OwnedKeyVal) -> Self { val.to_tuple() } -} - -impl KeyVal<'_> { - #[inline] - #[must_use] - pub fn to_owned(&self) -> OwnedKeyVal { OwnedKeyVal::from(self) } - - #[must_use] - pub fn as_tuple(&self) -> KeyValPair<'_> { (self.0, self.1) } -} - -impl<'a> From<&'a OwnedKeyVal> for KeyVal<'a> { - fn from(owned: &'a OwnedKeyVal) -> Self { owned.as_slice() } -} - -impl<'a> From<&'a OwnedKeyValPair> for KeyVal<'a> { - fn from((key, val): &'a OwnedKeyValPair) -> Self { KeyVal(key.as_slice(), val.as_slice()) } -} - -impl<'a> From> for KeyVal<'a> { - fn from((key, val): KeyValPair<'a>) -> Self { KeyVal(key, val) } -} diff --git a/src/database/stream.rs b/src/database/stream.rs new file mode 100644 index 00000000..d9b74215 --- /dev/null +++ b/src/database/stream.rs @@ -0,0 +1,122 @@ +mod items; +mod items_rev; +mod keys; +mod keys_rev; + +use std::sync::Arc; + +use conduit::{utils::exchange, Error, Result}; +use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, ReadOptions}; + +pub(crate) use self::{items::Items, items_rev::ItemsRev, keys::Keys, keys_rev::KeysRev}; +use crate::{ + engine::Db, + keyval::{Key, KeyVal, Val}, + util::map_err, + Engine, Slice, +}; + +struct State<'a> { + inner: Inner<'a>, + seek: bool, + init: bool, +} + +trait Cursor<'a, T> { + fn state(&self) -> &State<'a>; + + fn fetch(&self) -> Option; + + fn seek(&mut self); + + fn get(&self) -> Option> { + self.fetch() + .map(Ok) + .or_else(|| self.state().status().map(Err)) + } + + fn seek_and_get(&mut self) -> Option> { + self.seek(); + self.get() + } +} + +type Inner<'a> = DBRawIteratorWithThreadMode<'a, Db>; +type From<'a> = Option>; + +impl<'a> State<'a> { + fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions) -> Self { + Self { + inner: db.db.raw_iterator_cf_opt(&**cf, opts), + init: true, + seek: false, + } + } + + fn init_fwd(mut self, from: From<'_>) -> Self { + if let Some(key) = from { + self.inner.seek(key); + self.seek = true; + } + + self + } + + fn init_rev(mut self, from: From<'_>) -> Self { + if let Some(key) = from { + self.inner.seek_for_prev(key); + self.seek = true; + } + + self + } + + fn seek_fwd(&mut self) { + if !exchange(&mut self.init, false) { + self.inner.next(); + } else if !self.seek { + self.inner.seek_to_first(); + } + } + + fn seek_rev(&mut self) { + if !exchange(&mut self.init, false) { + self.inner.prev(); + } else if !self.seek { + self.inner.seek_to_last(); + } + } + + fn fetch_key(&self) -> Option> { self.inner.key().map(Key::from) } + + fn _fetch_val(&self) -> Option> { self.inner.value().map(Val::from) } + + fn fetch(&self) -> Option> { self.inner.item().map(KeyVal::from) } + + fn status(&self) -> Option { self.inner.status().map_err(map_err).err() } + + fn valid(&self) -> bool { self.inner.valid() } +} + +fn keyval_longevity<'a, 'b: 'a>(item: KeyVal<'a>) -> KeyVal<'b> { + (slice_longevity::<'a, 'b>(item.0), slice_longevity::<'a, 'b>(item.1)) +} + +fn slice_longevity<'a, 'b: 'a>(item: &'a Slice) -> &'b Slice { + // SAFETY: The lifetime of the data returned by the rocksdb cursor is only valid + // between each movement of the cursor. It is hereby unsafely extended to match + // the lifetime of the cursor itself. This is due to the limitation of the + // Stream trait where the Item is incapable of conveying a lifetime; this is due + // to GAT's being unstable during its development. This unsafety can be removed + // as soon as this limitation is addressed by an upcoming version. + // + // We have done our best to mitigate the implications of this in conjunction + // with the deserialization API such that borrows being held across movements of + // the cursor do not happen accidentally. The compiler will still error when + // values herein produced try to leave a closure passed to a StreamExt API. But + // escapes can happen if you explicitly and intentionally attempt it, and there + // will be no compiler error or warning. This is primarily the case with + // calling collect() without a preceding map(ToOwned::to_owned). A collection + // of references here is illegal, but this will not be enforced by the compiler. + unsafe { std::mem::transmute(item) } +} diff --git a/src/database/stream/items.rs b/src/database/stream/items.rs new file mode 100644 index 00000000..31d5e9e8 --- /dev/null +++ b/src/database/stream/items.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{keyval_longevity, Cursor, From, State}; +use crate::{keyval::KeyVal, Engine}; + +pub(crate) struct Items<'a> { + state: State<'a>, +} + +impl<'a> Items<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_fwd(from), + } + } +} + +impl<'a> Cursor<'a, KeyVal<'a>> for Items<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch().map(keyval_longevity) } + + fn seek(&mut self) { self.state.seek_fwd(); } +} + +impl<'a> Stream for Items<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for Items<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/items_rev.rs b/src/database/stream/items_rev.rs new file mode 100644 index 00000000..ab57a250 --- /dev/null +++ b/src/database/stream/items_rev.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{keyval_longevity, Cursor, From, State}; +use crate::{keyval::KeyVal, Engine}; + +pub(crate) struct ItemsRev<'a> { + state: State<'a>, +} + +impl<'a> ItemsRev<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_rev(from), + } + } +} + +impl<'a> Cursor<'a, KeyVal<'a>> for ItemsRev<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch().map(keyval_longevity) } + + fn seek(&mut self) { self.state.seek_rev(); } +} + +impl<'a> Stream for ItemsRev<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for ItemsRev<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/keys.rs b/src/database/stream/keys.rs new file mode 100644 index 00000000..1c5d12e3 --- /dev/null +++ b/src/database/stream/keys.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{slice_longevity, Cursor, From, State}; +use crate::{keyval::Key, Engine}; + +pub(crate) struct Keys<'a> { + state: State<'a>, +} + +impl<'a> Keys<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_fwd(from), + } + } +} + +impl<'a> Cursor<'a, Key<'a>> for Keys<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch_key().map(slice_longevity) } + + fn seek(&mut self) { self.state.seek_fwd(); } +} + +impl<'a> Stream for Keys<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for Keys<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/keys_rev.rs b/src/database/stream/keys_rev.rs new file mode 100644 index 00000000..26707483 --- /dev/null +++ b/src/database/stream/keys_rev.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{slice_longevity, Cursor, From, State}; +use crate::{keyval::Key, Engine}; + +pub(crate) struct KeysRev<'a> { + state: State<'a>, +} + +impl<'a> KeysRev<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_rev(from), + } + } +} + +impl<'a> Cursor<'a, Key<'a>> for KeysRev<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch_key().map(slice_longevity) } + + fn seek(&mut self) { self.state.seek_rev(); } +} + +impl<'a> Stream for KeysRev<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for KeysRev<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/util.rs b/src/database/util.rs index f0ccbcbe..d36e183f 100644 --- a/src/database/util.rs +++ b/src/database/util.rs @@ -1,4 +1,16 @@ use conduit::{err, Result}; +use rocksdb::{Direction, IteratorMode}; + +#[inline] +pub(crate) fn _into_direction(mode: &IteratorMode<'_>) -> Direction { + use Direction::{Forward, Reverse}; + use IteratorMode::{End, From, Start}; + + match mode { + Start | From(_, Forward) => Forward, + End | From(_, Reverse) => Reverse, + } +} #[inline] pub(crate) fn result(r: std::result::Result) -> Result { diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index cfed5a0e..737a7039 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -46,7 +46,7 @@ bytes.workspace = true conduit-core.workspace = true conduit-database.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true hickory-resolver.workspace = true http.workspace = true image.workspace = true diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs deleted file mode 100644 index 53a0e953..00000000 --- a/src/service/account_data/data.rs +++ /dev/null @@ -1,152 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use conduit::{Error, Result}; -use database::Map; -use ruma::{ - api::client::error::ErrorKind, - events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - roomuserdataid_accountdata: Arc, - roomusertype_roomuserdataid: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(), - roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - /// Places one event in the account data of the user and removes the - /// previous entry. - pub(super) fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: &RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - roomuserdataid.push(0xFF); - roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); - - let mut key = prefix; - key.extend_from_slice(event_type.to_string().as_bytes()); - - if data.get("type").is_none() || data.get("content").is_none() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Account data doesn't have all required fields.", - )); - } - - self.roomuserdataid_accountdata.insert( - &roomuserdataid, - &serde_json::to_vec(&data).expect("to_vec always works on json values"), - )?; - - let prev = self.roomusertype_roomuserdataid.get(&key)?; - - self.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; - - // Remove old entry - if let Some(prev) = prev { - self.roomuserdataid_accountdata.remove(&prev)?; - } - - Ok(()) - } - - /// Searches the account data for a specific kind. - pub(super) fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, kind: &RoomAccountDataEventType, - ) -> Result>> { - let mut key = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(kind.to_string().as_bytes()); - - self.roomusertype_roomuserdataid - .get(&key)? - .and_then(|roomuserdataid| { - self.roomuserdataid_accountdata - .get(&roomuserdataid) - .transpose() - }) - .transpose()? - .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) - .transpose() - } - - /// Returns all changes to the account data that happened after `since`. - pub(super) fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result> { - let mut userdata = HashMap::new(); - - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - // Skip the data that's exactly at since, because we sent that last time - let mut first_possible = prefix.clone(); - first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); - - for r in self - .roomuserdataid_accountdata - .iter_from(&first_possible, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(k, v)| { - Ok::<_, Error>(( - k, - match room_id { - None => serde_json::from_slice::>(&v) - .map(AnyRawAccountDataEvent::Global) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - Some(_) => serde_json::from_slice::>(&v) - .map(AnyRawAccountDataEvent::Room) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - }, - )) - }) { - let (kind, data) = r?; - userdata.insert(kind, data); - } - - Ok(userdata.into_values().collect()) - } -} diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index eaa53641..b4eb143d 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -1,52 +1,158 @@ -mod data; +use std::{collections::HashMap, sync::Arc}; -use std::sync::Arc; - -use conduit::Result; -use data::Data; +use conduit::{ + implement, + utils::{stream::TryIgnore, ReadyExt}, + Err, Error, Result, +}; +use database::{Deserialized, Map}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ - events::{AnyRawAccountDataEvent, RoomAccountDataEventType}, + events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType}, + serde::Raw, RoomId, UserId, }; +use serde_json::value::RawValue; + +use crate::{globals, Dep}; pub struct Service { + services: Services, db: Data, } +struct Data { + roomuserdataid_accountdata: Arc, + roomusertype_roomuserdataid: Arc, +} + +struct Services { + globals: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + services: Services { + globals: args.depend::("globals"), + }, + db: Data { + roomuserdataid_accountdata: args.db["roomuserdataid_accountdata"].clone(), + roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Places one event in the account data of the user and removes the - /// previous entry. - #[allow(clippy::needless_pass_by_value)] - pub fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - self.db.update(room_id, user_id, &event_type, data) +/// Places one event in the account data of the user and removes the +/// previous entry. +#[allow(clippy::needless_pass_by_value)] +#[implement(Service)] +pub async fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, +) -> Result<()> { + let event_type = event_type.to_string(); + let count = self.services.globals.next_count()?; + + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); + + let mut roomuserdataid = prefix.clone(); + roomuserdataid.extend_from_slice(&count.to_be_bytes()); + roomuserdataid.push(0xFF); + roomuserdataid.extend_from_slice(event_type.as_bytes()); + + let mut key = prefix; + key.extend_from_slice(event_type.as_bytes()); + + if data.get("type").is_none() || data.get("content").is_none() { + return Err!(Request(InvalidParam("Account data doesn't have all required fields."))); } - /// Searches the account data for a specific kind. - #[allow(clippy::needless_pass_by_value)] - pub fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - ) -> Result>> { - self.db.get(room_id, user_id, &event_type) + self.db.roomuserdataid_accountdata.insert( + &roomuserdataid, + &serde_json::to_vec(&data).expect("to_vec always works on json values"), + ); + + let prev_key = (room_id, user_id, &event_type); + let prev = self.db.roomusertype_roomuserdataid.qry(&prev_key).await; + + self.db + .roomusertype_roomuserdataid + .insert(&key, &roomuserdataid); + + // Remove old entry + if let Ok(prev) = prev { + self.db.roomuserdataid_accountdata.remove(&prev); } - /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip_all, name = "since", level = "debug")] - pub fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result> { - self.db.changes_since(room_id, user_id, since) - } + Ok(()) +} + +/// Searches the account data for a specific kind. +#[implement(Service)] +pub async fn get( + &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, +) -> Result> { + let key = (room_id, user_id, kind.to_string()); + self.db + .roomusertype_roomuserdataid + .qry(&key) + .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.qry(&roomuserdataid)) + .await + .deserialized_json() +} + +/// Returns all changes to the account data that happened after `since`. +#[implement(Service)] +pub async fn changes_since( + &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, +) -> Result> { + let mut userdata = HashMap::new(); + + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); + + // Skip the data that's exactly at since, because we sent that last time + let mut first_possible = prefix.clone(); + first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); + + self.db + .roomuserdataid_accountdata + .raw_stream_from(&first_possible) + .ignore_err() + .ready_take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(k, v)| { + let v = match room_id { + None => serde_json::from_slice::>(v) + .map(AnyRawAccountDataEvent::Global) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + Some(_) => serde_json::from_slice::>(v) + .map(AnyRawAccountDataEvent::Room) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + }; + + Ok((k.to_owned(), v)) + }) + .ignore_err() + .ready_for_each(|(kind, data)| { + userdata.insert(kind, data); + }) + .await; + + Ok(userdata.into_values().collect()) } diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 55bae365..0f5016e1 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -5,7 +5,7 @@ use std::{ }; use conduit::{debug, defer, error, log, Server}; -use futures_util::future::{AbortHandle, Abortable}; +use futures::future::{AbortHandle, Abortable}; use ruma::events::room::message::RoomMessageEventContent; use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; use termimad::MadSkin; diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index 4e2b831c..7b090aa0 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -30,7 +30,7 @@ use crate::Services; pub async fn create_admin_room(services: &Services) -> Result<()> { let room_id = RoomId::new(services.globals.server_name()); - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id); let state_lock = services.rooms.state.mutex.lock(&room_id).await; diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index b4589ebc..4b3ebb88 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -17,108 +17,108 @@ use serde_json::value::to_raw_value; use crate::pdu::PduBuilder; -impl super::Service { - /// Invite the user to the conduit admin room. - /// - /// In conduit, this is equivalent to granting admin privileges. - pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { - let Some(room_id) = self.get_admin_room()? else { - return Ok(()); - }; +/// Invite the user to the conduit admin room. +/// +/// In conduit, this is equivalent to granting admin privileges. +#[implement(super::Service)] +pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { + let Ok(room_id) = self.get_admin_room().await else { + return Ok(()); + }; - let state_lock = self.services.state.mutex.lock(&room_id).await; + let state_lock = self.services.state.mutex.lock(&room_id).await; - // Use the server user to grant the new admin's power level - let server_user = &self.services.globals.server_user; + // Use the server user to grant the new admin's power level + let server_user = &self.services.globals.server_user; - // Invite and join the real user - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ) - .await?; - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - user_id, - &room_id, - &state_lock, - ) - .await?; + // Invite and join the real user + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + server_user, + &room_id, + &state_lock, + ) + .await?; + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + user_id, + &room_id, + &state_lock, + ) + .await?; - // Set power level - let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]); + // Set power level + let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]); - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ) - .await?; + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(String::new()), + redacts: None, + timestamp: None, + }, + server_user, + &room_id, + &state_lock, + ) + .await?; - // Set room tag - let room_tag = &self.services.server.config.admin_room_tag; - if !room_tag.is_empty() { - if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag) { - error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant"); - } + // Set room tag + let room_tag = &self.services.server.config.admin_room_tag; + if !room_tag.is_empty() { + if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag).await { + error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant"); } + } - // Send welcome message - self.services.timeline.build_and_append_pdu( + // Send welcome message + self.services.timeline.build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_markdown( @@ -135,19 +135,18 @@ impl super::Service { &state_lock, ).await?; - Ok(()) - } + Ok(()) } #[implement(super::Service)] -fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> { +async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> { let mut event = self .services .account_data - .get(Some(room_id), user_id, RoomAccountDataEventType::Tag)? - .map(|event| serde_json::from_str(event.get())) - .and_then(Result::ok) - .unwrap_or_else(|| TagEvent { + .get(Some(room_id), user_id, RoomAccountDataEventType::Tag) + .await + .and_then(|event| serde_json::from_str(event.get()).map_err(Into::into)) + .unwrap_or_else(|_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -158,12 +157,15 @@ fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result< .tags .insert(tag.to_owned().into(), TagInfo::new()); - self.services.account_data.update( - Some(room_id), - user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(event)?, - )?; + self.services + .account_data + .update( + Some(room_id), + user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(event)?, + ) + .await?; Ok(()) } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 3274249e..12eacc8f 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -12,6 +12,7 @@ use std::{ use async_trait::async_trait; use conduit::{debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server}; pub use create::create_admin_room; +use futures::{FutureExt, TryFutureExt}; use loole::{Receiver, Sender}; use ruma::{ events::{ @@ -142,17 +143,18 @@ impl Service { /// admin room as the admin user. pub async fn send_text(&self, body: &str) { self.send_message(RoomMessageEventContent::text_markdown(body)) - .await; + .await + .ok(); } /// Sends a message to the admin room as the admin user (see send_text() for /// convenience). - pub async fn send_message(&self, message_content: RoomMessageEventContent) { - if let Ok(Some(room_id)) = self.get_admin_room() { - let user_id = &self.services.globals.server_user; - self.respond_to_room(message_content, &room_id, user_id) - .await; - } + pub async fn send_message(&self, message_content: RoomMessageEventContent) -> Result<()> { + let user_id = &self.services.globals.server_user; + let room_id = self.get_admin_room().await?; + self.respond_to_room(message_content, &room_id, user_id) + .boxed() + .await } /// Posts a command to the command processor queue and returns. Processing @@ -193,8 +195,12 @@ impl Service { async fn handle_command(&self, command: CommandInput) { match self.process_command(command).await { - Ok(Some(output)) | Err(output) => self.handle_response(output).await, Ok(None) => debug!("Command successful with no response"), + Ok(Some(output)) | Err(output) => self + .handle_response(output) + .boxed() + .await + .unwrap_or_else(default_log), } } @@ -218,71 +224,67 @@ impl Service { } /// Checks whether a given user is an admin of this server - pub async fn user_is_admin(&self, user_id: &UserId) -> Result { - if let Ok(Some(admin_room)) = self.get_admin_room() { - self.services.state_cache.is_joined(user_id, &admin_room) - } else { - Ok(false) - } + pub async fn user_is_admin(&self, user_id: &UserId) -> bool { + let Ok(admin_room) = self.get_admin_room().await else { + return false; + }; + + self.services + .state_cache + .is_joined(user_id, &admin_room) + .await } /// Gets the room ID of the admin room /// /// Errors are propagated from the database, and will have None if there is /// no admin room - pub fn get_admin_room(&self) -> Result> { - if let Some(room_id) = self + pub async fn get_admin_room(&self) -> Result { + let room_id = self .services .alias - .resolve_local_alias(&self.services.globals.admin_alias)? - { - if self - .services - .state_cache - .is_joined(&self.services.globals.server_user, &room_id)? - { - return Ok(Some(room_id)); - } - } + .resolve_local_alias(&self.services.globals.admin_alias) + .await?; - Ok(None) + self.services + .state_cache + .is_joined(&self.services.globals.server_user, &room_id) + .await + .then_some(room_id) + .ok_or_else(|| err!(Request(NotFound("Admin user not joined to admin room")))) } - async fn handle_response(&self, content: RoomMessageEventContent) { + async fn handle_response(&self, content: RoomMessageEventContent) -> Result<()> { let Some(Relation::Reply { in_reply_to, }) = content.relates_to.as_ref() else { - return; + return Ok(()); }; - let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else { + let Ok(pdu) = self.services.timeline.get_pdu(&in_reply_to.event_id).await else { error!( event_id = ?in_reply_to.event_id, "Missing admin command in_reply_to event" ); - return; + return Ok(()); }; - let response_sender = if self.is_admin_room(&pdu.room_id) { + let response_sender = if self.is_admin_room(&pdu.room_id).await { &self.services.globals.server_user } else { &pdu.sender }; self.respond_to_room(content, &pdu.room_id, response_sender) - .await; + .await } - async fn respond_to_room(&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId) { - assert!( - self.user_is_admin(user_id) - .await - .expect("checked user is admin"), - "sender is not admin" - ); + async fn respond_to_room( + &self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId, + ) -> Result<()> { + assert!(self.user_is_admin(user_id).await, "sender is not admin"); - let state_lock = self.services.state.mutex.lock(room_id).await; let response_pdu = PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -292,6 +294,7 @@ impl Service { timestamp: None, }; + let state_lock = self.services.state.mutex.lock(room_id).await; if let Err(e) = self .services .timeline @@ -302,6 +305,8 @@ impl Service { .await .unwrap_or_else(default_log); } + + Ok(()) } async fn handle_response_error( @@ -355,12 +360,12 @@ impl Service { } // Prevent unescaped !admin from being used outside of the admin room - if is_public_prefix && !self.is_admin_room(&pdu.room_id) { + if is_public_prefix && !self.is_admin_room(&pdu.room_id).await { return false; } // Only senders who are admin can proceed - if !self.user_is_admin(&pdu.sender).await.unwrap_or(false) { + if !self.user_is_admin(&pdu.sender).await { return false; } @@ -368,7 +373,7 @@ impl Service { // the administrator can execute commands as conduit let emergency_password_set = self.services.globals.emergency_password().is_some(); let from_server = pdu.sender == *server_user && !emergency_password_set; - if from_server && self.is_admin_room(&pdu.room_id) { + if from_server && self.is_admin_room(&pdu.room_id).await { return false; } @@ -377,12 +382,11 @@ impl Service { } #[must_use] - pub fn is_admin_room(&self, room_id: &RoomId) -> bool { - if let Ok(Some(admin_room_id)) = self.get_admin_room() { - admin_room_id == room_id - } else { - false - } + pub async fn is_admin_room(&self, room_id_: &RoomId) -> bool { + self.get_admin_room() + .map_ok(|room_id| room_id == room_id_) + .await + .unwrap_or(false) } /// Sets the self-reference to crate::Services which will provide context to diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index 40e641a1..d5fa5476 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,7 +1,8 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{err, utils::stream::TryIgnore, Result}; +use database::{Database, Deserialized, Map}; +use futures::Stream; use ruma::api::appservice::Registration; pub struct Data { @@ -19,7 +20,7 @@ impl Data { pub(super) fn register_appservice(&self, yaml: &Registration) -> Result { let id = yaml.id.as_str(); self.id_appserviceregistrations - .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; + .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes()); Ok(id.to_owned()) } @@ -31,24 +32,19 @@ impl Data { /// * `service_name` - the name you send to register the service previously pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations - .remove(service_name.as_bytes())?; + .remove(service_name.as_bytes()); Ok(()) } - pub fn get_registration(&self, id: &str) -> Result> { + pub async fn get_registration(&self, id: &str) -> Result { self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes) - .map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")) - }) - .transpose() + .qry(id) + .await + .deserialized_json() + .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) } - pub(super) fn iter_ids<'a>(&'a self) -> Result> + 'a>> { - Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { - utils::string_from_bytes(&id) - .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) - }))) + pub(super) fn iter_ids(&self) -> impl Stream + Send + '_ { + self.id_appserviceregistrations.keys().ignore_err() } } diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index c0752d56..7e2dc738 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -2,9 +2,10 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; +use async_trait::async_trait; use conduit::{err, Result}; use data::Data; -use futures_util::Future; +use futures::{Future, StreamExt, TryStreamExt}; use regex::RegexSet; use ruma::{ api::appservice::{Namespace, Registration}, @@ -126,13 +127,22 @@ struct Services { sending: Dep, } +#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { - let mut registration_info = BTreeMap::new(); - let db = Data::new(args.db); + Ok(Arc::new(Self { + db: Data::new(args.db), + services: Services { + sending: args.depend::("sending"), + }, + registration_info: RwLock::new(BTreeMap::new()), + })) + } + + async fn worker(self: Arc) -> Result<()> { // Inserting registrations into cache - for appservice in iter_ids(&db)? { - registration_info.insert( + for appservice in iter_ids(&self.db).await? { + self.registration_info.write().await.insert( appservice.0, appservice .1 @@ -141,13 +151,7 @@ impl crate::Service for Service { ); } - Ok(Arc::new(Self { - db, - services: Services { - sending: args.depend::("sending"), - }, - registration_info: RwLock::new(registration_info), - })) + Ok(()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -155,7 +159,7 @@ impl crate::Service for Service { impl Service { #[inline] - pub fn all(&self) -> Result> { iter_ids(&self.db) } + pub async fn all(&self) -> Result> { iter_ids(&self.db).await } /// Registers an appservice and returns the ID to the caller pub async fn register_appservice(&self, yaml: Registration) -> Result { @@ -188,7 +192,8 @@ impl Service { // sending to the URL self.services .sending - .cleanup_events(service_name.to_owned())?; + .cleanup_events(service_name.to_owned()) + .await; Ok(()) } @@ -251,15 +256,9 @@ impl Service { } } -fn iter_ids(db: &Data) -> Result> { - db.iter_ids()? - .filter_map(Result::ok) - .map(move |id| { - Ok(( - id.clone(), - db.get_registration(&id)? - .expect("iter_ids only returns appservices that exist"), - )) - }) - .collect() +async fn iter_ids(db: &Data) -> Result> { + db.iter_ids() + .then(|id| async move { Ok((id.clone(), db.get_registration(&id).await?)) }) + .try_collect() + .await } diff --git a/src/service/emergency/mod.rs b/src/service/emergency/mod.rs index 1bb0843d..98020bc2 100644 --- a/src/service/emergency/mod.rs +++ b/src/service/emergency/mod.rs @@ -33,6 +33,7 @@ impl crate::Service for Service { async fn worker(self: Arc) -> Result<()> { self.set_emergency_access() + .await .inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?; Ok(()) @@ -44,7 +45,7 @@ impl crate::Service for Service { impl Service { /// Sets the emergency password and push rules for the @conduit account in /// case emergency password is set - fn set_emergency_access(&self) -> Result { + async fn set_emergency_access(&self) -> Result { let conduit_user = &self.services.globals.server_user; self.services @@ -56,17 +57,20 @@ impl Service { None => (Ruleset::new(), false), }; - self.services.account_data.update( - None, - conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { - global: ruleset, - }, - }) - .expect("to json value always works"), - )?; + self.services + .account_data + .update( + None, + conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + ) + .await?; if pwd_set { warn!( @@ -75,7 +79,7 @@ impl Service { ); } else { // logs out any users still in the server service account and removes sessions - self.services.users.deactivate_account(conduit_user)?; + self.services.users.deactivate_account(conduit_user).await?; } Ok(pwd_set) diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5b5d9f09..3286e40c 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -4,8 +4,8 @@ use std::{ }; use conduit::{trace, utils, Error, Result, Server}; -use database::{Database, Map}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use database::{Database, Deserialized, Map}; +use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, signatures::Ed25519KeyPair, @@ -83,7 +83,7 @@ impl Data { .checked_add(1) .expect("counter must not overflow u64"); - self.global.insert(COUNTER, &counter.to_be_bytes())?; + self.global.insert(COUNTER, &counter.to_be_bytes()); Ok(*counter) } @@ -102,7 +102,7 @@ impl Data { fn stored_count(global: &Arc) -> Result { global - .get(COUNTER)? + .get(COUNTER) .as_deref() .map_or(Ok(0_u64), utils::u64_from_bytes) } @@ -133,36 +133,18 @@ impl Data { futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); // Events for rooms we are in - for room_id in self - .services - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - let short_roomid = self - .services - .short - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + + pin_mut!(rooms_joined); + while let Some(room_id) = rooms_joined.next().await { + let Ok(short_roomid) = self.services.short.get_shortroomid(room_id).await else { + continue; + }; let roomid_bytes = room_id.as_bytes().to_vec(); let mut roomid_prefix = roomid_bytes.clone(); roomid_prefix.push(0xFF); - // PDUs - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - futures.push(Box::pin(async move { - let _result = self.services.typing.wait_for_update(&room_id).await; - })); - - futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); - // Key changes futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); @@ -174,6 +156,19 @@ impl Data { self.roomusertype_roomuserdataid .watch_prefix(&roomuser_prefix), ); + + // PDUs + let short_roomid = short_roomid.to_be_bytes().to_vec(); + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + let typing_room_id = room_id.to_owned(); + let typing_wait_for_update = async move { + self.services.typing.wait_for_update(&typing_room_id).await; + }; + + futures.push(typing_wait_for_update.boxed()); + futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); } let mut globaluserdata_prefix = vec![0xFF]; @@ -190,12 +185,14 @@ impl Data { // One time keys futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); - futures.push(Box::pin(async move { + // Server shutdown + let server_shutdown = async move { while self.services.server.running() { - let _result = self.services.server.signal.subscribe().recv().await; + self.services.server.signal.subscribe().recv().await.ok(); } - })); + }; + futures.push(server_shutdown.boxed()); if !self.services.server.running() { return Ok(()); } @@ -209,10 +206,10 @@ impl Data { } pub fn load_keypair(&self) -> Result { - let keypair_bytes = self.global.get(b"keypair")?.map_or_else( - || { + let keypair_bytes = self.global.get(b"keypair").map_or_else( + |_| { let keypair = utils::generate_keypair(); - self.global.insert(b"keypair", &keypair)?; + self.global.insert(b"keypair", &keypair); Ok::<_, Error>(keypair) }, |val| Ok(val.to_vec()), @@ -241,7 +238,10 @@ impl Data { } #[inline] - pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } + pub fn remove_keypair(&self) -> Result<()> { + self.global.remove(b"keypair"); + Ok(()) + } /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored /// in room version > 4 @@ -250,15 +250,15 @@ impl Data { /// /// This doesn't actually check that the keys provided are newer than the /// old set. - pub fn add_signing_key( + pub async fn add_signing_key( &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> Result> { + ) -> BTreeMap { // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + let signingkeys = self.server_signingkeys.qry(origin).await; let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).ok()) - .unwrap_or_else(|| { + .and_then(|keys| serde_json::from_slice(&keys).map_err(Into::into)) + .unwrap_or_else(|_| { // Just insert "now", it doesn't matter ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) }); @@ -275,7 +275,7 @@ impl Data { self.server_signingkeys.insert( origin.as_bytes(), &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - )?; + ); let mut tree = keys.verify_keys; tree.extend( @@ -284,45 +284,38 @@ impl Data { .map(|old| (old.0, VerifyKey::new(old.1.key))), ); - Ok(tree) + tree } /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub fn verify_keys_for(&self, origin: &ServerName) -> Result> { - let signingkeys = self - .signing_keys_for(origin)? - .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { + pub async fn verify_keys_for(&self, origin: &ServerName) -> Result> { + self.signing_keys_for(origin).await.map_or_else( + |_| Ok(BTreeMap::new()), + |keys: ServerSigningKeys| { let mut tree = keys.verify_keys; tree.extend( keys.old_verify_keys .into_iter() .map(|old| (old.0, VerifyKey::new(old.1.key))), ); - tree - }); - - Ok(signingkeys) + Ok(tree) + }, + ) } - pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { - let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()); - - Ok(signingkeys) + pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { + self.server_signingkeys + .qry(origin) + .await + .deserialized_json() } - pub fn database_version(&self) -> Result { - self.global.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) - }) - } + pub async fn database_version(&self) -> u64 { self.global.qry("version").await.deserialized().unwrap_or(0) } #[inline] pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.global.insert(b"version", &new_version.to_be_bytes())?; + self.global.insert(b"version", &new_version.to_be_bytes()); Ok(()) } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index 66917520..c7a73230 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -1,17 +1,15 @@ -use std::{ - collections::{HashMap, HashSet}, - fs::{self}, - io::Write, - mem::size_of, - sync::Arc, +use conduit::{ + debug_info, debug_warn, error, info, + result::NotFound, + utils::{stream::TryIgnore, IterStream, ReadyExt}, + warn, Err, Error, Result, }; - -use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Error, Result}; +use futures::{FutureExt, StreamExt}; use itertools::Itertools; use ruma::{ events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, push::Ruleset, - EventId, OwnedRoomId, RoomId, UserId, + UserId, }; use crate::{media, Services}; @@ -33,12 +31,14 @@ pub(crate) const DATABASE_VERSION: u64 = 13; pub(crate) const CONDUIT_DATABASE_VERSION: u64 = 16; pub(crate) async fn migrations(services: &Services) -> Result<()> { + let users_count = services.users.count().await; + // Matrix resource ownership is based on the server name; changing it // requires recreating the database from scratch. - if services.users.count()? > 0 { + if users_count > 0 { let conduit_user = &services.globals.server_user; - if !services.users.exists(conduit_user)? { + if !services.users.exists(conduit_user).await { error!("The {} server user does not exist, and the database is not new.", conduit_user); return Err(Error::bad_database( "Cannot reuse an existing database after changing the server name, please delete the old one first.", @@ -46,7 +46,7 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> { } } - if services.users.count()? > 0 { + if users_count > 0 { migrate(services).await } else { fresh(services).await @@ -62,9 +62,9 @@ async fn fresh(services: &Services) -> Result<()> { .db .bump_database_version(DATABASE_VERSION)?; - db["global"].insert(b"feat_sha256_media", &[])?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; + db["global"].insert(b"feat_sha256_media", &[]); + db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]); + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]); // Create the admin room and server user on first run crate::admin::create_admin_room(services).await?; @@ -82,136 +82,109 @@ async fn migrate(services: &Services) -> Result<()> { let db = &services.db; let config = &services.server.config; - if services.globals.db.database_version()? < 1 { - db_lt_1(services).await?; + if services.globals.db.database_version().await < 11 { + return Err!(Database( + "Database schema version {} is no longer supported", + services.globals.db.database_version().await + )); } - if services.globals.db.database_version()? < 2 { - db_lt_2(services).await?; - } - - if services.globals.db.database_version()? < 3 { - db_lt_3(services).await?; - } - - if services.globals.db.database_version()? < 4 { - db_lt_4(services).await?; - } - - if services.globals.db.database_version()? < 5 { - db_lt_5(services).await?; - } - - if services.globals.db.database_version()? < 6 { - db_lt_6(services).await?; - } - - if services.globals.db.database_version()? < 7 { - db_lt_7(services).await?; - } - - if services.globals.db.database_version()? < 8 { - db_lt_8(services).await?; - } - - if services.globals.db.database_version()? < 9 { - db_lt_9(services).await?; - } - - if services.globals.db.database_version()? < 10 { - db_lt_10(services).await?; - } - - if services.globals.db.database_version()? < 11 { - db_lt_11(services).await?; - } - - if services.globals.db.database_version()? < 12 { + if services.globals.db.database_version().await < 12 { db_lt_12(services).await?; } // This migration can be reused as-is anytime the server-default rules are // updated. - if services.globals.db.database_version()? < 13 { + if services.globals.db.database_version().await < 13 { db_lt_13(services).await?; } - if db["global"].get(b"feat_sha256_media")?.is_none() { + if db["global"].qry("feat_sha256_media").await.is_not_found() { media::migrations::migrate_sha256_media(services).await?; } else if config.media_startup_check { media::migrations::checkup_sha256_media(services).await?; } if db["global"] - .get(b"fix_bad_double_separator_in_state_cache")? - .is_none() + .qry("fix_bad_double_separator_in_state_cache") + .await + .is_not_found() { fix_bad_double_separator_in_state_cache(services).await?; } if db["global"] - .get(b"retroactively_fix_bad_data_from_roomuserid_joined")? - .is_none() + .qry("retroactively_fix_bad_data_from_roomuserid_joined") + .await + .is_not_found() { retroactively_fix_bad_data_from_roomuserid_joined(services).await?; } - let version_match = services.globals.db.database_version().unwrap() == DATABASE_VERSION - || services.globals.db.database_version().unwrap() == CONDUIT_DATABASE_VERSION; + let version_match = services.globals.db.database_version().await == DATABASE_VERSION + || services.globals.db.database_version().await == CONDUIT_DATABASE_VERSION; assert!( version_match, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", - services.globals.db.database_version().unwrap(), + services.globals.db.database_version().await, DATABASE_VERSION, ); { let patterns = services.globals.forbidden_usernames(); if !patterns.is_empty() { - for user_id in services + services .users - .iter() - .filter_map(Result::ok) - .filter(|user| !services.users.is_deactivated(user).unwrap_or(true)) - .filter(|user| user.server_name() == config.server_name) - { - let matches = patterns.matches(user_id.localpart()); - if matches.matched_any() { - warn!( - "User {} matches the following forbidden username patterns: {}", - user_id.to_string(), - matches - .into_iter() - .map(|x| &patterns.patterns()[x]) - .join(", ") - ); - } - } - } - } - - { - let patterns = services.globals.forbidden_alias_names(); - if !patterns.is_empty() { - for address in services.rooms.metadata.iter_ids() { - let room_id = address?; - let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id); - for room_alias_result in room_aliases { - let room_alias = room_alias_result?; - let matches = patterns.matches(room_alias.alias()); + .stream() + .filter(|user_id| services.users.is_active_local(user_id)) + .ready_for_each(|user_id| { + let matches = patterns.matches(user_id.localpart()); if matches.matched_any() { warn!( - "Room with alias {} ({}) matches the following forbidden room name patterns: {}", - room_alias, - &room_id, + "User {} matches the following forbidden username patterns: {}", + user_id.to_string(), matches .into_iter() .map(|x| &patterns.patterns()[x]) .join(", ") ); } - } + }) + .await; + } + } + + { + let patterns = services.globals.forbidden_alias_names(); + if !patterns.is_empty() { + for room_id in services + .rooms + .metadata + .iter_ids() + .map(ToOwned::to_owned) + .collect::>() + .await + { + services + .rooms + .alias + .local_aliases_for_room(&room_id) + .ready_for_each(|room_alias| { + let matches = patterns.matches(room_alias.alias()); + if matches.matched_any() { + warn!( + "Room with alias {} ({}) matches the following forbidden room name patterns: {}", + room_alias, + &room_id, + matches + .into_iter() + .map(|x| &patterns.patterns()[x]) + .join(", ") + ); + } + }) + .await; } } } @@ -224,424 +197,17 @@ async fn migrate(services: &Services) -> Result<()> { Ok(()) } -async fn db_lt_1(services: &Services) -> Result<()> { - let db = &services.db; - - let roomserverids = &db["roomserverids"]; - let serverroomids = &db["serverroomids"]; - for (roomserverid, _) in roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xFF); - let room_id = parts.next().expect("split always returns one element"); - let Some(servername) = parts.next() else { - error!("Migration: Invalid roomserverid in db."); - continue; - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xFF); - serverroomid.extend_from_slice(room_id); - - serverroomids.insert(&serverroomid, &[])?; - } - - services.globals.db.bump_database_version(1)?; - info!("Migration: 0 -> 1 finished"); - Ok(()) -} - -async fn db_lt_2(services: &Services) -> Result<()> { - let db = &services.db; - - // We accidentally inserted hashed versions of "" into the db instead of just "" - let userid_password = &db["roomserverids"]; - for (userid, password) in userid_password.iter() { - let empty_pass = utils::hash::password("").expect("our own password to be properly hashed"); - let password = std::str::from_utf8(&password).expect("password is valid utf-8"); - let empty_hashed_password = utils::hash::verify_password(password, &empty_pass).is_ok(); - if empty_hashed_password { - userid_password.insert(&userid, b"")?; - } - } - - services.globals.db.bump_database_version(2)?; - info!("Migration: 1 -> 2 finished"); - Ok(()) -} - -async fn db_lt_3(services: &Services) -> Result<()> { - let db = &services.db; - - // Move media to filesystem - let mediaid_file = &db["mediaid_file"]; - for (key, content) in mediaid_file.iter() { - if content.is_empty() { - continue; - } - - #[allow(deprecated)] - let path = services.media.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - mediaid_file.insert(&key, &[])?; - } - - services.globals.db.bump_database_version(3)?; - info!("Migration: 2 -> 3 finished"); - Ok(()) -} - -async fn db_lt_4(services: &Services) -> Result<()> { - let config = &services.server.config; - - // Add federated users to services as deactivated - for our_user in services.users.iter() { - let our_user = our_user?; - if services.users.is_deactivated(&our_user)? { - continue; - } - for room in services.rooms.state_cache.rooms_joined(&our_user) { - for user in services.rooms.state_cache.room_members(&room?) { - let user = user?; - if user.server_name() != config.server_name { - info!(?user, "Migration: creating user"); - services.users.create(&user, None)?; - } - } - } - } - - services.globals.db.bump_database_version(4)?; - info!("Migration: 3 -> 4 finished"); - Ok(()) -} - -async fn db_lt_5(services: &Services) -> Result<()> { - let db = &services.db; - - // Upgrade user data store - let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"]; - let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"]; - for (roomuserdataid, _) in roomuserdataid_accountdata.iter() { - let mut parts = roomuserdataid.split(|&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let user_id = parts.next().unwrap(); - let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap(); - - let mut key = room_id.to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id); - key.push(0xFF); - key.extend_from_slice(event_type); - - roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; - } - - services.globals.db.bump_database_version(5)?; - info!("Migration: 4 -> 5 finished"); - Ok(()) -} - -async fn db_lt_6(services: &Services) -> Result<()> { - let db = &services.db; - - // Set room member count - let roomid_shortstatehash = &db["roomid_shortstatehash"]; - for (roomid, _) in roomid_shortstatehash.iter() { - let string = utils::string_from_bytes(&roomid).unwrap(); - let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services.rooms.state_cache.update_joined_count(room_id)?; - } - - services.globals.db.bump_database_version(6)?; - info!("Migration: 5 -> 6 finished"); - Ok(()) -} - -async fn db_lt_7(services: &Services) -> Result<()> { - let db = &services.db; - - // Upgrade state store - let mut last_roomstates: HashMap = HashMap::new(); - let mut current_sstatehash: Option = None; - let mut current_room = None; - let mut current_state = HashSet::new(); - - let handle_state = |current_sstatehash: u64, - current_room: &RoomId, - current_state: HashSet<_>, - last_roomstates: &mut HashMap<_, _>| { - let last_roomsstatehash = last_roomstates.get(current_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |&last_roomsstatehash| { - services - .rooms - .state_compressor - .load_shortstatehash_info(last_roomsstatehash) - }, - )?; - - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .copied() - .collect::>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .copied() - .collect::>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - services.rooms.state_compressor.save_state_from_diff( - current_sstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - 2, // every state change is 2 event changes on average - states_parents, - )?; - - /* - let mut tmp = services.rooms.load_shortstatehash_info(¤t_sstatehash)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>() - ); - */ - - Ok::<_, Error>(()) - }; - - let stateid_shorteventid = &db["stateid_shorteventid"]; - let shorteventid_eventid = &db["shorteventid_eventid"]; - for (k, seventid) in stateid_shorteventid.iter() { - let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]).expect("number of bytes is correct"); - let sstatekey = k[size_of::()..].to_vec(); - if Some(sstatehash) != current_sstatehash { - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash); - } - current_state = HashSet::new(); - current_sstatehash = Some(sstatehash); - - let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap(); - let string = utils::string_from_bytes(&event_id).unwrap(); - let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap(); - - if Some(&pdu.room_id) != current_room.as_ref() { - current_room = Some(pdu.room_id.clone()); - } - } - - let mut val = sstatekey; - val.extend_from_slice(&seventid); - current_state.insert(val.try_into().expect("size is correct")); - } - - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - } - - services.globals.db.bump_database_version(7)?; - info!("Migration: 6 -> 7 finished"); - Ok(()) -} - -async fn db_lt_8(services: &Services) -> Result<()> { - let db = &services.db; - - let roomid_shortstatehash = &db["roomid_shortstatehash"]; - let roomid_shortroomid = &db["roomid_shortroomid"]; - let pduid_pdu = &db["pduid_pdu"]; - let eventid_pduid = &db["eventid_pduid"]; - - // Generate short room ids for all rooms - for (room_id, _) in roomid_shortstatehash.iter() { - let shortroomid = services.globals.next_count()?.to_be_bytes(); - roomid_shortroomid.insert(&room_id, &shortroomid)?; - info!("Migration: 8"); - } - // Update pduids db layout - let batch = pduid_pdu - .iter() - .filter_map(|(key, v)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_key = short_room_id.to_vec(); - new_key.extend_from_slice(count); - - Some(database::OwnedKeyVal(new_key, v)) - }) - .collect::>(); - - pduid_pdu.insert_batch(batch.iter().map(database::KeyVal::from))?; - - let batch2 = eventid_pduid - .iter() - .filter_map(|(k, value)| { - if !value.starts_with(b"!") { - return None; - } - let mut parts = value.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_value = short_room_id.to_vec(); - new_value.extend_from_slice(count); - - Some(database::OwnedKeyVal(k, new_value)) - }) - .collect::>(); - - eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?; - - services.globals.db.bump_database_version(8)?; - info!("Migration: 7 -> 8 finished"); - Ok(()) -} - -async fn db_lt_9(services: &Services) -> Result<()> { - let db = &services.db; - - let tokenids = &db["tokenids"]; - let roomid_shortroomid = &db["roomid_shortroomid"]; - - // Update tokenids db layout - let mut iter = tokenids - .iter() - .filter_map(|(key, _)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(4, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let word = parts.next().unwrap(); - let _pdu_id_room = parts.next().unwrap(); - let pdu_id_count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - let mut new_key = short_room_id.to_vec(); - new_key.extend_from_slice(word); - new_key.push(0xFF); - new_key.extend_from_slice(pdu_id_count); - Some(database::OwnedKeyVal(new_key, Vec::::new())) - }) - .peekable(); - - while iter.peek().is_some() { - let batch = iter.by_ref().take(1000).collect::>(); - tokenids.insert_batch(batch.iter().map(database::KeyVal::from))?; - debug!("Inserted smaller batch"); - } - - info!("Deleting starts"); - - let batch2: Vec<_> = tokenids - .iter() - .filter_map(|(key, _)| { - if key.starts_with(b"!") { - Some(key) - } else { - None - } - }) - .collect(); - - for key in batch2 { - tokenids.remove(&key)?; - } - - services.globals.db.bump_database_version(9)?; - info!("Migration: 8 -> 9 finished"); - Ok(()) -} - -async fn db_lt_10(services: &Services) -> Result<()> { - let db = &services.db; - - let statekey_shortstatekey = &db["statekey_shortstatekey"]; - let shortstatekey_statekey = &db["shortstatekey_statekey"]; - - // Add other direction for shortstatekeys - for (statekey, shortstatekey) in statekey_shortstatekey.iter() { - shortstatekey_statekey.insert(&shortstatekey, &statekey)?; - } - - // Force E2EE device list updates so we can send them over federation - for user_id in services.users.iter().filter_map(Result::ok) { - services.users.mark_device_key_update(&user_id)?; - } - - services.globals.db.bump_database_version(10)?; - info!("Migration: 9 -> 10 finished"); - Ok(()) -} - -#[allow(unreachable_code)] -async fn db_lt_11(services: &Services) -> Result<()> { - error!("Dropping a column to clear data is not implemented yet."); - //let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"]; - //userdevicesessionid_uiaarequest.clear()?; - - services.globals.db.bump_database_version(11)?; - info!("Migration: 10 -> 11 finished"); - Ok(()) -} - async fn db_lt_12(services: &Services) -> Result<()> { let config = &services.server.config; - for username in services.users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { + for username in &services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::>() + .await + { + let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) { Ok(u) => u, Err(e) => { warn!("Invalid username {username}: {e}"); @@ -652,7 +218,7 @@ async fn db_lt_12(services: &Services) -> Result<()> { let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() + .await .expect("Username is invalid"); let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); @@ -694,12 +260,15 @@ async fn db_lt_12(services: &Services) -> Result<()> { } } - services.account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; } services.globals.db.bump_database_version(12)?; @@ -710,8 +279,14 @@ async fn db_lt_12(services: &Services) -> Result<()> { async fn db_lt_13(services: &Services) -> Result<()> { let config = &services.server.config; - for username in services.users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { + for username in &services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::>() + .await + { + let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) { Ok(u) => u, Err(e) => { warn!("Invalid username {username}: {e}"); @@ -722,7 +297,7 @@ async fn db_lt_13(services: &Services) -> Result<()> { let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() + .await .expect("Username is invalid"); let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); @@ -733,12 +308,15 @@ async fn db_lt_13(services: &Services) -> Result<()> { .global .update_with_server_default(user_default_rules); - services.account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; } services.globals.db.bump_database_version(13)?; @@ -754,32 +332,37 @@ async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result< let _cork = db.cork_and_sync(); let mut iter_count: usize = 0; - for (mut key, value) in roomuserid_joined.iter() { - iter_count = iter_count.saturating_add(1); - debug_info!(%iter_count); - let first_sep_index = key - .iter() - .position(|&i| i == 0xFF) - .expect("found 0xFF delim"); + roomuserid_joined + .raw_stream() + .ignore_err() + .ready_for_each(|(key, value)| { + let mut key = key.to_vec(); + iter_count = iter_count.saturating_add(1); + debug_info!(%iter_count); + let first_sep_index = key + .iter() + .position(|&i| i == 0xFF) + .expect("found 0xFF delim"); - if key - .iter() - .get(first_sep_index..=first_sep_index.saturating_add(1)) - .copied() - .collect_vec() - == vec![0xFF, 0xFF] - { - debug_warn!("Found bad key: {key:?}"); - roomuserid_joined.remove(&key)?; + if key + .iter() + .get(first_sep_index..=first_sep_index.saturating_add(1)) + .copied() + .collect_vec() + == vec![0xFF, 0xFF] + { + debug_warn!("Found bad key: {key:?}"); + roomuserid_joined.remove(&key); - key.remove(first_sep_index); - debug_warn!("Fixed key: {key:?}"); - roomuserid_joined.insert(&key, &value)?; - } - } + key.remove(first_sep_index); + debug_warn!("Fixed key: {key:?}"); + roomuserid_joined.insert(&key, value); + } + }) + .await; db.db.cleanup()?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; + db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]); info!("Finished fixing"); Ok(()) @@ -795,69 +378,71 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) .rooms .metadata .iter_ids() - .filter_map(Result::ok) - .collect_vec(); + .map(ToOwned::to_owned) + .collect::>() + .await; - for room_id in room_ids.clone() { + for room_id in &room_ids { debug_info!("Fixing room {room_id}"); let users_in_room = services .rooms .state_cache - .room_members(&room_id) - .filter_map(Result::ok) - .collect_vec(); + .room_members(room_id) + .collect::>() + .await; let joined_members = users_in_room .iter() + .stream() .filter(|user_id| { services .rooms .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| membership.membership == MembershipState::Join) + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) }) - .collect_vec(); + .collect::>() + .await; let non_joined_members = users_in_room .iter() + .stream() .filter(|user_id| { services .rooms .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| { - membership.membership == MembershipState::Leave || membership.membership == MembershipState::Ban - }) + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) }) - .collect_vec(); + .collect::>() + .await; for user_id in joined_members { debug_info!("User is joined, marking as joined"); - services - .rooms - .state_cache - .mark_as_joined(user_id, &room_id)?; + services.rooms.state_cache.mark_as_joined(user_id, room_id); } for user_id in non_joined_members { debug_info!("User is left or banned, marking as left"); - services.rooms.state_cache.mark_as_left(user_id, &room_id)?; + services.rooms.state_cache.mark_as_left(user_id, room_id); } } - for room_id in room_ids { + for room_id in &room_ids { debug_info!( "Updating joined count for room {room_id} to fix servers in room after correcting membership states" ); - services.rooms.state_cache.update_joined_count(&room_id)?; + services + .rooms + .state_cache + .update_joined_count(room_id) + .await; } db.db.cleanup()?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]); info!("Finished fixing"); Ok(()) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 87f8f492..f777901f 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -288,8 +288,8 @@ impl Service { /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub fn verify_keys_for(&self, origin: &ServerName) -> Result> { - let mut keys = self.db.verify_keys_for(origin)?; + pub async fn verify_keys_for(&self, origin: &ServerName) -> Result> { + let mut keys = self.db.verify_keys_for(origin).await?; if origin == self.server_name() { keys.insert( format!("ed25519:{}", self.keypair().version()) @@ -304,8 +304,8 @@ impl Service { Ok(keys) } - pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { - self.db.signing_keys_for(origin) + pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { + self.db.signing_keys_for(origin).await } pub fn well_known_client(&self) -> &Option { &self.config.well_known.client } diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs deleted file mode 100644 index 30ac593b..00000000 --- a/src/service/key_backups/data.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::{collections::BTreeMap, sync::Arc}; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{ - api::client::{ - backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - error::ErrorKind, - }, - serde::Raw, - OwnedRoomId, RoomId, UserId, -}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - backupid_algorithm: Arc, - backupid_etag: Arc, - backupkeyid_backup: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - backupid_algorithm: db["backupid_algorithm"].clone(), - backupid_etag: db["backupid_etag"].clone(), - backupkeyid_backup: db["backupkeyid_backup"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { - let version = self.services.globals.next_count()?.to_string(); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.insert( - &key, - &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), - )?; - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - Ok(version) - } - - pub(super) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.remove(&key)?; - self.backupid_etag.remove(&key)?; - - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_algorithm - .insert(&key, backup_metadata.json().get().as_bytes())?; - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - Ok(version.to_owned()) - } - - pub(super) fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, _)| { - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) - }) - .transpose() - } - - pub(super) fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, value)| { - let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; - - Ok(( - version, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?, - )) - }) - .transpose() - } - - pub(super) fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm - .get(&key)? - .map_or(Ok(None), |bytes| { - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) - }) - } - - pub(super) fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .insert(&key, key_data.json().get().as_bytes())?; - - Ok(()) - } - - pub(super) fn count_keys(&self, user_id: &UserId, version: &str) -> Result { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - - Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) - } - - pub(super) fn get_etag(&self, user_id: &UserId, version: &str) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - Ok(utils::u64_from_bytes( - &self - .backupid_etag - .get(&key)? - .ok_or_else(|| Error::bad_database("Backup has no etag."))?, - ) - .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? - .to_string()) - } - - pub(super) fn get_all(&self, user_id: &UserId, version: &str) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - - let mut rooms = BTreeMap::::new(); - - for result in self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let room_id = RoomId::parse( - utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((room_id, session_id, key_data)) - }) { - let (room_id, session_id, key_data) = result?; - rooms - .entry(room_id) - .or_insert_with(|| RoomKeyBackup { - sessions: BTreeMap::new(), - }) - .sessions - .insert(session_id, key_data); - } - - Ok(rooms) - } - - pub(super) fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - Ok(self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((session_id, key_data)) - }) - .filter_map(Result::ok) - .collect()) - } - - pub(super) fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .get(&key)? - .map(|value| { - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")) - }) - .transpose() - } - - pub(super) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn delete_room_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } -} diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 65d3c065..12712e79 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,93 +1,319 @@ -mod data; - use std::{collections::BTreeMap, sync::Arc}; -use conduit::Result; -use data::Data; +use conduit::{ + err, implement, utils, + utils::stream::{ReadyExt, TryIgnore}, + Err, Error, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::StreamExt; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, OwnedRoomId, RoomId, UserId, }; +use crate::{globals, Dep}; + pub struct Service { db: Data, + services: Services, +} + +struct Data { + backupid_algorithm: Arc, + backupid_etag: Arc, + backupkeyid_backup: Arc, +} + +struct Services { + globals: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + backupid_algorithm: args.db["backupid_algorithm"].clone(), + backupid_etag: args.db["backupid_etag"].clone(), + backupkeyid_backup: args.db["backupkeyid_backup"].clone(), + }, + services: Services { + globals: args.depend::("globals"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { - self.db.create_backup(user_id, backup_metadata) - } +#[implement(Service)] +pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { + let version = self.services.globals.next_count()?.to_string(); - pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_backup(user_id, version) - } + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - pub fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw, - ) -> Result { - self.db.update_backup(user_id, version, backup_metadata) - } + self.db.backupid_algorithm.insert( + &key, + &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), + ); - pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - self.db.get_latest_backup_version(user_id) - } + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); - pub fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { - self.db.get_latest_backup(user_id) - } - - pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { - self.db.get_backup(user_id, version) - } - - pub fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, - ) -> Result<()> { - self.db - .add_key(user_id, version, room_id, session_id, key_data) - } - - pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { self.db.count_keys(user_id, version) } - - pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result { self.db.get_etag(user_id, version) } - - pub fn get_all(&self, user_id: &UserId, version: &str) -> Result> { - self.db.get_all(user_id, version) - } - - pub fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result>> { - self.db.get_room(user_id, version, room_id) - } - - pub fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result>> { - self.db.get_session(user_id, version, room_id, session_id) - } - - pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_all_keys(user_id, version) - } - - pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - self.db.delete_room_keys(user_id, version, room_id) - } - - pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { - self.db - .delete_room_key(user_id, version, room_id, session_id) - } + Ok(version) +} + +#[implement(Service)] +pub async fn delete_backup(&self, user_id: &UserId, version: &str) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + self.db.backupid_algorithm.remove(&key); + self.db.backupid_etag.remove(&key); + + let key = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn update_backup( + &self, user_id: &UserId, version: &str, backup_metadata: &Raw, +) -> Result { + let key = (user_id, version); + if self.db.backupid_algorithm.qry(&key).await.is_err() { + return Err!(Request(NotFound("Tried to update nonexistent backup."))); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + self.db + .backupid_algorithm + .insert(&key, backup_metadata.json().get().as_bytes()); + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); + + Ok(version.to_owned()) +} + +#[implement(Service)] +pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.db + .backupid_algorithm + .rev_raw_keys_from(&last_possible_key) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .next() + .await + .ok_or_else(|| err!(Request(NotFound("No backup versions found")))) + .and_then(|key| { + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) + }) +} + +#[implement(Service)] +pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.db + .backupid_algorithm + .rev_raw_stream_from(&last_possible_key) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .next() + .await + .ok_or_else(|| err!(Request(NotFound("No backup found")))) + .and_then(|(key, val)| { + let version = utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; + + let algorithm = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?; + + Ok((version, algorithm)) + }) +} + +#[implement(Service)] +pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result> { + let key = (user_id, version); + self.db + .backupid_algorithm + .qry(&key) + .await + .deserialized_json() +} + +#[implement(Service)] +pub async fn add_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, +) -> Result<()> { + let key = (user_id, version); + if self.db.backupid_algorithm.qry(&key).await.is_err() { + return Err!(Request(NotFound("Tried to update nonexistent backup."))); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); + + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); + + self.db + .backupkeyid_backup + .insert(&key, key_data.json().get().as_bytes()); + + Ok(()) +} + +#[implement(Service)] +pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize { + let prefix = (user_id, version); + self.db + .backupkeyid_backup + .keys_raw_prefix(&prefix) + .count() + .await +} + +#[implement(Service)] +pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String { + let key = (user_id, version); + self.db + .backupid_etag + .qry(&key) + .await + .deserialized::() + .as_ref() + .map(ToString::to_string) + .expect("Backup has no etag.") +} + +#[implement(Service)] +pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap { + type KeyVal<'a> = ((Ignore, Ignore, &'a RoomId, &'a str), &'a [u8]); + + let mut rooms = BTreeMap::::new(); + let default = || RoomKeyBackup { + sessions: BTreeMap::new(), + }; + + let prefix = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .stream_prefix(&prefix) + .ignore_err() + .ready_for_each(|((_, _, room_id, session_id), value): KeyVal<'_>| { + let key_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON"); + rooms + .entry(room_id.into()) + .or_insert_with(default) + .sessions + .insert(session_id.into(), key_data); + }) + .await; + + rooms +} + +#[implement(Service)] +pub async fn get_room( + &self, user_id: &UserId, version: &str, room_id: &RoomId, +) -> BTreeMap> { + type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), &'a [u8]); + + let prefix = (user_id, version, room_id, Interfix); + self.db + .backupkeyid_backup + .stream_prefix(&prefix) + .ignore_err() + .map(|((.., session_id), value): KeyVal<'_>| { + let session_id = session_id.to_owned(); + let key_backup_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON"); + (session_id, key_backup_data) + }) + .collect() + .await +} + +#[implement(Service)] +pub async fn get_session( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, +) -> Result> { + let key = (user_id, version, room_id, session_id); + + self.db + .backupkeyid_backup + .qry(&key) + .await + .deserialized_json() +} + +#[implement(Service)] +pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) { + let key = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) { + let key = (user_id, version, room_id, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) { + let key = (user_id, version, room_id, session_id); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; } diff --git a/src/service/manager.rs b/src/service/manager.rs index 42260bb3..21e0ed7c 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -1,7 +1,7 @@ use std::{panic::AssertUnwindSafe, sync::Arc, time::Duration}; use conduit::{debug, debug_warn, error, trace, utils::time, warn, Err, Error, Result, Server}; -use futures_util::FutureExt; +use futures::FutureExt; use tokio::{ sync::{Mutex, MutexGuard}, task::{JoinHandle, JoinSet}, diff --git a/src/service/media/data.rs b/src/service/media/data.rs index e5d6d20b..29d562cc 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -2,10 +2,11 @@ use std::sync::Arc; use conduit::{ debug, debug_info, trace, - utils::{str_from_bytes, string_from_bytes}, + utils::{str_from_bytes, stream::TryIgnore, string_from_bytes, ReadyExt}, Err, Error, Result, }; use database::{Database, Map}; +use futures::StreamExt; use ruma::{api::client::error::ErrorKind, http_headers::ContentDisposition, Mxc, OwnedMxcUri, UserId}; use super::{preview::UrlPreviewData, thumbnail::Dim}; @@ -59,7 +60,7 @@ impl Data { .unwrap_or_default(), ); - self.mediaid_file.insert(&key, &[])?; + self.mediaid_file.insert(&key, &[]); if let Some(user) = user { let mut key: Vec = Vec::new(); @@ -68,13 +69,13 @@ impl Data { key.extend_from_slice(b"/"); key.extend_from_slice(mxc.media_id.as_bytes()); let user = user.as_bytes().to_vec(); - self.mediaid_user.insert(&key, &user)?; + self.mediaid_user.insert(&key, &user); } Ok(key) } - pub(super) fn delete_file_mxc(&self, mxc: &Mxc<'_>) -> Result<()> { + pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { debug!("MXC URI: {mxc}"); let mut prefix: Vec = Vec::new(); @@ -85,25 +86,31 @@ impl Data { prefix.push(0xFF); trace!("MXC db prefix: {prefix:?}"); - for (key, _) in self.mediaid_file.scan_prefix(prefix.clone()) { - debug!("Deleting key: {:?}", key); - self.mediaid_file.remove(&key)?; - } + self.mediaid_file + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| { + debug!("Deleting key: {:?}", key); + self.mediaid_file.remove(key); + }) + .await; - for (key, value) in self.mediaid_user.scan_prefix(prefix.clone()) { - if key.starts_with(&prefix) { - let user = str_from_bytes(&value).unwrap_or_default(); + self.mediaid_user + .raw_stream_prefix(&prefix) + .ignore_err() + .ready_for_each(|(key, val)| { + if key.starts_with(&prefix) { + let user = str_from_bytes(val).unwrap_or_default(); + debug_info!("Deleting key {key:?} which was uploaded by user {user}"); - debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}"); - self.mediaid_user.remove(&key)?; - } - } - - Ok(()) + self.mediaid_user.remove(key); + } + }) + .await; } /// Searches for all files with the given MXC - pub(super) fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result>> { + pub(super) async fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result>> { debug!("MXC URI: {mxc}"); let mut prefix: Vec = Vec::new(); @@ -115,9 +122,10 @@ impl Data { let keys: Vec> = self .mediaid_file - .scan_prefix(prefix) - .map(|(key, _)| key) - .collect(); + .keys_prefix_raw(&prefix) + .ignore_err() + .collect() + .await; if keys.is_empty() { return Err!(Database("Failed to find any keys in database for `{mxc}`",)); @@ -128,7 +136,7 @@ impl Data { Ok(keys) } - pub(super) fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result { + pub(super) async fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result { let mut prefix: Vec = Vec::new(); prefix.extend_from_slice(b"mxc://"); prefix.extend_from_slice(mxc.server_name.as_bytes()); @@ -139,10 +147,13 @@ impl Data { prefix.extend_from_slice(&dim.height.to_be_bytes()); prefix.push(0xFF); - let (key, _) = self + let key = self .mediaid_file - .scan_prefix(prefix) + .raw_keys_prefix(&prefix) + .ignore_err() + .map(ToOwned::to_owned) .next() + .await .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; let mut parts = key.rsplit(|&b| b == 0xFF); @@ -177,28 +188,31 @@ impl Data { } /// Gets all the MXCs associated with a user - pub(super) fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec { - let user_id = user_id.as_bytes().to_vec(); - + pub(super) async fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec { self.mediaid_user - .iter() - .filter_map(|(key, user)| { - if *user == user_id { - let mxc_s = string_from_bytes(&key).ok()?; - Some(OwnedMxcUri::from(mxc_s)) - } else { - None - } - }) + .stream() + .ignore_err() + .ready_filter_map(|(key, user): (&str, &UserId)| (user == user_id).then(|| key.into())) .collect() + .await } /// Gets all the media keys in our database (this includes all the metadata /// associated with it such as width, height, content-type, etc) - pub(crate) fn get_all_media_keys(&self) -> Vec> { self.mediaid_file.iter().map(|(key, _)| key).collect() } + pub(crate) async fn get_all_media_keys(&self) -> Vec> { + self.mediaid_file + .raw_keys() + .ignore_err() + .map(<[u8]>::to_vec) + .collect() + .await + } #[inline] - pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } + pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { + self.url_previews.remove(url.as_bytes()); + Ok(()) + } pub(super) fn set_url_preview( &self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration, @@ -233,11 +247,13 @@ impl Data { value.push(0xFF); value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes()); - self.url_previews.insert(url.as_bytes(), &value) + self.url_previews.insert(url.as_bytes(), &value); + + Ok(()) } - pub(super) fn get_url_preview(&self, url: &str) -> Option { - let values = self.url_previews.get(url.as_bytes()).ok()??; + pub(super) async fn get_url_preview(&self, url: &str) -> Result { + let values = self.url_previews.qry(url).await?; let mut values = values.split(|&b| b == 0xFF); @@ -291,7 +307,7 @@ impl Data { x => x, }; - Some(UrlPreviewData { + Ok(UrlPreviewData { title, description, image, diff --git a/src/service/media/migrations.rs b/src/service/media/migrations.rs index 9968d25b..2d1b39f9 100644 --- a/src/service/media/migrations.rs +++ b/src/service/media/migrations.rs @@ -7,7 +7,11 @@ use std::{ time::Instant, }; -use conduit::{debug, debug_info, debug_warn, error, info, warn, Config, Result}; +use conduit::{ + debug, debug_info, debug_warn, error, info, + utils::{stream::TryIgnore, ReadyExt}, + warn, Config, Result, +}; use crate::{globals, Services}; @@ -23,12 +27,17 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Move old media files to new names let mut changes = Vec::<(PathBuf, PathBuf)>::new(); - for (key, _) in mediaid_file.iter() { - let old = services.media.get_media_file_b64(&key); - let new = services.media.get_media_file_sha256(&key); - debug!(?key, ?old, ?new, num = changes.len(), "change"); - changes.push((old, new)); - } + mediaid_file + .raw_keys() + .ignore_err() + .ready_for_each(|key| { + let old = services.media.get_media_file_b64(key); + let new = services.media.get_media_file_sha256(key); + debug!(?key, ?old, ?new, num = changes.len(), "change"); + changes.push((old, new)); + }) + .await; + // move the file to the new location for (old_path, path) in changes { if old_path.exists() { @@ -41,11 +50,11 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Apply fix from when sha256_media was backward-incompat and bumped the schema // version from 13 to 14. For users satisfying these conditions we can go back. - if services.globals.db.database_version()? == 14 && globals::migrations::DATABASE_VERSION == 13 { + if services.globals.db.database_version().await == 14 && globals::migrations::DATABASE_VERSION == 13 { services.globals.db.bump_database_version(13)?; } - db["global"].insert(b"feat_sha256_media", &[])?; + db["global"].insert(b"feat_sha256_media", &[]); info!("Finished applying sha256_media"); Ok(()) } @@ -71,7 +80,7 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> { .filter_map(|ent| ent.map_or(None, |ent| Some(ent.path().into_os_string()))) .collect(); - for key in media.db.get_all_media_keys() { + for key in media.db.get_all_media_keys().await { let new_path = media.get_media_file_sha256(&key).into_os_string(); let old_path = media.get_media_file_b64(&key).into_os_string(); if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await { @@ -112,8 +121,8 @@ async fn handle_media_check( "Media is missing at all paths. Removing from database..." ); - mediaid_file.remove(key)?; - mediaid_user.remove(key)?; + mediaid_file.remove(key); + mediaid_user.remove(key); } if config.media_compat_file_link && !old_exists && new_exists { diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index d3765a17..c0b15726 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -97,7 +97,7 @@ impl Service { /// Deletes a file in the database and from the media directory via an MXC pub async fn delete(&self, mxc: &Mxc<'_>) -> Result<()> { - if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc) { + if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc).await { for key in keys { trace!(?mxc, "MXC Key: {key:?}"); debug_info!(?mxc, "Deleting from filesystem"); @@ -107,7 +107,7 @@ impl Service { } debug_info!(?mxc, "Deleting from database"); - _ = self.db.delete_file_mxc(mxc); + self.db.delete_file_mxc(mxc).await; } Ok(()) @@ -120,7 +120,7 @@ impl Service { /// /// currently, this is only practical for local users pub async fn delete_from_user(&self, user: &UserId) -> Result { - let mxcs = self.db.get_all_user_mxcs(user); + let mxcs = self.db.get_all_user_mxcs(user).await; let mut deletion_count: usize = 0; for mxc in mxcs { @@ -150,7 +150,7 @@ impl Service { content_disposition, content_type, key, - }) = self.db.search_file_metadata(mxc, &Dim::default()) + }) = self.db.search_file_metadata(mxc, &Dim::default()).await { let mut content = Vec::new(); let path = self.get_media_file(&key); @@ -170,7 +170,7 @@ impl Service { /// Gets all the MXC URIs in our media database pub async fn get_all_mxcs(&self) -> Result> { - let all_keys = self.db.get_all_media_keys(); + let all_keys = self.db.get_all_media_keys().await; let mut mxcs = Vec::with_capacity(all_keys.len()); @@ -209,7 +209,7 @@ impl Service { pub async fn delete_all_remote_media_at_after_time( &self, time: SystemTime, before: bool, after: bool, yes_i_want_to_delete_local_media: bool, ) -> Result { - let all_keys = self.db.get_all_media_keys(); + let all_keys = self.db.get_all_media_keys().await; let mut remote_mxcs = Vec::with_capacity(all_keys.len()); for key in all_keys { @@ -343,9 +343,10 @@ impl Service { } #[inline] - pub fn get_metadata(&self, mxc: &Mxc<'_>) -> Option { + pub async fn get_metadata(&self, mxc: &Mxc<'_>) -> Option { self.db .search_file_metadata(mxc, &Dim::default()) + .await .map(|metadata| FileMeta { content_disposition: metadata.content_disposition, content_type: metadata.content_type, diff --git a/src/service/media/preview.rs b/src/service/media/preview.rs index 5704075e..6b147383 100644 --- a/src/service/media/preview.rs +++ b/src/service/media/preview.rs @@ -71,16 +71,16 @@ pub async fn download_image(&self, url: &str) -> Result { #[implement(Service)] pub async fn get_url_preview(&self, url: &str) -> Result { - if let Some(preview) = self.db.get_url_preview(url) { + if let Ok(preview) = self.db.get_url_preview(url).await { return Ok(preview); } // ensure that only one request is made per URL let _request_lock = self.url_preview_mutex.lock(url).await; - match self.db.get_url_preview(url) { - Some(preview) => Ok(preview), - None => self.request_url_preview(url).await, + match self.db.get_url_preview(url).await { + Ok(preview) => Ok(preview), + Err(_) => self.request_url_preview(url).await, } } diff --git a/src/service/media/thumbnail.rs b/src/service/media/thumbnail.rs index 630f7b3b..04ec0303 100644 --- a/src/service/media/thumbnail.rs +++ b/src/service/media/thumbnail.rs @@ -54,9 +54,9 @@ impl super::Service { // 0, 0 because that's the original file let dim = dim.normalized(); - if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim) { + if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim).await { self.get_thumbnail_saved(metadata).await - } else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()) { + } else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()).await { self.get_thumbnail_generate(mxc, &dim, metadata).await } else { Ok(None) diff --git a/src/service/mod.rs b/src/service/mod.rs index f588a542..cb8bfcd9 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -19,6 +19,7 @@ pub mod resolver; pub mod rooms; pub mod sending; pub mod server_keys; +pub mod sync; pub mod transaction_ids; pub mod uiaa; pub mod updates; diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index ec036b3d..0c3f3d31 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -1,7 +1,12 @@ use std::sync::Arc; -use conduit::{debug_warn, utils, Error, Result}; -use database::Map; +use conduit::{ + debug_warn, utils, + utils::{stream::TryIgnore, ReadyExt}, + Result, +}; +use database::{Deserialized, Map}; +use futures::Stream; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use super::Presence; @@ -31,39 +36,35 @@ impl Data { } } - pub fn get_presence(&self, user_id: &UserId) -> Result> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; + pub async fn get_presence(&self, user_id: &UserId) -> Result<(u64, PresenceEvent)> { + let count = self + .userid_presenceid + .qry(user_id) + .await + .deserialized::()?; - let key = presenceid_key(count, user_id); - self.presenceid_presence - .get(&key)? - .map(|presence_bytes| -> Result<(u64, PresenceEvent)> { - Ok(( - count, - Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?, - )) - }) - .transpose() - } else { - Ok(None) - } + let key = presenceid_key(count, user_id); + let bytes = self.presenceid_presence.qry(&key).await?; + let event = Presence::from_json_bytes(&bytes)? + .to_presence_event(user_id, &self.services.users) + .await; + + Ok((count, event)) } - pub(super) fn set_presence( + pub(super) async fn set_presence( &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { - let last_presence = self.get_presence(user_id)?; + let last_presence = self.get_presence(user_id).await; let state_changed = match last_presence { - None => true, - Some(ref presence) => presence.1.content.presence != *presence_state, + Err(_) => true, + Ok(ref presence) => presence.1.content.presence != *presence_state, }; let status_msg_changed = match last_presence { - None => true, - Some(ref last_presence) => { + Err(_) => true, + Ok(ref last_presence) => { let old_msg = last_presence .1 .content @@ -79,8 +80,8 @@ impl Data { let now = utils::millis_since_unix_epoch(); let last_last_active_ts = match last_presence { - None => 0, - Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), + Err(_) => 0, + Ok((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), }; let last_active_ts = match last_active_ago { @@ -90,12 +91,7 @@ impl Data { // TODO: tighten for state flicker? if !status_msg_changed && !state_changed && last_active_ts < last_last_active_ts { - debug_warn!( - "presence spam {:?} last_active_ts:{:?} < {:?}", - user_id, - last_active_ts, - last_last_active_ts - ); + debug_warn!("presence spam {user_id:?} last_active_ts:{last_active_ts:?} < {last_last_active_ts:?}",); return Ok(()); } @@ -115,41 +111,42 @@ impl Data { let key = presenceid_key(count, user_id); self.presenceid_presence - .insert(&key, &presence.to_json_bytes()?)?; + .insert(&key, &presence.to_json_bytes()?); self.userid_presenceid - .insert(user_id.as_bytes(), &count.to_be_bytes())?; + .insert(user_id.as_bytes(), &count.to_be_bytes()); - if let Some((last_count, _)) = last_presence { + if let Ok((last_count, _)) = last_presence { let key = presenceid_key(last_count, user_id); - self.presenceid_presence.remove(&key)?; + self.presenceid_presence.remove(&key); } Ok(()) } - pub(super) fn remove_presence(&self, user_id: &UserId) -> Result<()> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - let key = presenceid_key(count, user_id); - self.presenceid_presence.remove(&key)?; - self.userid_presenceid.remove(user_id.as_bytes())?; - } + pub(super) async fn remove_presence(&self, user_id: &UserId) { + let Ok(count) = self + .userid_presenceid + .qry(user_id) + .await + .deserialized::() + else { + return; + }; - Ok(()) + let key = presenceid_key(count, user_id); + self.presenceid_presence.remove(&key); + self.userid_presenceid.remove(user_id.as_bytes()); } - pub fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a> { - Box::new( - self.presenceid_presence - .iter() - .flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec)> { - let (count, user_id) = presenceid_parse(&key)?; - Ok((user_id.to_owned(), count, presence_bytes)) - }) - .filter(move |(_, count, _)| *count > since), - ) + pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { + self.presenceid_presence + .raw_stream() + .ignore_err() + .ready_filter_map(move |(key, presence_bytes)| { + let (count, user_id) = presenceid_parse(key).expect("invalid presenceid_parse"); + (count > since).then(|| (user_id.to_owned(), count, presence_bytes.to_vec())) + }) } } @@ -162,7 +159,7 @@ fn presenceid_key(count: u64, user_id: &UserId) -> Vec { fn presenceid_parse(key: &[u8]) -> Result<(u64, &UserId)> { let (count, user_id) = key.split_at(8); let user_id = user_id_from_bytes(user_id)?; - let count = utils::u64_from_bytes(count).unwrap(); + let count = utils::u64_from_u8(count); Ok((count, user_id)) } diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index a54a6d7c..3b5c4caf 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -4,8 +4,8 @@ mod presence; use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{checked, debug, error, Error, Result, Server}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{checked, debug, error, result::LogErr, Error, Result, Server}; +use futures::{stream::FuturesUnordered, Stream, StreamExt, TryFutureExt}; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use tokio::{sync::Mutex, time::sleep}; @@ -58,7 +58,9 @@ impl crate::Service for Service { loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { - Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, + Some(user_id) = presence_timers.next() => { + self.process_presence_timer(&user_id).await.log_err().ok(); + }, event = receiver.recv_async() => match event { Err(_e) => return Ok(()), Ok((user_id, timeout)) => { @@ -82,28 +84,27 @@ impl crate::Service for Service { impl Service { /// Returns the latest presence event for the given user. #[inline] - pub fn get_presence(&self, user_id: &UserId) -> Result> { - if let Some((_, presence)) = self.db.get_presence(user_id)? { - Ok(Some(presence)) - } else { - Ok(None) - } + pub async fn get_presence(&self, user_id: &UserId) -> Result { + self.db + .get_presence(user_id) + .map_ok(|(_, presence)| presence) + .await } /// Pings the presence of the given user in the given room, setting the /// specified state. - pub fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { + pub async fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { const REFRESH_TIMEOUT: u64 = 60 * 25 * 1000; - let last_presence = self.db.get_presence(user_id)?; + let last_presence = self.db.get_presence(user_id).await; let state_changed = match last_presence { - None => true, - Some((_, ref presence)) => presence.content.presence != *new_state, + Err(_) => true, + Ok((_, ref presence)) => presence.content.presence != *new_state, }; let last_last_active_ago = match last_presence { - None => 0_u64, - Some((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(), + Err(_) => 0_u64, + Ok((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(), }; if !state_changed && last_last_active_ago < REFRESH_TIMEOUT { @@ -111,17 +112,18 @@ impl Service { } let status_msg = match last_presence { - Some((_, ref presence)) => presence.content.status_msg.clone(), - None => Some(String::new()), + Ok((_, ref presence)) => presence.content.status_msg.clone(), + Err(_) => Some(String::new()), }; let last_active_ago = UInt::new(0); let currently_active = *new_state == PresenceState::Online; self.set_presence(user_id, new_state, Some(currently_active), last_active_ago, status_msg) + .await } /// Adds a presence event which will be saved until a new event replaces it. - pub fn set_presence( + pub async fn set_presence( &self, user_id: &UserId, state: &PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { @@ -131,7 +133,8 @@ impl Service { }; self.db - .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg)?; + .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg) + .await?; if self.timeout_remote_users || self.services.globals.user_is_local(user_id) { let timeout = match presence_state { @@ -154,28 +157,33 @@ impl Service { /// /// TODO: Why is this not used? #[allow(dead_code)] - pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) } + pub async fn remove_presence(&self, user_id: &UserId) { self.db.remove_presence(user_id).await } /// Returns the most recent presence updates that happened after the event /// with id `since`. #[inline] - pub fn presence_since(&self, since: u64) -> Box)> + '_> { + pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { self.db.presence_since(since) } - pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result { + #[inline] + pub async fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result { let presence = Presence::from_json_bytes(bytes)?; - presence.to_presence_event(user_id, &self.services.users) + let event = presence + .to_presence_event(user_id, &self.services.users) + .await; + + Ok(event) } - fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { + async fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { let mut presence_state = PresenceState::Offline; let mut last_active_ago = None; let mut status_msg = None; - let presence_event = self.get_presence(user_id)?; + let presence_event = self.get_presence(user_id).await; - if let Some(presence_event) = presence_event { + if let Ok(presence_event) = presence_event { presence_state = presence_event.content.presence; last_active_ago = presence_event.content.last_active_ago; status_msg = presence_event.content.status_msg; @@ -192,7 +200,8 @@ impl Service { ); if let Some(new_state) = new_state { - self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; + self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg) + .await?; } Ok(()) diff --git a/src/service/presence/presence.rs b/src/service/presence/presence.rs index 570008f2..0d5c226b 100644 --- a/src/service/presence/presence.rs +++ b/src/service/presence/presence.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use conduit::{utils, Error, Result}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, @@ -42,7 +40,7 @@ impl Presence { } /// Creates a PresenceEvent from available data. - pub(super) fn to_presence_event(&self, user_id: &UserId, users: &Arc) -> Result { + pub(super) async fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> PresenceEvent { let now = utils::millis_since_unix_epoch(); let last_active_ago = if self.currently_active { None @@ -50,16 +48,16 @@ impl Presence { Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts))) }; - Ok(PresenceEvent { + PresenceEvent { sender: user_id.to_owned(), content: PresenceEventContent { presence: self.state.clone(), status_msg: self.status_msg.clone(), currently_active: Some(self.currently_active), last_active_ago, - displayname: users.displayname(user_id)?, - avatar_url: users.avatar_url(user_id)?, + displayname: users.displayname(user_id).await.ok(), + avatar_url: users.avatar_url(user_id).await.ok(), }, - }) + } } } diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs deleted file mode 100644 index f9734334..00000000 --- a/src/service/pusher/data.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{ - api::client::push::{set_pusher, Pusher}, - UserId, -}; - -pub(super) struct Data { - senderkey_pusher: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - senderkey_pusher: db["senderkey_pusher"].clone(), - } - } - - pub(super) fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { - match pusher { - set_pusher::v3::PusherAction::Post(data) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); - self.senderkey_pusher - .insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value"))?; - Ok(()) - }, - set_pusher::v3::PusherAction::Delete(ids) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(ids.pushkey.as_bytes()); - self.senderkey_pusher.remove(&key).map_err(Into::into) - }, - } - } - - pub(super) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - let mut senderkey = sender.as_bytes().to_vec(); - senderkey.push(0xFF); - senderkey.extend_from_slice(pushkey.as_bytes()); - - self.senderkey_pusher - .get(&senderkey)? - .map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .transpose() - } - - pub(super) fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .collect() - } - - pub(super) fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { - let mut parts = k.splitn(2, |&b| b == 0xFF); - let _senderkey = parts.next(); - let push_key = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; - let push_key_string = utils::string_from_bytes(push_key) - .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; - - Ok(push_key_string) - })) - } -} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index de87264c..44ff1945 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,9 +1,13 @@ -mod data; - use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -use conduit::{debug_error, err, trace, utils::string_from_bytes, warn, Err, PduEvent, Result}; +use conduit::{ + debug_error, err, trace, + utils::{stream::TryIgnore, string_from_bytes}, + Err, PduEvent, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{Stream, StreamExt}; use ipaddress::IPAddress; use ruma::{ api::{ @@ -22,12 +26,11 @@ use ruma::{ uint, RoomId, UInt, UserId, }; -use self::data::Data; use crate::{client, globals, rooms, users, Dep}; pub struct Service { - services: Services, db: Data, + services: Services, } struct Services { @@ -38,9 +41,16 @@ struct Services { users: Dep, } +struct Data { + senderkey_pusher: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + db: Data { + senderkey_pusher: args.db["senderkey_pusher"].clone(), + }, services: Services { globals: args.depend::("globals"), client: args.depend::("client"), @@ -48,7 +58,6 @@ impl crate::Service for Service { state_cache: args.depend::("rooms::state_cache"), users: args.depend::("users"), }, - db: Data::new(args.db), })) } @@ -56,19 +65,52 @@ impl crate::Service for Service { } impl Service { - pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { - self.db.set_pusher(sender, pusher) + pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) { + match pusher { + set_pusher::v3::PusherAction::Post(data) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); + self.db + .senderkey_pusher + .insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value")); + }, + set_pusher::v3::PusherAction::Delete(ids) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(ids.pushkey.as_bytes()); + self.db.senderkey_pusher.remove(&key); + }, + } } - pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - self.db.get_pusher(sender, pushkey) + pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result { + let senderkey = (sender, pushkey); + self.db + .senderkey_pusher + .qry(&senderkey) + .await + .deserialized_json() } - pub fn get_pushers(&self, sender: &UserId) -> Result> { self.db.get_pushers(sender) } + pub async fn get_pushers(&self, sender: &UserId) -> Vec { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .stream_prefix(&prefix) + .ignore_err() + .map(|(_, val): (Ignore, &[u8])| serde_json::from_slice(val).expect("Invalid Pusher in db.")) + .collect() + .await + } - #[must_use] - pub fn get_pushkeys(&self, sender: &UserId) -> Box> + '_> { - self.db.get_pushkeys(sender) + pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, pushkey): (Ignore, &str)| pushkey) } #[tracing::instrument(skip(self, dest, request))] @@ -161,15 +203,18 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "") + .await + .and_then(|ev| { serde_json::from_str(ev.content.get()) - .map_err(|e| err!(Database("invalid m.room.power_levels event: {e:?}"))) + .map_err(|e| err!(Database(error!("invalid m.room.power_levels event: {e:?}")))) }) - .transpose()? .unwrap_or_default(); - for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? { + for action in self + .get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id) + .await? + { let n = match action { Action::Notify => true, Action::SetTweak(tweak) => { @@ -197,7 +242,7 @@ impl Service { } #[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")] - pub fn get_actions<'a>( + pub async fn get_actions<'a>( &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, pdu: &Raw, room_id: &RoomId, ) -> Result<&'a [Action]> { @@ -207,21 +252,27 @@ impl Service { notifications: power_levels.notifications.clone(), }; + let room_joined_count = self + .services + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(1) + .try_into() + .unwrap_or_else(|_| uint!(0)); + + let user_display_name = self + .services + .users + .displayname(user) + .await + .unwrap_or_else(|_| user.localpart().to_owned()); + let ctx = PushConditionRoomCtx { room_id: room_id.to_owned(), - member_count: UInt::try_from( - self.services - .state_cache - .room_joined_count(room_id)? - .unwrap_or(1), - ) - .unwrap_or_else(|_| uint!(0)), + member_count: room_joined_count, user_id: user.to_owned(), - user_display_name: self - .services - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), + user_display_name, power_levels: Some(power_levels), }; @@ -278,9 +329,14 @@ impl Service { notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); } - notifi.sender_display_name = self.services.users.displayname(&event.sender)?; + notifi.sender_display_name = self.services.users.displayname(&event.sender).await.ok(); - notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?; + notifi.room_name = self + .services + .state_accessor + .get_name(&event.room_id) + .await + .ok(); self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) .await?; diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 07d9a0fa..ea4b1100 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -193,7 +193,7 @@ impl super::Service { .send() .await; - trace!("response: {:?}", response); + trace!("response: {response:?}"); if let Err(e) = &response { debug!("error: {e:?}"); return Ok(None); @@ -206,7 +206,7 @@ impl super::Service { } let text = response.text().await?; - trace!("response text: {:?}", text); + trace!("response text: {text:?}"); if text.len() >= 12288 { debug_warn!("response contains junk"); return Ok(None); @@ -225,7 +225,7 @@ impl super::Service { return Ok(None); } - debug_info!("{:?} found at {:?}", dest, m_server); + debug_info!("{dest:?} found at {m_server:?}"); Ok(Some(m_server.to_owned())) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs deleted file mode 100644 index efd2b5b7..00000000 --- a/src/service/rooms/alias/data.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - alias_userid: Arc, - alias_roomid: Arc, - aliasid_alias: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - alias_userid: db["alias_userid"].clone(), - alias_roomid: db["alias_roomid"].clone(), - aliasid_alias: db["aliasid_alias"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { - // Comes first as we don't want a stuck alias - self.alias_userid - .insert(alias.alias().as_bytes(), user_id.as_bytes())?; - - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xFF); - aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; - - Ok(()) - } - - pub(super) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id.to_vec(); - prefix.push(0xFF); - - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - - self.alias_roomid.remove(alias.alias().as_bytes())?; - - self.alias_userid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist or is invalid.")); - } - - Ok(()) - } - - pub(super) fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.alias_roomid - .get(alias.alias().as_bytes())? - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) - }) - .transpose() - } - - pub(super) fn who_created_alias(&self, alias: &RoomAliasId) -> Result> { - self.alias_userid - .get(alias.alias().as_bytes())? - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("User ID in alias_userid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in alias_roomid is invalid.")) - }) - .transpose() - } - - pub(super) fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a + Send> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - })) - } - - pub(super) fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { - Box::new( - self.alias_roomid - .iter() - .map(|(room_alias_bytes, room_id_bytes)| { - let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; - - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; - - Ok((room_id, room_alias_localpart)) - }), - ) - } -} diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index f2e01ab5..6b81a221 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,19 +1,23 @@ -mod data; mod remote; use std::sync::Arc; -use conduit::{err, Error, Result}; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt}, + Err, Error, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{Stream, StreamExt}; use ruma::{ api::client::error::ErrorKind, events::{ room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, StateEventType, }, - OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, UserId, }; -use self::data::Data; use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, Dep}; pub struct Service { @@ -21,6 +25,12 @@ pub struct Service { services: Services, } +struct Data { + alias_userid: Arc, + alias_roomid: Arc, + aliasid_alias: Arc, +} + struct Services { admin: Dep, appservice: Dep, @@ -32,7 +42,11 @@ struct Services { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + alias_userid: args.db["alias_userid"].clone(), + alias_roomid: args.db["alias_roomid"].clone(), + aliasid_alias: args.db["aliasid_alias"].clone(), + }, services: Services { admin: args.depend::("admin"), appservice: args.depend::("appservice"), @@ -50,25 +64,52 @@ impl Service { #[tracing::instrument(skip(self))] pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user { - Err(Error::BadRequest( + return Err(Error::BadRequest( ErrorKind::forbidden(), "Only the server user can set this alias", - )) - } else { - self.db.set_alias(alias, room_id, user_id) + )); } + + // Comes first as we don't want a stuck alias + self.db + .alias_userid + .insert(alias.alias().as_bytes(), user_id.as_bytes()); + + self.db + .alias_roomid + .insert(alias.alias().as_bytes(), room_id.as_bytes()); + + let mut aliasid = room_id.as_bytes().to_vec(); + aliasid.push(0xFF); + aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + self.db.aliasid_alias.insert(&aliasid, alias.as_bytes()); + + Ok(()) } #[tracing::instrument(skip(self))] pub async fn remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<()> { - if self.user_can_remove_alias(alias, user_id).await? { - self.db.remove_alias(alias) - } else { - Err(Error::BadRequest( - ErrorKind::forbidden(), - "User is not permitted to remove this alias.", - )) + if !self.user_can_remove_alias(alias, user_id).await? { + return Err!(Request(Forbidden("User is not permitted to remove this alias."))); } + + let alias = alias.alias(); + let Ok(room_id) = self.db.alias_roomid.qry(&alias).await else { + return Err!(Request(NotFound("Alias does not exist or is invalid."))); + }; + + let prefix = (&room_id, Interfix); + self.db + .aliasid_alias + .keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key: &[u8]| self.db.aliasid_alias.remove(&key)) + .await; + + self.db.alias_roomid.remove(alias.as_bytes()); + self.db.alias_userid.remove(alias.as_bytes()); + + Ok(()) } pub async fn resolve(&self, room: &RoomOrAliasId) -> Result { @@ -97,9 +138,9 @@ impl Service { return self.remote_resolve(room_alias, servers).await; } - let room_id: Option = match self.resolve_local_alias(room_alias)? { - Some(r) => Some(r), - None => self.resolve_appservice_alias(room_alias).await?, + let room_id: Option = match self.resolve_local_alias(room_alias).await { + Ok(r) => Some(r), + Err(_) => self.resolve_appservice_alias(room_alias).await?, }; room_id.map_or_else( @@ -109,46 +150,54 @@ impl Service { } #[tracing::instrument(skip(self), level = "debug")] - pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.db.resolve_local_alias(alias) + pub async fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result { + self.db.alias_roomid.qry(alias.alias()).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a + Send> { - self.db.local_aliases_for_room(room_id) + pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .aliasid_alias + .stream_prefix(&prefix) + .ignore_err() + .map(|((Ignore, Ignore), alias): ((Ignore, Ignore), &RoomAliasId)| alias) } #[tracing::instrument(skip(self), level = "debug")] - pub fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { - self.db.all_local_aliases() + pub fn all_local_aliases<'a>(&'a self) -> impl Stream + Send + 'a { + self.db + .alias_roomid + .stream() + .ignore_err() + .map(|(alias_localpart, room_id): (&str, &RoomId)| (room_id, alias_localpart)) } async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result { - let Some(room_id) = self.resolve_local_alias(alias)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found.")); - }; + let room_id = self + .resolve_local_alias(alias) + .await + .map_err(|_| err!(Request(NotFound("Alias not found."))))?; let server_user = &self.services.globals.server_user; // The creator of an alias can remove it if self - .db - .who_created_alias(alias)? - .is_some_and(|user| user == user_id) + .who_created_alias(alias).await + .is_ok_and(|user| user == user_id) // Server admins can remove any local alias - || self.services.admin.user_is_admin(user_id).await? + || self.services.admin.user_is_admin(user_id).await // Always allow the server service account to remove the alias, since there may not be an admin room || server_user == user_id { Ok(true) // Checking whether the user is able to change canonical aliases of the // room - } else if let Some(event) = - self.services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? + } else if let Ok(event) = self + .services + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "") + .await { serde_json::from_str(event.content.get()) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) @@ -157,10 +206,11 @@ impl Service { }) // If there is no power levels event, only the room creator can change // canonical aliases - } else if let Some(event) = - self.services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? + } else if let Ok(event) = self + .services + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCreate, "") + .await { Ok(event.sender == user_id) } else { @@ -168,6 +218,10 @@ impl Service { } } + async fn who_created_alias(&self, alias: &RoomAliasId) -> Result { + self.db.alias_userid.qry(alias.alias()).await.deserialized() + } + async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result> { use ruma::api::appservice::query::query_room_alias; @@ -185,10 +239,11 @@ impl Service { .await, Ok(Some(_opt_result)) ) { - return Ok(Some( - self.resolve_local_alias(room_alias)? - .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, - )); + return self + .resolve_local_alias(room_alias) + .await + .map_err(|_| err!(Request(NotFound("Room does not exist.")))) + .map(Some); } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 6e7c7835..3d00374e 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -24,7 +24,7 @@ impl Data { } } - pub(super) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { // Check RAM cache if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { return Ok(Some(Arc::clone(result))); @@ -33,17 +33,14 @@ impl Data { // We only save auth chains for single events in the db if key.len() == 1 { // Check DB cache - let chain = self - .shorteventid_authchain - .get(&key[0].to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::()) - .map(utils::u64_from_u8) - .collect::>() - }); + let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| { + chain + .chunks_exact(size_of::()) + .map(utils::u64_from_u8) + .collect::>() + }); - if let Some(chain) = chain { + if let Ok(chain) = chain { // Cache in RAM self.auth_chain_cache .lock() @@ -66,7 +63,7 @@ impl Data { .iter() .flat_map(|s| s.to_be_bytes().to_vec()) .collect::>(), - )?; + ); } // Cache in RAM diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index d0bc425f..7bc239d7 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,7 +5,8 @@ use std::{ sync::Arc, }; -use conduit::{debug, error, trace, validated, warn, Err, Result}; +use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; +use futures::{FutureExt, Stream, StreamExt}; use ruma::{EventId, RoomId}; use self::data::Data; @@ -38,7 +39,7 @@ impl crate::Service for Service { impl Service { pub async fn event_ids_iter<'a>( &'a self, room_id: &RoomId, starting_events_: Vec>, - ) -> Result> + 'a> { + ) -> Result> + Send + 'a> { let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); for starting_event in &starting_events_ { starting_events.push(starting_event); @@ -48,7 +49,13 @@ impl Service { .get_auth_chain(room_id, &starting_events) .await? .into_iter() - .filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok())) + .stream() + .filter_map(|sid| { + self.services + .short + .get_eventid_from_short(sid) + .map(Result::ok) + })) } #[tracing::instrument(skip_all, name = "auth_chain")] @@ -61,7 +68,8 @@ impl Service { for (i, &short) in self .services .short - .multi_get_or_create_shorteventid(starting_events)? + .multi_get_or_create_shorteventid(starting_events) + .await .iter() .enumerate() { @@ -85,7 +93,7 @@ impl Service { } let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key)? { + if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? { trace!("Found cache entry for whole chunk"); full_auth_chain.extend(cached.iter().copied()); hits = hits.saturating_add(1); @@ -96,12 +104,12 @@ impl Service { let mut misses2: usize = 0; let mut chunk_cache = Vec::with_capacity(chunk.len()); for (sevent_id, event_id) in chunk { - if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id])? { + if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? { trace!(?event_id, "Found cache entry for event"); chunk_cache.extend(cached.iter().copied()); hits2 = hits2.saturating_add(1); } else { - let auth_chain = self.get_auth_chain_inner(room_id, event_id)?; + let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; self.cache_auth_chain(vec![sevent_id], &auth_chain)?; chunk_cache.extend(auth_chain.iter()); misses2 = misses2.saturating_add(1); @@ -143,15 +151,16 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { trace!(?event_id, "processing auth event"); - match self.services.timeline.get_pdu(&event_id) { - Ok(Some(pdu)) => { + match self.services.timeline.get_pdu(&event_id).await { + Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"), + Ok(pdu) => { if pdu.room_id != room_id { return Err!(Request(Forbidden( "auth event {event_id:?} for incorrect room {} which is not {}", @@ -160,7 +169,11 @@ impl Service { ))); } for auth_event in &pdu.auth_events { - let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?; + let sauthevent = self + .services + .short + .get_or_create_shorteventid(auth_event) + .await; if found.insert(sauthevent) { trace!(?event_id, ?auth_event, "adding auth event to processing queue"); @@ -168,20 +181,14 @@ impl Service { } } }, - Ok(None) => { - warn!(?event_id, "Could not find pdu mentioned in auth events"); - }, - Err(error) => { - error!(?event_id, ?error, "Could not load event in auth chain"); - }, } } Ok(found) } - pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { - self.db.get_cached_eventid_authchain(key) + pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + self.db.get_cached_eventid_authchain(key).await } #[tracing::instrument(skip(self), level = "debug")] diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs deleted file mode 100644 index 713ee057..00000000 --- a/src/service/rooms/directory/data.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{OwnedRoomId, RoomId}; - -pub(super) struct Data { - publicroomids: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - publicroomids: db["publicroomids"].clone(), - } - } - - pub(super) fn set_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.insert(room_id.as_bytes(), &[]) - } - - pub(super) fn set_not_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.remove(room_id.as_bytes()) - } - - pub(super) fn is_public_room(&self, room_id: &RoomId) -> Result { - Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) - } - - pub(super) fn public_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) - })) - } -} diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 706e6c2e..3585205d 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,36 +1,44 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, RoomId}; - -use self::data::Data; +use conduit::{implement, utils::stream::TryIgnore, Result}; +use database::{Ignore, Map}; +use futures::{Stream, StreamExt}; +use ruma::RoomId; pub struct Service { db: Data, } +struct Data { + publicroomids: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + publicroomids: args.db["publicroomids"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } +#[implement(Service)] +pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_id.as_bytes(), &[]); } - #[tracing::instrument(skip(self), level = "debug")] - pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } +#[implement(Service)] +pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id.as_bytes()); } - #[tracing::instrument(skip(self), level = "debug")] - pub fn is_public_room(&self, room_id: &RoomId) -> Result { self.db.is_public_room(room_id) } +#[implement(Service)] +pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.qry(room_id).await.is_ok() } - #[tracing::instrument(skip(self), level = "debug")] - pub fn public_rooms(&self) -> impl Iterator> + '_ { self.db.public_rooms() } +#[implement(Service)] +pub fn public_rooms(&self) -> impl Stream + Send { + self.db + .publicroomids + .keys() + .ignore_err() + .map(|(room_id, _): (&RoomId, Ignore)| room_id) } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index bee986de..07d6e4db 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -3,17 +3,18 @@ mod parse_incoming_pdu; use std::{ collections::{hash_map, BTreeMap, HashMap, HashSet}, fmt::Write, - pin::Pin, sync::{Arc, RwLock as StdRwLock}, time::Instant, }; use conduit::{ - debug, debug_error, debug_info, err, error, info, pdu, trace, - utils::{math::continue_exponential_backoff_secs, MutexMap}, - warn, Error, PduEvent, Result, + debug, debug_error, debug_info, debug_warn, err, info, pdu, + result::LogErr, + trace, + utils::{math::continue_exponential_backoff_secs, IterStream, MutexMap}, + warn, Err, Error, PduEvent, Result, }; -use futures_util::Future; +use futures::{future, future::ready, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::{ client::error::ErrorKind, @@ -27,7 +28,7 @@ use ruma::{ }, int, serde::Base64, - state_res::{self, RoomVersion, StateMap}, + state_res::{self, EventTypeExt, RoomVersion, StateMap}, uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, ServerName, }; @@ -60,14 +61,6 @@ struct Services { type RoomMutexMap = MutexMap; type HandleTimeMap = HashMap; -// We use some AsyncRecursiveType hacks here so we can call async funtion -// recursively. -type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; -type AsyncRecursiveCanonicalJsonVec<'a> = - AsyncRecursiveType<'a, Vec<(Arc, Option>)>>; -type AsyncRecursiveCanonicalJsonResult<'a> = - AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -142,17 +135,17 @@ impl Service { pub_key_map: &'a RwLock>>, ) -> Result>> { // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = self.services.timeline.get_pdu_id(event_id)? { + if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { return Ok(Some(pdu_id.to_vec())); } // 1.1 Check the server is in the room - if !self.services.metadata.exists(room_id)? { + if !self.services.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } // 1.2 Check if the room is disabled - if self.services.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "Federation of this room is currently disabled on this server.", @@ -160,7 +153,7 @@ impl Service { } // 1.3.1 Check room ACL on origin field/server - self.acl_check(origin, room_id)?; + self.acl_check(origin, room_id).await?; // 1.3.2 Check room ACL on sender's server name let sender: OwnedUserId = serde_json::from_value( @@ -172,26 +165,23 @@ impl Service { ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "User ID in sender is invalid"))?; - self.acl_check(sender.server_name(), room_id)?; + self.acl_check(sender.server_name(), room_id).await?; // Fetch create event let create_event = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await?; // Procure the room version let room_version_id = Self::get_room_version_id(&create_event)?; - let first_pdu_in_room = self - .services - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; let (incoming_pdu, val) = self .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map) + .boxed() .await?; Self::check_room_id(room_id, &incoming_pdu)?; @@ -235,7 +225,7 @@ impl Service { { Ok(()) => continue, Err(e) => { - warn!("Prev event {} failed: {}", prev_id, e); + warn!("Prev event {prev_id} failed: {e}"); match self .services .globals @@ -287,7 +277,7 @@ impl Service { create_event: &Arc, first_pdu_in_room: &Arc, prev_id: &EventId, ) -> Result<()> { // Check for disabled again because it might have changed - if self.services.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id).await { debug!( "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ event ID {event_id}" @@ -349,149 +339,153 @@ impl Service { } #[allow(clippy::too_many_arguments)] - fn handle_outlier_pdu<'a>( - &'a self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, + async fn handle_outlier_pdu<'a>( + &self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, mut value: BTreeMap, auth_events_known: bool, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveCanonicalJsonResult<'a> { - Box::pin(async move { - // 1. Remove unsigned field - value.remove("unsigned"); + ) -> Result<(Arc, BTreeMap)> { + // 1. Remove unsigned field + value.remove("unsigned"); - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - // 2. Check signatures, otherwise drop - // 3. check content hash, redact if doesn't match - let room_version_id = Self::get_room_version_id(create_event)?; + // 2. Check signatures, otherwise drop + // 3. check content hash, redact if doesn't match + let room_version_id = Self::get_room_version_id(create_event)?; - let guard = pub_key_map.read().await; - let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) { - Err(e) => { - // Drop - warn!("Dropping bad event {}: {}", event_id, e,); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Signature verification failed")); - }, - Ok(ruma::signatures::Verified::Signatures) => { - // Redact - debug_info!("Calculated hash does not match (redaction): {event_id}"); - let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Redaction failed")); - }; + let guard = pub_key_map.read().await; + let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) { + Err(e) => { + // Drop + warn!("Dropping bad event {event_id}: {e}"); + return Err!(Request(InvalidParam("Signature verification failed"))); + }, + Ok(ruma::signatures::Verified::Signatures) => { + // Redact + debug_info!("Calculated hash does not match (redaction): {event_id}"); + let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { + return Err!(Request(InvalidParam("Redaction failed"))); + }; - // Skip the PDU if it is redacted and we already have it as an outlier event - if self.services.timeline.get_pdu_json(event_id)?.is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Event was redacted and we already knew about it", - )); - } + // Skip the PDU if it is redacted and we already have it as an outlier event + if self.services.timeline.get_pdu_json(event_id).await.is_ok() { + return Err!(Request(InvalidParam("Event was redacted and we already knew about it"))); + } - obj - }, - Ok(ruma::signatures::Verified::All) => value, - }; + obj + }, + Ok(ruma::signatures::Verified::All) => value, + }; - drop(guard); + drop(guard); - // Now that we have checked the signature and hashes we can add the eventID and - // convert to our PduEvent type - val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - let incoming_pdu = serde_json::from_value::( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; + // Now that we have checked the signature and hashes we can add the eventID and + // convert to our PduEvent type + val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + let incoming_pdu = serde_json::from_value::( + serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), + ) + .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; - Self::check_room_id(room_id, &incoming_pdu)?; + Self::check_room_id(room_id, &incoming_pdu)?; - if !auth_events_known { - // 4. fetch any missing auth events doing all checks listed here starting at 1. - // These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of - // the auth events are also rejected "due to auth events" - // NOTE: Step 5 is not applied anymore because it failed too often - debug!("Fetching auth events"); + if !auth_events_known { + // 4. fetch any missing auth events doing all checks listed here starting at 1. + // These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth events or some of + // the auth events are also rejected "due to auth events" + // NOTE: Step 5 is not applied anymore because it failed too often + debug!("Fetching auth events"); + Box::pin( self.fetch_and_handle_outliers( origin, &incoming_pdu .auth_events .iter() .map(|x| Arc::from(&**x)) - .collect::>(), + .collect::>>(), create_event, room_id, &room_version_id, pub_key_map, - ) - .await; - } - - // 6. Reject "due to auth events" if the event doesn't pass auth based on the - // auth events - debug!("Checking based on auth events"); - // Build map of auth events - let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); - for id in &incoming_pdu.auth_events { - let Some(auth_event) = self.services.timeline.get_pdu(id)? else { - warn!("Could not find auth event {}", id); - continue; - }; - - Self::check_room_id(room_id, &auth_event)?; - - match auth_events.entry(( - auth_event.kind.to_string().into(), - auth_event - .state_key - .clone() - .expect("all auth events have state keys"), - )) { - hash_map::Entry::Vacant(v) => { - v.insert(auth_event); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times.", - )); - }, - } - } - - // The original create event must be in the auth events - if !matches!( - auth_events - .get(&(StateEventType::RoomCreate, String::new())) - .map(AsRef::as_ref), - Some(_) | None - ) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Incoming event refers to wrong create event.", - )); - } - - if !state_res::event_auth::auth_check( - &Self::to_room_version(&room_version_id), - &incoming_pdu, - None::, // TODO: third party invite - |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), + ), ) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed"))? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); + .await; + } + + // 6. Reject "due to auth events" if the event doesn't pass auth based on the + // auth events + debug!("Checking based on auth events"); + // Build map of auth events + let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); + for id in &incoming_pdu.auth_events { + let Ok(auth_event) = self.services.timeline.get_pdu(id).await else { + warn!("Could not find auth event {id}"); + continue; + }; + + Self::check_room_id(room_id, &auth_event)?; + + match auth_events.entry(( + auth_event.kind.to_string().into(), + auth_event + .state_key + .clone() + .expect("all auth events have state keys"), + )) { + hash_map::Entry::Vacant(v) => { + v.insert(auth_event); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times.", + )); + }, } + } - trace!("Validation successful."); + // The original create event must be in the auth events + if !matches!( + auth_events + .get(&(StateEventType::RoomCreate, String::new())) + .map(AsRef::as_ref), + Some(_) | None + ) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Incoming event refers to wrong create event.", + )); + } - // 7. Persist the event as an outlier. - self.services - .outlier - .add_pdu_outlier(&incoming_pdu.event_id, &val)?; + let state_fetch = |ty: &'static StateEventType, sk: &str| { + let key = ty.with_state_key(sk); + ready(auth_events.get(&key)) + }; - trace!("Added pdu as outlier."); + let auth_check = state_res::event_auth::auth_check( + &Self::to_room_version(&room_version_id), + &incoming_pdu, + None, // TODO: third party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - Ok((Arc::new(incoming_pdu), val)) - }) + if !auth_check { + return Err!(Request(Forbidden("Auth check failed"))); + } + + trace!("Validation successful."); + + // 7. Persist the event as an outlier. + self.services + .outlier + .add_pdu_outlier(&incoming_pdu.event_id, &val); + + trace!("Added pdu as outlier."); + + Ok((Arc::new(incoming_pdu), val)) } pub async fn upgrade_outlier_to_timeline_pdu( @@ -499,16 +493,22 @@ impl Service { origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock>>, ) -> Result>> { // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = self.services.timeline.get_pdu_id(&incoming_pdu.event_id) { + if let Ok(pduid) = self + .services + .timeline + .get_pdu_id(&incoming_pdu.event_id) + .await + { return Ok(Some(pduid.to_vec())); } if self .services .pdu_metadata - .is_event_soft_failed(&incoming_pdu.event_id)? + .is_event_soft_failed(&incoming_pdu.event_id) + .await { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); + return Err!(Request(InvalidParam("Event has been soft failed"))); } debug!("Upgrading to timeline pdu"); @@ -545,57 +545,69 @@ impl Service { debug!("Performing auth check"); // 11. Check the auth of the event passes based on the state of the event - let check_result = state_res::event_auth::auth_check( + let state_fetch_state = &state_at_incoming_event; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = state_fetch_state.get(&shortstatekey)?; + self.services.timeline.get_pdu(event_id).await.ok() + }; + + let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, - None::, // TODO: third party invite - |k, s| { - self.services - .short - .get_shortstatekey(&k.to_string().into(), s) - .ok() - .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| self.services.timeline.get_pdu(event_id).ok().flatten()) - }, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), ) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?; + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - if !check_result { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Event has failed auth check with state at the event.", - )); + if !auth_check { + return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); } debug!("Gathering auth events"); - let auth_events = self.services.state.get_auth_events( - room_id, - &incoming_pdu.kind, - &incoming_pdu.sender, - incoming_pdu.state_key.as_deref(), - &incoming_pdu.content, - )?; + let auth_events = self + .services + .state + .get_auth_events( + room_id, + &incoming_pdu.kind, + &incoming_pdu.sender, + incoming_pdu.state_key.as_deref(), + &incoming_pdu.content, + ) + .await?; + + let state_fetch = |k: &'static StateEventType, s: &str| { + let key = k.with_state_key(s); + ready(auth_events.get(&key).cloned()) + }; + + let auth_check = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None, // third-party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; // Soft fail check before doing state res debug!("Performing soft-fail check"); let soft_fail = { use RoomVersionId::*; - !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::, |k, s| { - auth_events.get(&(k.clone(), s.to_owned())) - }) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))? + !auth_check || incoming_pdu.kind == TimelineEventType::RoomRedaction && match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &incoming_pdu.redacts { - !self.services.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? } else { false } @@ -605,12 +617,11 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; if let Some(redact_id) = &content.redacts { - !self.services.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? } else { false } @@ -627,28 +638,52 @@ impl Service { // Now we calculate the set of extremities this room has after the incoming // event has been applied. We start with the previous extremities (aka leaves) trace!("Calculating extremities"); - let mut extremities = self.services.state.get_forward_extremities(room_id)?; - trace!("Calculated {} extremities", extremities.len()); + let mut extremities: HashSet<_> = self + .services + .state + .get_forward_extremities(room_id) + .map(ToOwned::to_owned) + .collect() + .await; // Remove any forward extremities that are referenced by this incoming event's // prev_events + trace!( + "Calculated {} extremities; checking against {} prev_events", + extremities.len(), + incoming_pdu.prev_events.len() + ); for prev_event in &incoming_pdu.prev_events { - extremities.remove(prev_event); + extremities.remove(&(**prev_event)); } // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(self.services.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); + let mut retained = HashSet::new(); + for id in &extremities { + if !self + .services + .pdu_metadata + .is_event_referenced(room_id, id) + .await + { + retained.insert(id.clone()); + } + } + + extremities.retain(|id| retained.contains(id)); debug!("Retained {} extremities. Compressing state", extremities.len()); - let state_ids_compressed = Arc::new( - state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - self.services - .state_compressor - .compress_state_event(*shortstatekey, id) - }) - .collect::>()?, - ); + + let mut state_ids_compressed = HashSet::new(); + for (shortstatekey, id) in &state_at_incoming_event { + state_ids_compressed.insert( + self.services + .state_compressor + .compress_state_event(*shortstatekey, id) + .await, + ); + } + + let state_ids_compressed = Arc::new(state_ids_compressed); if incoming_pdu.state_key.is_some() { debug!("Event is a state-event. Deriving new room state"); @@ -659,9 +694,11 @@ impl Service { let shortstatekey = self .services .short - .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) + .await; - state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); + let event_id = &incoming_pdu.event_id; + state_after.insert(shortstatekey, event_id.clone()); } let new_room_state = self @@ -673,7 +710,8 @@ impl Service { let (sstatehash, new, removed) = self .services .state_compressor - .save_state(room_id, new_room_state)?; + .save_state(room_id, new_room_state) + .await?; self.services .state @@ -698,16 +736,16 @@ impl Service { .await?; // Soft fail, we keep the event as an outlier but don't add it to the timeline - warn!("Event was soft failed: {:?}", incoming_pdu); + warn!("Event was soft failed: {incoming_pdu:?}"); self.services .pdu_metadata - .mark_event_soft_failed(&incoming_pdu.event_id)?; + .mark_event_soft_failed(&incoming_pdu.event_id); return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); } trace!("Appending pdu to timeline"); - extremities.insert(incoming_pdu.event_id.clone()); + extremities.insert(incoming_pdu.event_id.clone().into()); // Now that the event has passed all auth it is added into the timeline. // We use the `state_at_event` instead of `state_after` so we accurately @@ -718,7 +756,7 @@ impl Service { .append_incoming_pdu( &incoming_pdu, val, - extremities.iter().map(|e| (**e).to_owned()).collect(), + extremities.into_iter().collect(), state_ids_compressed, soft_fail, &state_lock, @@ -742,8 +780,9 @@ impl Service { let current_sstatehash = self .services .state - .get_room_shortstatehash(room_id)? - .expect("every room has state"); + .get_room_shortstatehash(room_id) + .await + .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; let current_state_ids = self .services @@ -752,7 +791,6 @@ impl Service { .await?; let fork_states = [current_state_ids, incoming_state]; - let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); for state in &fork_states { auth_chain_sets.push( @@ -760,62 +798,59 @@ impl Service { .auth_chain .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) .await? - .collect(), + .collect::>>() + .await, ); } debug!("Loading fork states"); - let fork_states: Vec<_> = fork_states + let fork_states: Vec>> = fork_states .into_iter() - .map(|map| { - map.into_iter() + .stream() + .then(|fork_state| { + fork_state + .into_iter() + .stream() .filter_map(|(k, id)| { self.services .short .get_statekey_from_short(k) - .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) - .ok() + .map_ok_or_else(|_| None, move |(ty, st_key)| Some(((ty, st_key), id))) }) - .collect::>() + .collect() }) - .collect(); - - let lock = self.services.globals.stateres_mutex.lock(); + .collect() + .boxed() + .await; debug!("Resolving state"); - let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = self.services.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); + let lock = self.services.globals.stateres_mutex.lock(); - let state = match state_resolve { - Ok(new_state) => new_state, - Err(e) => { - error!("State resolution failed: {}", e); - return Err(Error::bad_database( - "State resolution failed, either an event could not be found or deserialization", - )); - }, - }; + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?; drop(lock); debug!("State resolution done. Compressing state"); - let new_room_state = state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - self.services - .state_compressor - .compress_state_event(shortstatekey, &event_id) - }) - .collect::>()?; + let mut new_room_state = HashSet::new(); + for ((event_type, state_key), event_id) in state { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) + .await; + + let compressed = self + .services + .state_compressor + .compress_state_event(shortstatekey, &event_id) + .await; + + new_room_state.insert(compressed); + } Ok(Arc::new(new_room_state)) } @@ -827,46 +862,47 @@ impl Service { &self, incoming_pdu: &Arc, ) -> Result>>> { let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = self + let Ok(prev_event_sstatehash) = self .services .state_accessor - .pdu_shortstatehash(prev_event)?; - - let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some( - self.services - .state_accessor - .state_full_ids(shortstatehash) - .await, - ) - } else { - None + .pdu_shortstatehash(prev_event) + .await + else { + return Ok(None); }; - if let Some(Ok(mut state)) = state { - debug!("Using cached state"); - let prev_pdu = self + let Ok(mut state) = self + .services + .state_accessor + .state_full_ids(prev_event_sstatehash) + .await + .log_err() + else { + return Ok(None); + }; + + debug!("Using cached state"); + let prev_pdu = self + .services + .timeline + .get_pdu(prev_event) + .await + .map_err(|e| err!(Database("Could not find prev event, but we know the state: {e:?}")))?; + + if let Some(state_key) = &prev_pdu.state_key { + let shortstatekey = self .services - .timeline - .get_pdu(prev_event) - .ok() - .flatten() - .ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?; + .short + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) + .await; - if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?; - - state.insert(shortstatekey, Arc::from(prev_event)); - // Now it's the state after the pdu - } - - return Ok(Some(state)); + state.insert(shortstatekey, Arc::from(prev_event)); + // Now it's the state after the pdu } - Ok(None) + debug_assert!(!state.is_empty(), "should be returning None for empty HashMap result"); + + Ok(Some(state)) } #[tracing::instrument(skip_all, name = "state")] @@ -878,15 +914,16 @@ impl Service { let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let Ok(Some(prev_event)) = self.services.timeline.get_pdu(prev_eventid) else { + let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else { okay = false; break; }; - let Ok(Some(sstatehash)) = self + let Ok(sstatehash) = self .services .state_accessor .pdu_shortstatehash(prev_eventid) + .await else { okay = false; break; @@ -901,20 +938,25 @@ impl Service { let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = self + let Ok(mut leaf_state) = self .services .state_accessor .state_full_ids(sstatehash) - .await?; + .await + else { + continue; + }; if let Some(state_key) = &prev_event.state_key { let shortstatekey = self .services .short - .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?; - leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); + .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) + .await; + + let event_id = &prev_event.event_id; + leaf_state.insert(shortstatekey, event_id.clone()); // Now it's the state after the pdu } @@ -922,13 +964,18 @@ impl Service { let mut starting_events = Vec::with_capacity(leaf_state.len()); for (k, id) in leaf_state { - if let Ok((ty, st_key)) = self.services.short.get_statekey_from_short(k) { + if let Ok((ty, st_key)) = self + .services + .short + .get_statekey_from_short(k) + .await + .log_err() + { // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType state.insert((ty.to_string().into(), st_key), id.clone()); - } else { - warn!("Failed to get_statekey_from_short."); } + starting_events.push(id); } @@ -937,43 +984,40 @@ impl Service { .auth_chain .event_ids_iter(room_id, starting_events) .await? - .collect(), + .collect() + .await, ); fork_states.push(state); } let lock = self.services.globals.stateres_mutex.lock(); - let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = self.services.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); + + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed.")))); + drop(lock); - Ok(match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - Ok((shortstatekey, event_id)) - }) - .collect::>()?, - ), - Err(e) => { - warn!( - "State resolution on prev events failed, either an event could not be found or deserialization: {}", - e - ); - None - }, - }) + let Ok(new_state) = result else { + return Ok(None); + }; + + new_state + .iter() + .stream() + .then(|((event_type, state_key), event_id)| { + self.services + .short + .get_or_create_shortstatekey(event_type, state_key) + .map(move |shortstatekey| (shortstatekey, event_id.clone())) + }) + .collect() + .map(Some) + .map(Ok) + .await } /// Call /state_ids to find out what the state at this pdu is. We trust the @@ -985,7 +1029,7 @@ impl Service { pub_key_map: &RwLock>>, event_id: &EventId, ) -> Result>>> { debug!("Fetching state ids"); - match self + let res = self .services .sending .send_federation_request( @@ -996,61 +1040,57 @@ impl Service { }, ) .await - { - Ok(res) => { - debug!("Fetching state events"); - let collect = res - .pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(); + .inspect_err(|e| warn!("Fetching state for event failed: {e}"))?; - let state_vec = self - .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map) - .await; + debug!("Fetching state events"); + let collect = res + .pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(); - let mut state: HashMap<_, Arc> = HashMap::with_capacity(state_vec.len()); - for (pdu, _) in state_vec { - let state_key = pdu - .state_key - .clone() - .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; + let state_vec = self + .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map) + .boxed() + .await; - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?; + let mut state: HashMap<_, Arc> = HashMap::with_capacity(state_vec.len()); + for (pdu, _) in state_vec { + let state_key = pdu + .state_key + .clone() + .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; - match state.entry(shortstatekey) { - hash_map::Entry::Vacant(v) => { - v.insert(Arc::from(&*pdu.event_id)); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::bad_database( - "State event's type and state_key combination exists multiple times.", - )) - }, - } - } + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) + .await; - // The original create event must still be in the state - let create_shortstatekey = self - .services - .short - .get_shortstatekey(&StateEventType::RoomCreate, "")? - .expect("Room exists"); - - if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { - return Err(Error::bad_database("Incoming event refers to wrong create event.")); - } - - Ok(Some(state)) - }, - Err(e) => { - warn!("Fetching state for event failed: {}", e); - Err(e) - }, + match state.entry(shortstatekey) { + hash_map::Entry::Vacant(v) => { + v.insert(Arc::from(&*pdu.event_id)); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::bad_database( + "State event's type and state_key combination exists multiple times.", + )) + }, + } } + + // The original create event must still be in the state + let create_shortstatekey = self + .services + .short + .get_shortstatekey(&StateEventType::RoomCreate, "") + .await?; + + if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { + return Err!(Database("Incoming event refers to wrong create event.")); + } + + Ok(Some(state)) } /// Find the event and auth it. Once the event is validated (steps 1 - 8) @@ -1062,191 +1102,196 @@ impl Service { /// b. Look at outlier pdu tree /// c. Ask origin server over federation /// d. TODO: Ask other servers over federation? - pub fn fetch_and_handle_outliers<'a>( - &'a self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, + pub async fn fetch_and_handle_outliers<'a>( + &self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveCanonicalJsonVec<'a> { - Box::pin(async move { - let back_off = |id| async { - match self + ) -> Vec<(Arc, Option>)> { + let back_off = |id| match self + .services + .globals + .bad_event_ratelimiter + .write() + .expect("locked") + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), + }; + + let mut events_with_auth_events = Vec::with_capacity(events.len()); + for id in events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await { + trace!("Found {id} in db"); + events_with_auth_events.push((id, Some(local_pdu), vec![])); + continue; + } + + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); + let mut events_all = HashSet::with_capacity(todo_auth_events.len()); + let mut i: u64 = 0; + while let Some(next_id) = todo_auth_events.pop() { + if let Some((time, tries)) = self .services .globals .bad_event_ratelimiter - .write() + .read() .expect("locked") - .entry(id) + .get(&*next_id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + info!("Backing off from {next_id}"); + continue; + } } - }; - let mut events_with_auth_events = Vec::with_capacity(events.len()); - for id in events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = self.services.timeline.get_pdu(id) { - trace!("Found {} in db", id); - events_with_auth_events.push((id, Some(local_pdu), vec![])); + if events_all.contains(&next_id) { continue; } - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); - let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - let mut i: u64 = 0; - while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&*next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - info!("Backing off from {next_id}"); + i = i.saturating_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; + } + + if self.services.timeline.get_pdu(&next_id).await.is_ok() { + trace!("Found {next_id} in db"); + continue; + } + + debug!("Fetching {next_id} over federation."); + match self + .services + .sending + .send_federation_request( + origin, + get_event::v1::Request { + event_id: (*next_id).to_owned(), + }, + ) + .await + { + Ok(res) => { + debug!("Got {next_id} over federation"); + let Ok((calculated_event_id, value)) = + pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) + else { + back_off((*next_id).to_owned()); continue; + }; + + if calculated_event_id != *next_id { + warn!( + "Server didn't return event id we requested: requested: {next_id}, we got \ + {calculated_event_id}. Event: {:?}", + &res.pdu + ); } - } - if events_all.contains(&next_id) { - continue; - } - - i = i.saturating_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - - if let Ok(Some(_)) = self.services.timeline.get_pdu(&next_id) { - trace!("Found {} in db", next_id); - continue; - } - - debug!("Fetching {} over federation.", next_id); - match self - .services - .sending - .send_federation_request( - origin, - get_event::v1::Request { - event_id: (*next_id).to_owned(), - }, - ) - .await - { - Ok(res) => { - debug!("Got {} over federation", next_id); - let Ok((calculated_event_id, value)) = - pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) - else { - back_off((*next_id).to_owned()).await; - continue; - }; - - if calculated_event_id != *next_id { - warn!( - "Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu - ); - } - - if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { - for auth_event in auth_events { - if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { - let a: Arc = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); - } + if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { + for auth_event in auth_events { + if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { + let a: Arc = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); } - } else { - warn!("Auth event list invalid"); } - - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - }, - Err(e) => { - debug_error!("Failed to fetch event {next_id}: {e}"); - back_off((*next_id).to_owned()).await; - }, - } - } - events_with_auth_events.push((id, None, events_in_reverse_order)); - } - - // We go through all the signatures we see on the PDUs and their unresolved - // dependencies and fetch the corresponding signing keys - self.services - .server_keys - .fetch_required_signing_keys( - events_with_auth_events - .iter() - .flat_map(|(_id, _local_pdu, events)| events) - .map(|(_event_id, event)| event), - pub_key_map, - ) - .await - .unwrap_or_else(|e| { - warn!("Could not fetch all signatures for PDUs from {}: {:?}", origin, e); - }); - - let mut pdus = Vec::with_capacity(events_with_auth_events.len()); - for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Some(local_pdu) = local_pdu { - trace!("Found {} in db", id); - pdus.push((local_pdu, None)); - } - for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&**next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - debug!("Backing off from {next_id}"); - continue; + } else { + warn!("Auth event list invalid"); } - } - match self - .handle_outlier_pdu(origin, create_event, next_id, room_id, value.clone(), true, pub_key_map) - .await - { - Ok((pdu, json)) => { - if next_id == id { - pdus.push((pdu, Some(json))); - } - }, - Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); - back_off((**next_id).to_owned()).await; - }, - } + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + }, + Err(e) => { + debug_error!("Failed to fetch event {next_id}: {e}"); + back_off((*next_id).to_owned()); + }, } } - pdus - }) + events_with_auth_events.push((id, None, events_in_reverse_order)); + } + + // We go through all the signatures we see on the PDUs and their unresolved + // dependencies and fetch the corresponding signing keys + self.services + .server_keys + .fetch_required_signing_keys( + events_with_auth_events + .iter() + .flat_map(|(_id, _local_pdu, events)| events) + .map(|(_event_id, event)| event), + pub_key_map, + ) + .await + .unwrap_or_else(|e| { + warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}"); + }); + + let mut pdus = Vec::with_capacity(events_with_auth_events.len()); + for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Some(local_pdu) = local_pdu { + trace!("Found {id} in db"); + pdus.push((local_pdu.clone(), None)); + } + + for (next_id, value) in events_in_reverse_order.into_iter().rev() { + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(&*next_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!("Backing off from {next_id}"); + continue; + } + } + + match Box::pin(self.handle_outlier_pdu( + origin, + create_event, + &next_id, + room_id, + value.clone(), + true, + pub_key_map, + )) + .await + { + Ok((pdu, json)) => { + if next_id == *id { + pdus.push((pdu, Some(json))); + } + }, + Err(e) => { + warn!("Authentication of event {next_id} failed: {e:?}"); + back_off(next_id.into()); + }, + } + } + } + pdus } #[allow(clippy::type_complexity)] @@ -1262,16 +1307,12 @@ impl Service { let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec> = initial_set; - let first_pdu_in_room = self - .services - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; let mut amount = 0; while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, json_opt)) = self + if let Some((pdu, mut json_opt)) = self .fetch_and_handle_outliers( origin, &[prev_event_id.clone()], @@ -1280,28 +1321,29 @@ impl Service { room_version_id, pub_key_map, ) + .boxed() .await .pop() { Self::check_room_id(room_id, &pdu)?; - if amount > self.services.globals.max_fetch_prev_events() { - // Max limit reached - debug!( - "Max prev event limit reached! Limit: {}", - self.services.globals.max_fetch_prev_events() - ); + let limit = self.services.globals.max_fetch_prev_events(); + if amount > limit { + debug_warn!("Max prev event limit reached! Limit: {limit}"); graph.insert(prev_event_id.clone(), HashSet::new()); continue; } - if let Some(json) = json_opt.or_else(|| { - self.services + if json_opt.is_none() { + json_opt = self + .services .outlier .get_outlier_pdu_json(&prev_event_id) - .ok() - .flatten() - }) { + .await + .ok(); + } + + if let Some(json) = json_opt { if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { amount = amount.saturating_add(1); for prev_prev in &pdu.prev_events { @@ -1327,56 +1369,42 @@ impl Service { } } - let sorted = state_res::lexicographical_topological_sort(&graph, |event_id| { + let event_fetch = |event_id| { + let origin_server_ts = eventid_info + .get(&event_id) + .cloned() + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts); + // This return value is the key used for sorting events, // events are then sorted by power level, time, // and lexically by event_id. - Ok(( - int!(0), - MilliSecondsSinceUnixEpoch( - eventid_info - .get(event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), - ), - )) - }) - .map_err(|e| { - error!("Error sorting prev events: {e}"); - Error::bad_database("Error sorting prev events") - })?; + future::ok((int!(0), MilliSecondsSinceUnixEpoch(origin_server_ts))) + }; + + let sorted = state_res::lexicographical_topological_sort(&graph, &event_fetch) + .await + .map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?; Ok((sorted, eventid_info)) } /// Returns Ok if the acl allows the server #[tracing::instrument(skip_all)] - pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { - let acl_event = if let Some(acl) = - self.services - .state_accessor - .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? - { - trace!("ACL event found: {acl:?}"); - acl - } else { - trace!("No ACL event found"); + pub async fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { + let Ok(acl_event_content) = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomServerAcl, "") + .await + .map(|c: RoomServerAclEventContent| c) + .inspect(|acl| trace!("ACL content found: {acl:?}")) + .inspect_err(|e| trace!("No ACL content found: {e:?}")) + else { return Ok(()); }; - let acl_event_content: RoomServerAclEventContent = match serde_json::from_str(acl_event.content.get()) { - Ok(content) => { - trace!("Found ACL event contents: {content:?}"); - content - }, - Err(e) => { - warn!("Invalid ACL event: {e}"); - return Ok(()); - }, - }; - if acl_event_content.allow.is_empty() { warn!("Ignoring broken ACL event (allow key is empty)"); - // Ignore broken acl events return Ok(()); } @@ -1384,16 +1412,18 @@ impl Service { trace!("server {server_name} is allowed by ACL"); Ok(()) } else { - debug!("Server {} was denied by room ACL in {}", server_name, room_id); - Err(Error::BadRequest(ErrorKind::forbidden(), "Server was denied by room ACL")) + debug!("Server {server_name} was denied by room ACL in {room_id}"); + Err!(Request(Forbidden("Server was denied by room ACL"))) } } fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> { if pdu.room_id != room_id { - warn!("Found event from room {} in room {}", pdu.room_id, room_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has wrong room id")); + return Err!(Request(InvalidParam( + warn!(pdu_event_id = ?pdu.event_id, pdu_room_id = ?pdu.room_id, ?room_id, "Found event from room in room") + ))); } + Ok(()) } @@ -1408,4 +1438,10 @@ impl Service { fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { RoomVersion::new(room_version_id).expect("room version is supported") } + + async fn event_exists(&self, event_id: Arc) -> bool { self.services.timeline.pdu_exists(&event_id).await } + + async fn event_fetch(&self, event_id: Arc) -> Option> { + self.services.timeline.get_pdu(&event_id).await.ok() + } } diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index a7ffe193..2de3e28e 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -3,7 +3,9 @@ use ruma::{CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; impl super::Service { - pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + pub async fn parse_incoming_pdu( + &self, pdu: &RawJsonValue, + ) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { debug_warn!("Error parsing incoming event {pdu:#?}"); err!(BadServerResponse("Error parsing incoming event {e:?}")) @@ -14,7 +16,7 @@ impl super::Service { .and_then(|id| RoomId::parse(id.as_str()?).ok()) .ok_or(err!(Request(InvalidParam("Invalid room id in pdu"))))?; - let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else { + let Ok(room_version_id) = self.services.state.get_room_version(&room_id).await else { return Err!("Server is not in room {room_id}"); }; diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs deleted file mode 100644 index 073d45f5..00000000 --- a/src/service/rooms/lazy_loading/data.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Database, Map}; -use ruma::{DeviceId, RoomId, UserId}; - -pub(super) struct Data { - lazyloadedids: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - lazyloadedids: db["lazyloadedids"].clone(), - } - } - - pub(super) fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(ll_user.as_bytes()); - Ok(self.lazyloadedids.get(&key)?.is_some()) - } - - pub(super) fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - confirmed_user_ids: &mut dyn Iterator, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for ll_id in confirmed_user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } - - Ok(()) - } - - pub(super) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for (key, _) in self.lazyloadedids.scan_prefix(prefix) { - self.lazyloadedids.remove(&key)?; - } - - Ok(()) - } -} diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 0a9d4cf2..e0816d3f 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,21 +1,26 @@ -mod data; - use std::{ collections::{HashMap, HashSet}, fmt::Write, sync::{Arc, Mutex}, }; -use conduit::{PduCount, Result}; +use conduit::{ + implement, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, Result, +}; +use database::{Interfix, Map}; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use self::data::Data; - pub struct Service { - pub lazy_load_waiting: Mutex, + lazy_load_waiting: Mutex, db: Data, } +struct Data { + lazyloadedids: Arc, +} + type LazyLoadWaiting = HashMap; type LazyLoadWaitingKey = (OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount); type LazyLoadWaitingVal = HashSet; @@ -23,8 +28,10 @@ type LazyLoadWaitingVal = HashSet; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - lazy_load_waiting: Mutex::new(HashMap::new()), - db: Data::new(args.db), + lazy_load_waiting: LazyLoadWaiting::new().into(), + db: Data { + lazyloadedids: args.db["lazyloadedids"].clone(), + }, })) } @@ -40,47 +47,60 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result { - self.db - .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +#[inline] +pub async fn lazy_load_was_sent_before( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, +) -> bool { + let key = (user_id, device_id, room_id, ll_user); + self.db.lazyloadedids.qry(&key).await.is_ok() +} - #[tracing::instrument(skip(self), level = "debug")] - pub async fn lazy_load_mark_sent( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, - count: PduCount, - ) { - self.lazy_load_waiting - .lock() - .expect("locked") - .insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load); - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn lazy_load_mark_sent( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, count: PduCount, +) { + let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count); - #[tracing::instrument(skip(self), level = "debug")] - pub async fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, - ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - since, - )) { - self.db - .lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?; - } else { - // Ignore - } + self.lazy_load_waiting + .lock() + .expect("locked") + .insert(key, lazy_load); +} - Ok(()) - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) { + let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since); - #[tracing::instrument(skip(self), level = "debug")] - pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - self.db.lazy_load_reset(user_id, device_id, room_id) + let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else { + return; + }; + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); + + for ll_id in &user_ids { + let mut key = prefix.clone(); + key.extend_from_slice(ll_id.as_bytes()); + self.db.lazyloadedids.insert(&key, &[]); } } + +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) { + let prefix = (user_id, device_id, room_id, Interfix); + self.db + .lazyloadedids + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.lazyloadedids.remove(key)) + .await; +} diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs deleted file mode 100644 index efe681b1..00000000 --- a/src/service/rooms/metadata/data.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::sync::Arc; - -use conduit::{error, utils, Error, Result}; -use database::Map; -use ruma::{OwnedRoomId, RoomId}; - -use crate::{rooms, Dep}; - -pub(super) struct Data { - disabledroomids: Arc, - bannedroomids: Arc, - roomid_shortroomid: Arc, - pduid_pdu: Arc, - services: Services, -} - -struct Services { - short: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - disabledroomids: db["disabledroomids"].clone(), - bannedroomids: db["bannedroomids"].clone(), - roomid_shortroomid: db["roomid_shortroomid"].clone(), - pduid_pdu: db["pduid_pdu"].clone(), - services: Services { - short: args.depend::("rooms::short"), - }, - } - } - - pub(super) fn exists(&self, room_id: &RoomId) -> Result { - let prefix = match self.services.short.get_shortroomid(room_id)? { - Some(b) => b.to_be_bytes().to_vec(), - None => return Ok(false), - }; - - // Look for PDUs in that room. - Ok(self - .pduid_pdu - .iter_from(&prefix, false) - .next() - .filter(|(k, _)| k.starts_with(&prefix)) - .is_some()) - } - - pub(super) fn iter_ids<'a>(&'a self) -> Box> + 'a> { - Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) - })) - } - - #[inline] - pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result { - Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) - } - - #[inline] - pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - if disabled { - self.disabledroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.disabledroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - #[inline] - pub(super) fn is_banned(&self, room_id: &RoomId) -> Result { - Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) - } - - #[inline] - pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { - if banned { - self.bannedroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.bannedroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - pub(super) fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.bannedroomids.iter().map( - |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|e| { - error!("Invalid room_id bytes in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids.") - })? - .try_into() - .map_err(|e| { - error!("Invalid room_id in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids") - })?; - - Ok(room_id) - }, - )) - } -} diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 7415c53b..5d4a47c7 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,51 +1,92 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, RoomId}; +use conduit::{implement, utils::stream::TryIgnore, Result}; +use database::Map; +use futures::{Stream, StreamExt}; +use ruma::RoomId; -use self::data::Data; +use crate::{rooms, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + disabledroomids: Arc, + bannedroomids: Arc, + roomid_shortroomid: Arc, + pduid_pdu: Arc, +} + +struct Services { + short: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + disabledroomids: args.db["disabledroomids"].clone(), + bannedroomids: args.db["bannedroomids"].clone(), + roomid_shortroomid: args.db["roomid_shortroomid"].clone(), + pduid_pdu: args.db["pduid_pdu"].clone(), + }, + services: Services { + short: args.depend::("rooms::short"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Checks if a room exists. - #[inline] - pub fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } +#[implement(Service)] +pub async fn exists(&self, room_id: &RoomId) -> bool { + let Ok(prefix) = self.services.short.get_shortroomid(room_id).await else { + return false; + }; - #[must_use] - pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { self.db.iter_ids() } + // Look for PDUs in that room. + self.db + .pduid_pdu + .keys_raw_prefix(&prefix) + .ignore_err() + .next() + .await + .is_some() +} - #[inline] - pub fn is_disabled(&self, room_id: &RoomId) -> Result { self.db.is_disabled(room_id) } +#[implement(Service)] +pub fn iter_ids(&self) -> impl Stream + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() } - #[inline] - pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - self.db.disable_room(room_id, disabled) - } - - #[inline] - pub fn is_banned(&self, room_id: &RoomId) -> Result { self.db.is_banned(room_id) } - - #[inline] - pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) } - - #[inline] - #[must_use] - pub fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - self.db.list_banned_rooms() +#[implement(Service)] +#[inline] +pub fn disable_room(&self, room_id: &RoomId, disabled: bool) { + if disabled { + self.db.disabledroomids.insert(room_id.as_bytes(), &[]); + } else { + self.db.disabledroomids.remove(room_id.as_bytes()); } } + +#[implement(Service)] +#[inline] +pub fn ban_room(&self, room_id: &RoomId, banned: bool) { + if banned { + self.db.bannedroomids.insert(room_id.as_bytes(), &[]); + } else { + self.db.bannedroomids.remove(room_id.as_bytes()); + } +} + +#[implement(Service)] +pub fn list_banned_rooms(&self) -> impl Stream + Send + '_ { self.db.bannedroomids.keys().ignore_err() } + +#[implement(Service)] +#[inline] +pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.qry(room_id).await.is_ok() } + +#[implement(Service)] +#[inline] +pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.qry(room_id).await.is_ok() } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs deleted file mode 100644 index aa804721..00000000 --- a/src/service/rooms/outlier/data.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::sync::Arc; - -use conduit::{Error, Result}; -use database::{Database, Map}; -use ruma::{CanonicalJsonObject, EventId}; - -use crate::PduEvent; - -pub(super) struct Data { - eventid_outlierpdu: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - eventid_outlierpdu: db["eventid_outlierpdu"].clone(), - } - } - - pub(super) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - pub(super) fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - pub(super) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ) - } -} diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 22bd2092..277b5982 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,9 +1,7 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{implement, Result}; +use database::{Deserialized, Map}; use ruma::{CanonicalJsonObject, EventId}; use crate::PduEvent; @@ -12,31 +10,48 @@ pub struct Service { db: Data, } +struct Data { + eventid_outlierpdu: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + eventid_outlierpdu: args.db["eventid_outlierpdu"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Returns the pdu from the outlier tree. - pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_outlier_pdu_json(event_id) - } - - /// Returns the pdu from the outlier tree. - /// - /// TODO: use this? - #[allow(dead_code)] - pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu(event_id) } - - /// Append the PDU as an outlier. - #[tracing::instrument(skip(self, pdu), level = "debug")] - pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.db.add_pdu_outlier(event_id, pdu) - } +/// Returns the pdu from the outlier tree. +#[implement(Service)] +pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result { + self.db + .eventid_outlierpdu + .qry(event_id) + .await + .deserialized_json() +} + +/// Returns the pdu from the outlier tree. +#[implement(Service)] +pub async fn get_pdu_outlier(&self, event_id: &EventId) -> Result { + self.db + .eventid_outlierpdu + .qry(event_id) + .await + .deserialized_json() +} + +/// Append the PDU as an outlier. +#[implement(Service)] +#[tracing::instrument(skip(self, pdu), level = "debug")] +pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) { + self.db.eventid_outlierpdu.insert( + event_id.as_bytes(), + &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), + ); } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index d1649da8..f2323475 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,7 +1,13 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, PduCount, PduEvent, Result}; +use conduit::{ + result::LogErr, + utils, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, PduEvent, +}; use database::Map; +use futures::{Stream, StreamExt}; use ruma::{EventId, RoomId, UserId}; use crate::{rooms, Dep}; @@ -17,8 +23,7 @@ struct Services { timeline: Dep, } -type PdusIterItem = Result<(PduCount, PduEvent)>; -type PdusIterator<'a> = Box + 'a>; +pub(super) type PdusIterItem = (PduCount, PduEvent); impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { @@ -33,19 +38,17 @@ impl Data { } } - pub(super) fn add_relation(&self, from: u64, to: u64) -> Result<()> { + pub(super) fn add_relation(&self, from: u64, to: u64) { let mut key = to.to_be_bytes().to_vec(); key.extend_from_slice(&from.to_be_bytes()); - self.tofrom_relation.insert(&key, &[])?; - Ok(()) + self.tofrom_relation.insert(&key, &[]); } pub(super) fn relations_until<'a>( &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, - ) -> Result> { + ) -> impl Stream + Send + 'a + '_ { let prefix = target.to_be_bytes().to_vec(); let mut current = prefix.clone(); - let count_raw = match until { PduCount::Normal(x) => x.saturating_sub(1), PduCount::Backfilled(x) => { @@ -55,53 +58,42 @@ impl Data { }; current.extend_from_slice(&count_raw.to_be_bytes()); - Ok(Box::new( - self.tofrom_relation - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(tofrom, _data)| { - let from = utils::u64_from_bytes(&tofrom[(size_of::())..]) - .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; + self.tofrom_relation + .rev_raw_keys_from(¤t) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|to_from| utils::u64_from_u8(&to_from[(size_of::())..])) + .filter_map(move |from| async move { + let mut pduid = shortroomid.to_be_bytes().to_vec(); + pduid.extend_from_slice(&from.to_be_bytes()); + let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?; - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } - let mut pdu = self - .services - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((PduCount::Normal(from), pdu)) - }), - )) + Some((PduCount::Normal(from), pdu)) + }) } - pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[])?; + self.referencedevents.insert(&key, &[]); } - - Ok(()) } - pub(super) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(event_id.as_bytes()); - Ok(self.referencedevents.get(&key)?.is_some()) + pub(super) async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool { + let key = (room_id, event_id); + self.referencedevents.qry(&key).await.is_ok() } - pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.softfailedeventids.insert(event_id.as_bytes(), &[]) + pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { + self.softfailedeventids.insert(event_id.as_bytes(), &[]); } - pub(super) fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) + pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { + self.softfailedeventids.qry(event_id).await.is_ok() } } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index d9eaf324..dbaebfbf 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,8 +1,8 @@ mod data; - use std::sync::Arc; -use conduit::{PduCount, PduEvent, Result}; +use conduit::{utils::stream::IterStream, PduCount, Result}; +use futures::StreamExt; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, @@ -10,7 +10,7 @@ use ruma::{ }; use serde::Deserialize; -use self::data::Data; +use self::data::{Data, PdusIterItem}; use crate::{rooms, Dep}; pub struct Service { @@ -51,21 +51,19 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self, from, to), level = "debug")] - pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { + pub fn add_relation(&self, from: PduCount, to: PduCount) { match (from, to) { (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), _ => { // TODO: Relations with backfilled pdus - - Ok(()) }, } } #[allow(clippy::too_many_arguments)] - pub fn paginate_relations_with_filter( - &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option, - filter_rel_type: &Option, from: &Option, to: &Option, limit: &Option, + pub async fn paginate_relations_with_filter( + &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option, + filter_rel_type: Option, from: Option<&String>, to: Option<&String>, limit: Option, recurse: bool, dir: Direction, ) -> Result { let from = match from { @@ -76,7 +74,7 @@ impl Service { }, }; - let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); + let to = to.and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 let limit = limit @@ -92,30 +90,32 @@ impl Service { 1 }; - let relations_until = &self.relations_until(sender_user, room_id, target, from, depth)?; - let events: Vec<_> = relations_until // TODO: should be relations_after - .iter() - .filter(|(_, pdu)| { - filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) - && if let Ok(content) = - serde_json::from_str::(pdu.content.get()) - { - filter_rel_type - .as_ref() - .map_or(true, |r| &content.relates_to.rel_type == r) - } else { - false - } - }) - .take(limit) - .filter(|(_, pdu)| { - self.services - .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) - .unwrap_or(false) - }) - .take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to` - .collect(); + let relations_until: Vec = self + .relations_until(sender_user, room_id, target, from, depth) + .await?; + + // TODO: should be relations_after + let events: Vec<_> = relations_until + .into_iter() + .filter(move |(_, pdu): &PdusIterItem| { + if !filter_event_type.as_ref().map_or(true, |t| pdu.kind == *t) { + return false; + } + + let Ok(content) = serde_json::from_str::(pdu.content.get()) else { + return false; + }; + + filter_rel_type + .as_ref() + .map_or(true, |r| *r == content.relates_to.rel_type) + }) + .take(limit) + .take_while(|(k, _)| Some(*k) != to) + .stream() + .filter_map(|item| self.visibility_filter(sender_user, item)) + .collect() + .await; let next_token = events.last().map(|(count, _)| count).copied(); @@ -125,9 +125,9 @@ impl Service { .map(|(_, pdu)| pdu.to_message_like_event()) .collect(), Direction::Backward => events - .into_iter() - .rev() // relations are always most recent first - .map(|(_, pdu)| pdu.to_message_like_event()) + .into_iter() + .rev() // relations are always most recent first + .map(|(_, pdu)| pdu.to_message_like_event()) .collect(), }; @@ -135,68 +135,85 @@ impl Service { chunk: events_chunk, next_batch: next_token.map(|t| t.stringify()), prev_batch: Some(from.stringify()), - recursion_depth: if recurse { - Some(depth.into()) - } else { - None - }, + recursion_depth: recurse.then_some(depth.into()), }) } - pub fn relations_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, - ) -> Result> { - let room_id = self.services.short.get_or_create_shortroomid(room_id)?; - #[allow(unknown_lints)] - #[allow(clippy::manual_unwrap_or_default)] - let target = match self.services.timeline.get_pdu_count(target)? { - Some(PduCount::Normal(c)) => c, + async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option { + let (_, pdu) = &item; + + self.services + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) + } + + pub async fn relations_until( + &self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, max_depth: u8, + ) -> Result> { + let room_id = self.services.short.get_or_create_shortroomid(room_id).await; + + let target = match self.services.timeline.get_pdu_count(target).await { + Ok(PduCount::Normal(c)) => c, // TODO: Support backfilled relations _ => 0, // This will result in an empty iterator }; - self.db + let mut pdus: Vec = self + .db .relations_until(user_id, room_id, target, until) - .map(|mut relations| { - let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect(); - let mut stack: Vec<_> = pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect(); + .collect() + .await; - while let Some(stack_pdu) = stack.pop() { - let target = match stack_pdu.0 .0 { - PduCount::Normal(c) => c, - // TODO: Support backfilled relations - PduCount::Backfilled(_) => 0, // This will result in an empty iterator - }; + let mut stack: Vec<_> = pdus.clone().into_iter().map(|pdu| (pdu, 1)).collect(); - if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) { - for relation in relations.flatten() { - if stack_pdu.1 < max_depth { - stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); - } + while let Some(stack_pdu) = stack.pop() { + let target = match stack_pdu.0 .0 { + PduCount::Normal(c) => c, + // TODO: Support backfilled relations + PduCount::Backfilled(_) => 0, // This will result in an empty iterator + }; - pdus.push(relation); - } - } + let relations: Vec = self + .db + .relations_until(user_id, room_id, target, until) + .collect() + .await; + + for relation in relations { + if stack_pdu.1 < max_depth { + stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); } - pdus.sort_by(|a, b| a.0.cmp(&b.0)); - pdus - }) + pdus.push(relation); + } + } + + pdus.sort_by(|a, b| a.0.cmp(&b.0)); + + Ok(pdus) } + #[inline] #[tracing::instrument(skip_all, level = "debug")] - pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { - self.db.mark_as_referenced(room_id, event_ids) + pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { + self.db.mark_as_referenced(room_id, event_ids); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - self.db.is_event_referenced(room_id, event_id) + pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool { + self.db.is_event_referenced(room_id, event_id).await } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) } + pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { self.db.is_event_soft_failed(event_id) } + pub async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { + self.db.is_event_soft_failed(event_id).await + } } diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 0c156df3..a2c0fabc 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -1,10 +1,18 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, RoomId, UserId}; +use conduit::{ + utils, + utils::{stream::TryIgnore, ReadyExt}, + Error, Result, +}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; +use ruma::{ + events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, + serde::Raw, + CanonicalJsonObject, OwnedUserId, RoomId, UserId, +}; -use super::AnySyncEphemeralRoomEventIter; use crate::{globals, Dep}; pub(super) struct Data { @@ -18,6 +26,8 @@ struct Services { globals: Dep, } +pub(super) type ReceiptItem = (OwnedUserId, u64, Raw); + impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { let db = &args.db; @@ -31,7 +41,9 @@ impl Data { } } - pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { + pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { + type KeyVal<'a> = (&'a RoomId, u64, &'a UserId); + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -39,108 +51,90 @@ impl Data { last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); // Remove old entry - if let Some((old, _)) = self - .readreceiptid_readreceipt - .iter_from(&last_possible_key, true) - .take_while(|(key, _)| key.starts_with(&prefix)) - .find(|(key, _)| { - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element") - == user_id.as_bytes() - }) { - // This is the old room_latest - self.readreceiptid_readreceipt.remove(&old)?; - } + self.readreceiptid_readreceipt + .rev_keys_from_raw(&last_possible_key) + .ignore_err() + .ready_take_while(|(r, ..): &KeyVal<'_>| *r == room_id) + .ready_filter_map(|(r, c, u): KeyVal<'_>| (u == user_id).then_some((r, c, u))) + .ready_for_each(|old: KeyVal<'_>| { + // This is the old room_latest + self.readreceiptid_readreceipt.del(&old); + }) + .await; let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + room_latest_id.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); room_latest_id.push(0xFF); room_latest_id.extend_from_slice(user_id.as_bytes()); self.readreceiptid_readreceipt.insert( &room_latest_id, &serde_json::to_vec(event).expect("EduEvent::to_string always works"), - )?; - - Ok(()) + ); } - pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'a> { + pub(super) fn readreceipts_since<'a>( + &'a self, room_id: &'a RoomId, since: u64, + ) -> impl Stream + Send + 'a { + let after_since = since.saturating_add(1); // +1 so we don't send the event at since + let first_possible_edu = (room_id, after_since); + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); let prefix2 = prefix.clone(); - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since + self.readreceiptid_readreceipt + .stream_raw_from(&first_possible_edu) + .ignore_err() + .ready_take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(k, v)| { + let count_offset = prefix.len().saturating_add(size_of::()); + let user_id_offset = count_offset.saturating_add(1); - Box::new( - self.readreceiptid_readreceipt - .iter_from(&first_possible_edu, false) - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(k, v)| { - let count_offset = prefix.len().saturating_add(size_of::()); - let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; - let user_id_offset = count_offset.saturating_add(1); - let user_id = UserId::parse( - utils::string_from_bytes(&k[user_id_offset..]) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, - ) + let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) + .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + + let user_id_str = utils::string_from_bytes(&k[user_id_offset..]) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?; + + let user_id = UserId::parse(user_id_str) .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; - let mut json = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - json.remove("room_id"); + let mut json = serde_json::from_slice::(v) + .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - Ok(( - user_id, - count, - Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")), - )) - }), - ) + json.remove("room_id"); + + let event = Raw::from_json(serde_json::value::to_raw_value(&json)?); + + Ok((user_id, count, event)) + }) + .ignore_err() } - pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; + .insert(&key, &count.to_be_bytes()); self.roomuserid_lastprivatereadupdate - .insert(&key, &self.services.globals.next_count()?.to_be_bytes()) + .insert(&key, &self.services.globals.next_count().unwrap().to_be_bytes()); } - pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some( - utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, - )) - }) + pub(super) async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.roomuserid_privateread.qry(&key).await.deserialized() } - pub(super) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastprivatereadupdate - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) + pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.roomuserid_lastprivatereadupdate + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } } diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index da11e2a0..ec34361e 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -3,16 +3,17 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; use conduit::{debug, Result}; -use data::Data; +use futures::Stream; use ruma::{ events::{ receipt::{ReceiptEvent, ReceiptEventContent}, - AnySyncEphemeralRoomEvent, SyncEphemeralRoomEvent, + SyncEphemeralRoomEvent, }, serde::Raw, - OwnedUserId, RoomId, UserId, + RoomId, UserId, }; +use self::data::{Data, ReceiptItem}; use crate::{sending, Dep}; pub struct Service { @@ -24,9 +25,6 @@ struct Services { sending: Dep, } -type AnySyncEphemeralRoomEventIter<'a> = - Box)>> + 'a>; - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -42,44 +40,53 @@ impl crate::Service for Service { impl Service { /// Replaces the previous read receipt. - pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { - self.db.readreceipt_update(user_id, room_id, event)?; - self.services.sending.flush_room(room_id)?; - - Ok(()) + pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { + self.db.readreceipt_update(user_id, room_id, event).await; + self.services + .sending + .flush_room(room_id) + .await + .expect("room flush failed"); } /// Returns an iterator over the most recent read_receipts in a room that /// happened after the event with id `since`. + #[inline] #[tracing::instrument(skip(self), level = "debug")] pub fn readreceipts_since<'a>( - &'a self, room_id: &RoomId, since: u64, - ) -> impl Iterator)>> + 'a { + &'a self, room_id: &'a RoomId, since: u64, + ) -> impl Stream + Send + 'a { self.db.readreceipts_since(room_id, since) } /// Sets a private read marker at `count`. + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - self.db.private_read_set(room_id, user_id, count) + pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { + self.db.private_read_set(room_id, user_id, count); } /// Returns the private read marker. + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.private_read_get(room_id, user_id) + pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { + self.db.private_read_get(room_id, user_id).await } /// Returns the count of the last typing update in this room. - pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.last_privateread_update(user_id, room_id) + #[inline] + pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.last_privateread_update(user_id, room_id).await } } #[must_use] -pub fn pack_receipts(receipts: AnySyncEphemeralRoomEventIter<'_>) -> Raw> { +pub fn pack_receipts(receipts: I) -> Raw> +where + I: Iterator, +{ let mut json = BTreeMap::new(); - for (_user, _count, value) in receipts.flatten() { + for (_, _, value) in receipts { let receipt = serde_json::from_str::>(value.json().get()); if let Ok(value) = receipt { for (event, receipt) in value.content { diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index a0086095..de98beee 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,13 +1,12 @@ use std::sync::Arc; -use conduit::{utils, Result}; +use conduit::utils::{set, stream::TryIgnore, IterStream, ReadyExt}; use database::Map; +use futures::StreamExt; use ruma::RoomId; use crate::{rooms, Dep}; -type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; - pub(super) struct Data { tokenids: Arc, services: Services, @@ -28,7 +27,7 @@ impl Data { } } - pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { let batch = tokenize(message_body) .map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); @@ -39,11 +38,10 @@ impl Data { }) .collect::>(); - self.tokenids - .insert_batch(batch.iter().map(database::KeyVal::from)) + self.tokenids.insert_batch(batch.iter()); } - pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { let batch = tokenize(message_body).map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); @@ -53,46 +51,53 @@ impl Data { }); for token in batch { - self.tokenids.remove(&token)?; + self.tokenids.remove(&token); } - - Ok(()) } - pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { + pub(super) async fn search_pdus( + &self, room_id: &RoomId, search_string: &str, + ) -> Option<(Vec>, Vec)> { let prefix = self .services .short - .get_shortroomid(room_id)? - .expect("room exists") + .get_shortroomid(room_id) + .await + .ok()? .to_be_bytes() .to_vec(); let words: Vec<_> = tokenize(search_string).collect(); - let iterators = words.clone().into_iter().map(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xFF); - let prefix3 = prefix2.clone(); + let bufs: Vec<_> = words + .clone() + .into_iter() + .stream() + .then(move |word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xFF); + let prefix3 = prefix2.clone(); - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); + let mut last_possible_id = prefix2.clone(); + last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - self.tokenids - .iter_from(&last_possible_id, true) // Newest pdus first - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(key, _)| key[prefix3.len()..].to_vec()) - }); + self.tokenids + .rev_raw_keys_from(&last_possible_id) // Newest pdus first + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix2)) + .map(move |key| key[prefix3.len()..].to_vec()) + .collect::>() + }) + .collect() + .await; - let Some(common_elements) = utils::common_elements(iterators, |a, b| { - // We compare b with a because we reversed the iterator earlier - b.cmp(a) - }) else { - return Ok(None); - }; - - Ok(Some((Box::new(common_elements), words))) + Some(( + set::intersection(bufs.iter().map(|buf| buf.iter())) + .cloned() + .collect(), + words, + )) } } @@ -100,7 +105,7 @@ impl Data { /// /// This may be used to tokenize both message bodies (for indexing) or search /// queries (for querying). -fn tokenize(body: &str) -> impl Iterator + '_ { +fn tokenize(body: &str) -> impl Iterator + Send + '_ { body.split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) .filter(|word| word.len() <= 50) diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 8caa0ce3..80b58804 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -21,20 +21,21 @@ impl crate::Service for Service { } impl Service { + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.index_pdu(shortroomid, pdu_id, message_body) + pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { + self.db.index_pdu(shortroomid, pdu_id, message_body); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.deindex_pdu(shortroomid, pdu_id, message_body) + pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { + self.db.deindex_pdu(shortroomid, pdu_id, message_body); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn search_pdus<'a>( - &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a, Vec)>> { - self.db.search_pdus(room_id, search_string) + pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec>, Vec)> { + self.db.search_pdus(room_id, search_string).await } } diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 17fbb64e..f6a82488 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use conduit::{utils, warn, Error, Result}; -use database::Map; +use conduit::{err, utils, Error, Result}; +use database::{Deserialized, Map}; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{globals, Dep}; @@ -36,44 +36,46 @@ impl Data { } } - pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { - utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? - } else { - let shorteventid = self.services.globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - }; + pub(super) async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { + if let Ok(shorteventid) = self.eventid_shorteventid.qry(event_id).await.deserialized() { + return shorteventid; + } - Ok(short) + let shorteventid = self.services.globals.next_count().unwrap(); + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes()); + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes()); + + shorteventid } - pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { + pub(super) async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { let mut ret: Vec = Vec::with_capacity(event_ids.len()); let keys = event_ids .iter() .map(|id| id.as_bytes()) .collect::>(); + for (i, short) in self .eventid_shorteventid - .multi_get(&keys)? + .multi_get(keys.iter()) .iter() .enumerate() { #[allow(clippy::single_match_else)] match short { Some(short) => ret.push( - utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + utils::u64_from_bytes(short) + .map_err(|_| Error::bad_database("Invalid shorteventid in db.")) + .unwrap(), ), None => { - let short = self.services.globals.next_count()?; + let short = self.services.globals.next_count().unwrap(); self.eventid_shorteventid - .insert(keys[i], &short.to_be_bytes())?; + .insert(keys[i], &short.to_be_bytes()); self.shorteventid_eventid - .insert(&short.to_be_bytes(), keys[i])?; + .insert(&short.to_be_bytes(), keys[i]); debug_assert!(ret.len() == i, "position of result must match input"); ret.push(short); @@ -81,115 +83,85 @@ impl Data { } } - Ok(ret) + ret } - pub(super) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); + pub(super) async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + let key = (event_type, state_key); + self.statekey_shortstatekey.qry(&key).await.deserialized() + } - let short = self - .statekey_shortstatekey - .get(&statekey_vec)? - .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + pub(super) async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { + let key = (event_type.to_string(), state_key); + if let Ok(shortstatekey) = self.statekey_shortstatekey.qry(&key).await.deserialized() { + return shortstatekey; + } + + let mut key = event_type.to_string().as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(state_key.as_bytes()); + + let shortstatekey = self.services.globals.next_count().unwrap(); + self.statekey_shortstatekey + .insert(&key, &shortstatekey.to_be_bytes()); + self.shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &key); + + shortstatekey + } + + pub(super) async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + self.shorteventid_eventid + .qry(&shorteventid) + .await + .deserialized() + .map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}"))) + } + + pub(super) async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + self.shortstatekey_statekey + .qry(&shortstatekey) + .await + .deserialized() + .map_err(|e| { + err!(Database( + "Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}" + )) }) - .transpose()?; - - Ok(short) - } - - pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? - } else { - let shortstatekey = self.services.globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; - shortstatekey - }; - - Ok(short) - } - - pub(super) fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - let bytes = self - .shorteventid_eventid - .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - - let event_id = EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; - - Ok(event_id) - } - - pub(super) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - let bytes = self - .shortstatekey_statekey - .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; - - let mut parts = bytes.splitn(2, |&b| b == 0xFF); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; - - let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { - warn!("Event type in shortstatekey_statekey is invalid: {}", e); - Error::bad_database("Event type in shortstatekey_statekey is invalid.") - })?); - - let state_key = utils::string_from_bytes(statekey_bytes) - .map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?; - - let result = (event_type, state_key); - - Ok(result) } /// Returns (shortstatehash, already_existed) - pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { - ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ) - } else { - let shortstatehash = self.services.globals.next_count()?; - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - }) + pub(super) async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { + if let Ok(shortstatehash) = self + .statehash_shortstatehash + .qry(state_hash) + .await + .deserialized() + { + return (shortstatehash, true); + } + + let shortstatehash = self.services.globals.next_count().unwrap(); + self.statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes()); + + (shortstatehash, false) } - pub(super) fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + pub(super) async fn get_shortroomid(&self, room_id: &RoomId) -> Result { + self.roomid_shortroomid.qry(room_id).await.deserialized() + } + + pub(super) async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) - .transpose() - } - - pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { - utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? - } else { - let short = self.services.globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - }) + .qry(room_id) + .await + .deserialized() + .unwrap_or_else(|_| { + let short = self.services.globals.next_count().unwrap(); + self.roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes()); + short + }) } } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index bfe0e9a0..00bb7cb1 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -22,38 +22,40 @@ impl crate::Service for Service { } impl Service { - pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - self.db.get_or_create_shorteventid(event_id) + pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { + self.db.get_or_create_shorteventid(event_id).await } - pub fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { - self.db.multi_get_or_create_shorteventid(event_ids) + pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { + self.db.multi_get_or_create_shorteventid(event_ids).await } - pub fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { - self.db.get_shortstatekey(event_type, state_key) + pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + self.db.get_shortstatekey(event_type, state_key).await } - pub fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - self.db.get_or_create_shortstatekey(event_type, state_key) + pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { + self.db + .get_or_create_shortstatekey(event_type, state_key) + .await } - pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - self.db.get_eventid_from_short(shorteventid) + pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + self.db.get_eventid_from_short(shorteventid).await } - pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - self.db.get_statekey_from_short(shortstatekey) + pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + self.db.get_statekey_from_short(shortstatekey).await } /// Returns (shortstatehash, already_existed) - pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - self.db.get_or_create_shortstatehash(state_hash) + pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { + self.db.get_or_create_shortstatehash(state_hash).await } - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.db.get_shortroomid(room_id) } + pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result { self.db.get_shortroomid(room_id).await } - pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - self.db.get_or_create_shortroomid(room_id) + pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { + self.db.get_or_create_shortroomid(room_id).await } } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 24d612d8..17fbf0ef 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -7,7 +7,12 @@ use std::{ sync::Arc, }; -use conduit::{checked, debug, debug_info, err, utils::math::usize_from_f64, warn, Error, Result}; +use conduit::{ + checked, debug, debug_info, err, + utils::{math::usize_from_f64, IterStream}, + Error, Result, +}; +use futures::{StreamExt, TryFutureExt}; use lru_cache::LruCache; use ruma::{ api::{ @@ -211,12 +216,15 @@ impl Service { .as_ref() { return Ok(if let Some(cached) = cached { - if self.is_accessible_child( - current_room, - &cached.summary.join_rule, - &identifier, - &cached.summary.allowed_room_ids, - ) { + if self + .is_accessible_child( + current_room, + &cached.summary.join_rule, + &identifier, + &cached.summary.allowed_room_ids, + ) + .await + { Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone()))) } else { Some(SummaryAccessibility::Inaccessible) @@ -228,7 +236,9 @@ impl Service { Ok( if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { - let summary = self.get_room_summary(current_room, children_pdus, &identifier); + let summary = self + .get_room_summary(current_room, children_pdus, &identifier) + .await; if let Ok(summary) = summary { self.roomid_spacehierarchy_cache.lock().await.insert( current_room.clone(), @@ -322,12 +332,15 @@ impl Service { ); } } - if self.is_accessible_child( - current_room, - &response.room.join_rule, - &Identifier::UserId(user_id), - &response.room.allowed_room_ids, - ) { + if self + .is_accessible_child( + current_room, + &response.room.join_rule, + &Identifier::UserId(user_id), + &response.room.allowed_room_ids, + ) + .await + { return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); } @@ -358,7 +371,7 @@ impl Service { } } - fn get_room_summary( + async fn get_room_summary( &self, current_room: &OwnedRoomId, children_state: Vec>, identifier: &Identifier<'_>, ) -> Result { @@ -367,48 +380,43 @@ impl Service { let join_rule = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { + .room_state_get(room_id, &StateEventType::RoomJoinRules, "") + .await + .map_or(JoinRule::Invite, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomJoinRulesEventContent| c.join_rule) .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - }) - .transpose()? - .unwrap_or(JoinRule::Invite); + .unwrap() + }); let allowed_room_ids = self .services .state_accessor .allowed_room_ids(join_rule.clone()); - if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { + if !self + .is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) + .await + { debug!("User is not allowed to see room {room_id}"); // This error will be caught later return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); } - let join_rule = join_rule.into(); - Ok(SpaceHierarchyParentSummary { canonical_alias: self .services .state_accessor .get_canonical_alias(room_id) - .unwrap_or(None), - name: self - .services - .state_accessor - .get_name(room_id) - .unwrap_or(None), + .await + .ok(), + name: self.services.state_accessor.get_name(room_id).await.ok(), num_joined_members: self .services .state_cache .room_joined_count(room_id) - .unwrap_or_default() - .unwrap_or_else(|| { - warn!("Room {room_id} has no member count"); - 0 - }) + .await + .unwrap_or(0) .try_into() .expect("user count should not be that big"), room_id: room_id.to_owned(), @@ -416,18 +424,29 @@ impl Service { .services .state_accessor .get_room_topic(room_id) - .unwrap_or(None), - world_readable: self.services.state_accessor.is_world_readable(room_id)?, - guest_can_join: self.services.state_accessor.guest_can_join(room_id)?, + .await + .ok(), + world_readable: self + .services + .state_accessor + .is_world_readable(room_id) + .await, + guest_can_join: self.services.state_accessor.guest_can_join(room_id).await, avatar_url: self .services .state_accessor - .get_avatar(room_id)? + .get_avatar(room_id) + .await .into_option() .unwrap_or_default() .url, - join_rule, - room_type: self.services.state_accessor.get_room_type(room_id)?, + join_rule: join_rule.into(), + room_type: self + .services + .state_accessor + .get_room_type(room_id) + .await + .ok(), children_state, allowed_room_ids, }) @@ -474,21 +493,22 @@ impl Service { results.push(summary_to_chunk(*summary.clone())); } else { children = children - .into_iter() - .rev() - .skip_while(|(room, _)| { - if let Ok(short) = self.services.short.get_shortroomid(room) - { - short.as_ref() != short_room_ids.get(parents.len()) - } else { - false - } - }) - .collect::>() - // skip_while doesn't implement DoubleEndedIterator, which is needed for rev - .into_iter() - .rev() - .collect(); + .iter() + .rev() + .stream() + .skip_while(|(room, _)| { + self.services + .short + .get_shortroomid(room) + .map_ok(|short| Some(&short) != short_room_ids.get(parents.len())) + .unwrap_or_else(|_| false) + }) + .map(Clone::clone) + .collect::)>>() + .await + .into_iter() + .rev() + .collect(); if children.is_empty() { return Err(Error::BadRequest( @@ -531,7 +551,7 @@ impl Service { let mut short_room_ids = vec![]; for room in parents { - short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?); + short_room_ids.push(self.services.short.get_or_create_shortroomid(&room).await); } Some( @@ -554,7 +574,7 @@ impl Service { async fn get_stripped_space_child_events( &self, room_id: &RoomId, ) -> Result>>, Error> { - let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else { + let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else { return Ok(None); }; @@ -562,10 +582,13 @@ impl Service { .services .state_accessor .state_full_ids(current_shortstatehash) - .await?; + .await + .map_err(|e| err!(Database("State in space not found: {e}")))?; + let mut children_pdus = Vec::new(); for (key, id) in state { - let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?; + let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?; + if event_type != StateEventType::SpaceChild { continue; } @@ -573,8 +596,9 @@ impl Service { let pdu = self .services .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + .get_pdu(&id) + .await + .map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?; if serde_json::from_str::(pdu.content.get()) .ok() @@ -593,7 +617,7 @@ impl Service { } /// With the given identifier, checks if a room is accessable - fn is_accessible_child( + async fn is_accessible_child( &self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, allowed_room_ids: &Vec, ) -> bool { @@ -607,6 +631,7 @@ impl Service { .services .event_handler .acl_check(server_name, room_id) + .await .is_err() { return false; @@ -617,12 +642,11 @@ impl Service { .services .state_cache .is_joined(user_id, current_room) - .unwrap_or_default() - || self - .services - .state_cache - .is_invited(user_id, current_room) - .unwrap_or_default() + .await || self + .services + .state_cache + .is_invited(user_id, current_room) + .await { return true; } @@ -633,22 +657,12 @@ impl Service { for room in allowed_room_ids { match identifier { Identifier::UserId(user) => { - if self - .services - .state_cache - .is_joined(user, room) - .unwrap_or_default() - { + if self.services.state_cache.is_joined(user, room).await { return true; } }, Identifier::ServerName(server) => { - if self - .services - .state_cache - .server_in_room(server, room) - .unwrap_or_default() - { + if self.services.state_cache.server_in_room(server, room).await { return true; } }, diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 3c110afc..ccf7509a 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,34 +1,31 @@ -use std::{collections::HashSet, sync::Arc}; +use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{EventId, OwnedEventId, RoomId}; +use conduit::{ + utils::{stream::TryIgnore, ReadyExt}, + Result, +}; +use database::{Database, Deserialized, Interfix, Map}; +use ruma::{OwnedEventId, RoomId}; use super::RoomMutexGuard; pub(super) struct Data { shorteventid_shortstatehash: Arc, - roomid_pduleaves: Arc, roomid_shortstatehash: Arc, + pub(super) roomid_pduleaves: Arc, } impl Data { pub(super) fn new(db: &Arc) -> Self { Self { shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(), - roomid_pduleaves: db["roomid_pduleaves"].clone(), roomid_shortstatehash: db["roomid_shortstatehash"].clone(), + roomid_pduleaves: db["roomid_pduleaves"].clone(), } } - pub(super) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) + pub(super) async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { + self.roomid_shortstatehash.qry(room_id).await.deserialized() } #[inline] @@ -37,53 +34,35 @@ impl Data { room_id: &RoomId, new_shortstatehash: u64, _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - Ok(()) + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes()); } - pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { + pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) { self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; - Ok(()) + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes()); } - pub(super) fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } - - pub(super) fn set_forward_extremities( + pub(super) async fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { + let prefix = (room_id, Interfix); + self.roomid_pduleaves + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.roomid_pduleaves.remove(key)) + .await; + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); - - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } - for event_id in event_ids { let mut key = prefix.clone(); key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; + self.roomid_pduleaves.insert(&key, event_id.as_bytes()); } - - Ok(()) } } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index cb219bc0..c7f6605c 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -7,12 +7,14 @@ use std::{ }; use conduit::{ - utils::{calculate_hash, MutexMap, MutexMapGuard}, - warn, Error, PduEvent, Result, + err, + utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard}, + warn, PduEvent, Result, }; use data::Data; +use database::{Ignore, Interfix}; +use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ - api::client::error::ErrorKind, events::{ room::{create::RoomCreateEventContent, member::RoomMemberEventContent}, AnyStrippedStateEvent, StateEventType, TimelineEventType, @@ -81,14 +83,16 @@ impl Service { _statediffremoved: Arc>, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { - for event_id in statediffnew.iter().filter_map(|new| { + let event_ids = statediffnew.iter().stream().filter_map(|new| { self.services .state_compressor .parse_compressed_state_event(new) - .ok() - .map(|(_, id)| id) - }) { - let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else { + .map_ok_or_else(|_| None, |(_, event_id)| Some(event_id)) + }); + + pin_mut!(event_ids); + while let Some(event_id) = event_ids.next().await { + let Ok(pdu) = self.services.timeline.get_pdu_json(&event_id).await else { continue; }; @@ -113,15 +117,10 @@ impl Service { continue; }; - self.services.state_cache.update_membership( - room_id, - &user_id, - membership_event, - &pdu.sender, - None, - None, - false, - )?; + self.services + .state_cache + .update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false) + .await?; }, TimelineEventType::SpaceChild => { self.services @@ -135,10 +134,9 @@ impl Service { } } - self.services.state_cache.update_joined_count(room_id)?; + self.services.state_cache.update_joined_count(room_id).await; - self.db - .set_room_state(room_id, shortstatehash, state_lock)?; + self.db.set_room_state(room_id, shortstatehash, state_lock); Ok(()) } @@ -148,12 +146,16 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed), level = "debug")] - pub fn set_event_state( + pub async fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { - let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?; + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; - let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = self.db.get_room_shortstatehash(room_id).await; let state_hash = calculate_hash( &state_ids_compressed @@ -165,13 +167,18 @@ impl Service { let (shortstatehash, already_existed) = self .services .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await; if !already_existed { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - |p| self.services.state_compressor.load_shortstatehash_info(p), - )?; + let states_parents = if let Ok(p) = previous_shortstatehash { + self.services + .state_compressor + .load_shortstatehash_info(p) + .await? + } else { + Vec::new() + }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed @@ -198,7 +205,7 @@ impl Service { )?; } - self.db.set_event_state(shorteventid, shortstatehash)?; + self.db.set_event_state(shorteventid, shortstatehash); Ok(shortstatehash) } @@ -208,34 +215,40 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu), level = "debug")] - pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result { let shorteventid = self .services .short - .get_or_create_shorteventid(&new_pdu.event_id)?; + .get_or_create_shorteventid(&new_pdu.event_id) + .await; - let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; + let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id).await; - if let Some(p) = previous_shortstatehash { - self.db.set_event_state(shorteventid, p)?; + if let Ok(p) = previous_shortstatehash { + self.db.set_event_state(shorteventid, p); } if let Some(state_key) = &new_pdu.state_key { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - #[inline] - |p| self.services.state_compressor.load_shortstatehash_info(p), - )?; + let states_parents = if let Ok(p) = previous_shortstatehash { + self.services + .state_compressor + .load_shortstatehash_info(p) + .await? + } else { + Vec::new() + }; let shortstatekey = self .services .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key) + .await; let new = self .services .state_compressor - .compress_state_event(shortstatekey, &new_pdu.event_id)?; + .compress_state_event(shortstatekey, &new_pdu.event_id) + .await; let replaces = states_parents .last() @@ -276,49 +289,55 @@ impl Service { } #[tracing::instrument(skip(self, invite_event), level = "debug")] - pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { + pub async fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { let mut state = Vec::new(); // Add recommended events - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = self.services.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomCanonicalAlias, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = self.services.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "") + .await + { + state.push(e.to_stripped_state_event()); + } + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str()) + .await + { state.push(e.to_stripped_state_event()); } @@ -333,101 +352,108 @@ impl Service { room_id: &RoomId, shortstatehash: u64, mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.db.set_room_state(room_id, shortstatehash, mutex_lock) + ) { + self.db.set_room_state(room_id, shortstatehash, mutex_lock); } /// Returns the room's version. #[tracing::instrument(skip(self), level = "debug")] - pub fn get_room_version(&self, room_id: &RoomId) -> Result { - let create_event = self - .services + pub async fn get_room_version(&self, room_id: &RoomId) -> Result { + self.services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: RoomCreateEventContent = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?; - - Ok(create_event_content.room_version) + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.room_version) + .map_err(|e| err!(Request(NotFound("No create event found: {e:?}")))) } #[inline] - pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.db.get_room_shortstatehash(room_id) + pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { + self.db.get_room_shortstatehash(room_id).await } - pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - self.db.get_forward_extremities(room_id) + pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + '_ { + let prefix = (room_id, Interfix); + + self.db + .roomid_pduleaves + .keys_prefix(&prefix) + .map_ok(|(_, event_id): (Ignore, &EventId)| event_id) + .ignore_err() } - pub fn set_forward_extremities( + pub async fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { self.db .set_forward_extremities(room_id, event_ids, state_lock) + .await; } /// This fetches auth events from the current state. #[tracing::instrument(skip(self), level = "debug")] - pub fn get_auth_events( + pub async fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, ) -> Result>> { - let Some(shortstatehash) = self.get_room_shortstatehash(room_id)? else { + let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else { return Ok(HashMap::new()); }; - let auth_events = - state_res::auth_types_for_event(kind, sender, state_key, content).expect("content is a valid JSON object"); + let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)?; - let mut sauthevents = auth_events - .into_iter() + let mut sauthevents: HashMap<_, _> = auth_events + .iter() + .stream() .filter_map(|(event_type, state_key)| { self.services .short - .get_shortstatekey(&event_type.to_string().into(), &state_key) - .ok() - .flatten() - .map(|s| (s, (event_type, state_key))) + .get_shortstatekey(event_type, state_key) + .map_ok(move |s| (s, (event_type, state_key))) + .map(Result::ok) }) - .collect::>(); + .collect() + .await; let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| { + err!(Database( + "Missing shortstatehash info for {room_id:?} at {shortstatehash:?}: {e:?}" + )) + })? .pop() .expect("there is always one layer") .1; - Ok(full_state - .iter() - .filter_map(|compressed| { - self.services - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - }) - .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) - .filter_map(|(k, event_id)| { - self.services - .timeline - .get_pdu(&event_id) - .ok() - .flatten() - .map(|pdu| (k, pdu)) - }) - .collect()) + let mut ret = HashMap::new(); + for compressed in full_state.iter() { + let Ok((shortstatekey, event_id)) = self + .services + .state_compressor + .parse_compressed_state_event(compressed) + .await + else { + continue; + }; + + let Some((ty, state_key)) = sauthevents.remove(&shortstatekey) else { + continue; + }; + + let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else { + continue; + }; + + ret.insert((ty.to_owned(), state_key.to_owned()), pdu); + } + + Ok(ret) } } diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 4c85148d..79a98325 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,7 +1,8 @@ use std::{collections::HashMap, sync::Arc}; -use conduit::{utils, Error, PduEvent, Result}; -use database::Map; +use conduit::{err, PduEvent, Result}; +use database::{Deserialized, Map}; +use futures::TryFutureExt; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{rooms, Dep}; @@ -39,17 +40,22 @@ impl Data { let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database("Missing state IDs: {e}")))? .pop() .expect("there is always one layer") .1; + let mut result = HashMap::new(); let mut i: u8 = 0; for compressed in full_state.iter() { let parsed = self .services .state_compressor - .parse_compressed_state_event(compressed)?; + .parse_compressed_state_event(compressed) + .await?; + result.insert(parsed.0, parsed.1); i = i.wrapping_add(1); @@ -57,6 +63,7 @@ impl Data { tokio::task::yield_now().await; } } + Ok(result) } @@ -67,7 +74,8 @@ impl Data { let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await? .pop() .expect("there is always one layer") .1; @@ -78,18 +86,13 @@ impl Data { let (_, eventid) = self .services .state_compressor - .parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); + .parse_compressed_state_event(compressed) + .await?; + + if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await { + if let Some(state_key) = pdu.state_key.as_ref() { + result.insert((pdu.kind.to_string().into(), state_key.clone()), pdu); + } } i = i.wrapping_add(1); @@ -101,61 +104,63 @@ impl Data { Ok(result) } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). #[allow(clippy::unused_self)] - pub(super) fn state_get_id( + pub(super) async fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - let Some(shortstatekey) = self + ) -> Result> { + let shortstatekey = self .services .short - .get_shortstatekey(event_type, state_key)? - else { - return Ok(None); - }; + .get_shortstatekey(event_type, state_key) + .await?; + let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))? .pop() .expect("there is always one layer") .1; - Ok(full_state + + let compressed = full_state .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - self.services - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) + .ok_or(err!(Database("No shortstatekey in compressed state")))?; + + self.services + .state_compressor + .parse_compressed_state_event(compressed) + .map_ok(|(_, id)| id) + .map_err(|e| { + err!(Database(error!( + ?event_type, + ?state_key, + ?shortstatekey, + "Failed to parse compressed: {e:?}" + ))) + }) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn state_get( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id)) + ) -> Result> { + self.state_get_id(shortstatehash, event_type, state_key) + .and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await }) + .await } /// Returns the state hash for this pdu. - pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash") - }) - }) - .transpose() - }) + .qry(event_id) + .and_then(|shorteventid| self.shorteventid_shortstatehash.qry(&shorteventid)) + .await + .deserialized() } /// Returns the full room state. @@ -163,34 +168,33 @@ impl Data { pub(super) async fn room_state_full( &self, room_id: &RoomId, ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_full(shortstatehash)) + .map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}"))) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn room_state_get_id( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + ) -> Result> { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key)) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn room_state_get( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + ) -> Result> { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key)) + .await } } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 58fa31b3..4c28483c 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,8 +6,13 @@ use std::{ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result}; -use data::Data; +use conduit::{ + err, error, + pdu::PduBuilder, + utils::{math::usize_from_f64, ReadyExt}, + Error, PduEvent, Result, +}; +use futures::StreamExt; use lru_cache::LruCache; use ruma::{ events::{ @@ -31,8 +36,10 @@ use ruma::{ EventEncryptionAlgorithm, EventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; +use serde::Deserialize; use serde_json::value::to_raw_value; +use self::data::Data; use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; pub struct Service { @@ -99,54 +106,58 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn state_get_id( + pub async fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.state_get_id(shortstatehash, event_type, state_key) + ) -> Result> { + self.db + .state_get_id(shortstatehash, event_type, state_key) + .await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[inline] - pub fn state_get( + pub async fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.state_get(shortstatehash, event_type, state_key) + ) -> Result> { + self.db + .state_get(shortstatehash, event_type, state_key) + .await } /// Get membership for given user in state - fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result { - self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(MembershipState::Leave), |s| { + async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState { + self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) + .await + .map_or(MembershipState::Leave, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomMemberEventContent| c.membership) .map_err(|_| Error::bad_database("Invalid room membership event in database.")) + .unwrap() }) } /// The user was a joined member at this state (potentially in the past) #[inline] - fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join) - // Return sensible default, i.e. - // false + async fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { + self.user_membership(shortstatehash, user_id).await == MembershipState::Join } /// The user was an invited or joined room member at this state (potentially /// in the past) #[inline] - fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite) - // Return sensible default, i.e. false + async fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { + let s = self.user_membership(shortstatehash, user_id).await; + s == MembershipState::Join || s == MembershipState::Invite } /// Whether a server is allowed to see an event through federation, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, origin, room_id, event_id))] - pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> Result { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { + pub async fn server_can_see_event( + &self, origin: &ServerName, room_id: &RoomId, event_id: &EventId, + ) -> Result { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { return Ok(true); }; @@ -160,8 +171,9 @@ impl Service { } let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { + .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|e| { @@ -171,25 +183,28 @@ impl Service { ); Error::bad_database("Invalid history visibility event in database.") }) - }) - .unwrap_or(HistoryVisibility::Shared); + .unwrap() + }); - let mut current_server_members = self + let current_server_members = self .services .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|member| member.server_name() == origin); + .ready_filter(|member| member.server_name() == origin); let visibility = match history_visibility { HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, HistoryVisibility::Invited => { // Allow if any member on requesting server was AT LEAST invited, else deny - current_server_members.any(|member| self.user_was_invited(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_invited(shortstatehash, member)) + .await }, HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - current_server_members.any(|member| self.user_was_joined(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_joined(shortstatehash, member)) + .await }, _ => { error!("Unknown history visibility {history_visibility}"); @@ -208,9 +223,9 @@ impl Service { /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id, event_id))] - pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { - return Ok(true); + pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + return true; }; if let Some(visibility) = self @@ -219,14 +234,15 @@ impl Service { .unwrap() .get_mut(&(user_id.to_owned(), shortstatehash)) { - return Ok(*visibility); + return *visibility; } - let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; + let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { + .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|e| { @@ -236,19 +252,19 @@ impl Service { ); Error::bad_database("Invalid history visibility event in database.") }) - }) - .unwrap_or(HistoryVisibility::Shared); + .unwrap() + }); let visibility = match history_visibility { HistoryVisibility::WorldReadable => true, HistoryVisibility::Shared => currently_member, HistoryVisibility::Invited => { // Allow if any member on requesting server was AT LEAST invited, else deny - self.user_was_invited(shortstatehash, user_id) + self.user_was_invited(shortstatehash, user_id).await }, HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - self.user_was_joined(shortstatehash, user_id) + self.user_was_joined(shortstatehash, user_id).await }, _ => { error!("Unknown history visibility {history_visibility}"); @@ -261,17 +277,18 @@ impl Service { .unwrap() .insert((user_id.to_owned(), shortstatehash), visibility); - Ok(visibility) + visibility } /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id))] - pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; + pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; let history_visibility = self - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? + .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "") + .await .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) @@ -285,11 +302,13 @@ impl Service { }) .unwrap_or(HistoryVisibility::Shared); - Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) + currently_member || history_visibility == HistoryVisibility::WorldReadable } /// Returns the state hash for this pdu. - pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.db.pdu_shortstatehash(event_id) } + pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { + self.db.pdu_shortstatehash(event_id).await + } /// Returns the full room state. #[tracing::instrument(skip(self), level = "debug")] @@ -300,47 +319,61 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_get_id( + pub async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.room_state_get_id(room_id, event_type, state_key) + ) -> Result> { + self.db + .room_state_get_id(room_id, event_type, state_key) + .await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_get( + pub async fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.room_state_get(room_id, event_type, state_key) + ) -> Result> { + self.db.room_state_get(room_id, event_type, state_key).await } - pub fn get_name(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomName, "")? - .map_or(Ok(None), |s| { - Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name))) - }) + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub async fn room_state_get_content( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result + where + T: for<'de> Deserialize<'de> + Send, + { + use serde_json::from_str; + + self.room_state_get(room_id, event_type, state_key) + .await + .and_then(|event| from_str::(event.content.get()).map_err(Into::into)) } - pub fn get_avatar(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomAvatar, "")? - .map_or(Ok(ruma::JsOption::Undefined), |s| { + pub async fn get_name(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomName, "") + .await + .map(|c: RoomNameEventContent| c.name) + } + + pub async fn get_avatar(&self, room_id: &RoomId) -> ruma::JsOption { + self.room_state_get(room_id, &StateEventType::RoomAvatar, "") + .await + .map_or(ruma::JsOption::Undefined, |s| { serde_json::from_str(s.content.get()) .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + .unwrap() }) } - pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room member event in database.")) - }) + pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await } - pub fn user_can_invite( + pub async fn user_can_invite( &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, - ) -> Result { + ) -> bool { let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) .expect("Event content always serializes"); @@ -353,122 +386,101 @@ impl Service { timestamp: None, }; - Ok(self - .services + self.services .timeline .create_hash_and_sign_event(new_event, sender, room_id, state_lock) - .is_ok()) + .await + .is_ok() } /// Checks if guests are able to view room content without joining - pub fn is_world_readable(&self, room_id: &RoomId) -> Result { - self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable - }) - .map_err(|e| { - error!( - "Invalid room history visibility event in database for room {room_id}, assuming not world \ - readable: {e} " - ); - Error::bad_database("Invalid room history visibility event in database.") - }) - }) + pub async fn is_world_readable(&self, room_id: &RoomId) -> bool { + self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") + .await + .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable) + .unwrap_or(false) } /// Checks if guests are able to join a given room - pub fn guest_can_join(&self, room_id: &RoomId) -> Result { - self.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) - .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) - }) + pub async fn guest_can_join(&self, room_id: &RoomId) -> bool { + self.room_state_get_content(room_id, &StateEventType::RoomGuestAccess, "") + .await + .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) + .unwrap_or(false) } /// Gets the primary alias from canonical alias event - pub fn get_canonical_alias(&self, room_id: &RoomId) -> Result, Error> { - self.room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomCanonicalAliasEventContent| c.alias) - .map_err(|_| Error::bad_database("Invalid canonical alias event in database.")) + pub async fn get_canonical_alias(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomCanonicalAlias, "") + .await + .and_then(|c: RoomCanonicalAliasEventContent| { + c.alias + .ok_or_else(|| err!(Request(NotFound("No alias found in event content.")))) }) } /// Gets the room topic - pub fn get_room_topic(&self, room_id: &RoomId) -> Result, Error> { - self.room_state_get(room_id, &StateEventType::RoomTopic, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomTopicEventContent| Some(c.topic)) - .map_err(|e| { - error!("Invalid room topic event in database for room {room_id}: {e}"); - Error::bad_database("Invalid room topic event in database.") - }) - }) + pub async fn get_room_topic(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomTopic, "") + .await + .map(|c: RoomTopicEventContent| c.topic) } /// Checks if a given user can redact a given event /// /// If federation is true, it allows redaction events from any user of the /// same server as the original event sender - pub fn user_can_redact( + pub async fn user_can_redact( &self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool, ) -> Result { - self.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map_or_else( - || { - // Falling back on m.room.create to judge power level - if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? { - Ok(pdu.sender == sender - || if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { - pdu.sender == sender - } else { - false - }) + if let Ok(event) = self + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await + { + let Ok(event) = serde_json::from_str(event.content.get()) + .map(|content: RoomPowerLevelsEventContent| content.into()) + .map(|event: RoomPowerLevels| event) + else { + return Ok(false); + }; + + Ok(event.user_can_redact_event_of_other(sender) + || event.user_can_redact_own_event(sender) + && if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { + if federation { + pdu.sender.server_name() == sender.server_name() + } else { + pdu.sender == sender + } } else { - Err(Error::bad_database( - "No m.room.power_levels or m.room.create events in database for room", - )) - } - }, - |event| { - serde_json::from_str(event.content.get()) - .map(|content: RoomPowerLevelsEventContent| content.into()) - .map(|event: RoomPowerLevels| { - event.user_can_redact_event_of_other(sender) - || event.user_can_redact_own_event(sender) - && if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { - if federation { - pdu.sender.server_name() == sender.server_name() - } else { - pdu.sender == sender - } - } else { - false - } - }) - .map_err(|_| Error::bad_database("Invalid m.room.power_levels event in database")) - }, - ) + false + }) + } else { + // Falling back on m.room.create to judge power level + if let Ok(pdu) = self + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + { + Ok(pdu.sender == sender + || if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { + pdu.sender == sender + } else { + false + }) + } else { + Err(Error::bad_database( + "No m.room.power_levels or m.room.create events in database for room", + )) + } + } } /// Returns the join rule (`SpaceRoomJoinRule`) for a given room - pub fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec), Error> { - Ok(self - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| { - (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)) - }) - .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - }) - .transpose()? - .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) + pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec)> { + self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))) + .or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![]))) } /// Returns an empty vec if not a restricted room @@ -487,25 +499,21 @@ impl Service { room_ids } - pub fn get_room_type(&self, room_id: &RoomId) -> Result> { - Ok(self - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .map(|s| { - serde_json::from_str::(s.content.get()) - .map_err(|e| err!(Database(error!("Invalid room create event in database: {e}")))) + pub async fn get_room_type(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .and_then(|content: RoomCreateEventContent| { + content + .room_type + .ok_or_else(|| err!(Request(NotFound("No type found in event content")))) }) - .transpose()? - .and_then(|e| e.room_type)) } /// Gets the room's encryption algorithm if `m.room.encryption` state event /// is found - pub fn get_room_encryption(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomEncryption, "")? - .map_or(Ok(None), |s| { - serde_json::from_str::(s.content.get()) - .map(|content| Some(content.algorithm)) - .map_err(|e| err!(Database(error!("Invalid room encryption event in database: {e}")))) - }) + pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "") + .await + .map(|content: RoomEncryptionEventContent| content.algorithm) } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 19c73ea1..38e504f6 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,43 +1,42 @@ use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, sync::{Arc, RwLock}, }; -use conduit::{utils, Error, Result}; -use database::Map; -use itertools::Itertools; +use conduit::{utils, utils::stream::TryIgnore, Error, Result}; +use database::{Deserialized, Interfix, Map}; +use futures::{Stream, StreamExt}; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, RoomId, UserId, }; -use crate::{appservice::RegistrationInfo, globals, users, Dep}; +use crate::{globals, Dep}; -type StrippedStateEventIter<'a> = Box>)>> + 'a>; -type AnySyncStateEventIter<'a> = Box>)>> + 'a>; type AppServiceInRoomCache = RwLock>>; +type StrippedStateEventItem = (OwnedRoomId, Vec>); +type SyncStateEventItem = (OwnedRoomId, Vec>); pub(super) struct Data { pub(super) appservice_in_room_cache: AppServiceInRoomCache, - roomid_invitedcount: Arc, - roomid_inviteviaservers: Arc, - roomid_joinedcount: Arc, - roomserverids: Arc, - roomuserid_invitecount: Arc, - roomuserid_joined: Arc, - roomuserid_leftcount: Arc, - roomuseroncejoinedids: Arc, - serverroomids: Arc, - userroomid_invitestate: Arc, - userroomid_joined: Arc, - userroomid_leftstate: Arc, + pub(super) roomid_invitedcount: Arc, + pub(super) roomid_inviteviaservers: Arc, + pub(super) roomid_joinedcount: Arc, + pub(super) roomserverids: Arc, + pub(super) roomuserid_invitecount: Arc, + pub(super) roomuserid_joined: Arc, + pub(super) roomuserid_leftcount: Arc, + pub(super) roomuseroncejoinedids: Arc, + pub(super) serverroomids: Arc, + pub(super) userroomid_invitestate: Arc, + pub(super) userroomid_joined: Arc, + pub(super) userroomid_leftstate: Arc, services: Services, } struct Services { globals: Dep, - users: Dep, } impl Data { @@ -59,19 +58,18 @@ impl Data { userroomid_leftstate: db["userroomid_leftstate"].clone(), services: Services { globals: args.depend::("globals"), - users: args.depend::("users"), }, } } - pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.roomuseroncejoinedids.insert(&userroom_id, &[]) + self.roomuseroncejoinedids.insert(&userroom_id, &[]); } - pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { let roomid = room_id.as_bytes().to_vec(); let mut roomuser_id = roomid.clone(); @@ -82,64 +80,17 @@ impl Data { userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_joined.insert(&userroom_id, &[])?; - self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; + self.userroomid_joined.insert(&userroom_id, &[]); + self.roomuserid_joined.insert(&roomuser_id, &[]); + self.userroomid_invitestate.remove(&userroom_id); + self.roomuserid_invitecount.remove(&roomuser_id); + self.userroomid_leftstate.remove(&userroom_id); + self.roomuserid_leftcount.remove(&roomuser_id); - self.roomid_inviteviaservers.remove(&roomid)?; - - Ok(()) + self.roomid_inviteviaservers.remove(&roomid); } - pub(super) fn mark_as_invited( - &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, - invite_via: Option>, - ) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), - )?; - self.roomuserid_invitecount - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - if let Some(servers) = invite_via { - let mut prev_servers = self - .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect_vec(); - #[allow(clippy::redundant_clone)] // this is a necessary clone? - prev_servers.append(servers.clone().as_mut()); - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - } - - Ok(()) - } - - pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { let roomid = room_id.as_bytes().to_vec(); let mut roomuser_id = roomid.clone(); @@ -153,115 +104,20 @@ impl Data { self.userroomid_leftstate.insert( &userroom_id, &serde_json::to_vec(&Vec::>::new()).unwrap(), - )?; // TODO + ); // TODO self.roomuserid_leftcount - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + self.userroomid_joined.remove(&userroom_id); + self.roomuserid_joined.remove(&roomuser_id); + self.userroomid_invitestate.remove(&userroom_id); + self.roomuserid_invitecount.remove(&roomuser_id); - self.roomid_inviteviaservers.remove(&roomid)?; - - Ok(()) - } - - pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - let mut joinedcount = 0_u64; - let mut invitedcount = 0_u64; - let mut joined_servers = HashSet::new(); - - for joined in self.room_members(room_id).filter_map(Result::ok) { - joined_servers.insert(joined.server_name().to_owned()); - joinedcount = joinedcount.saturating_add(1); - } - - for _invited in self.room_members_invited(room_id).filter_map(Result::ok) { - invitedcount = invitedcount.saturating_add(1); - } - - self.roomid_joinedcount - .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; - - self.roomid_invitedcount - .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; - - for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - - self.appservice_in_room_cache - .write() - .unwrap() - .remove(room_id); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] - pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { - let maybe = self - .appservice_in_room_cache - .read() - .unwrap() - .get(room_id) - .and_then(|map| map.get(&appservice.registration.id)) - .copied(); - - if let Some(b) = maybe { - Ok(b) - } else { - let bridge_user_id = UserId::parse_with_server_name( - appservice.registration.sender_localpart.as_str(), - self.services.globals.server_name(), - ) - .ok(); - - let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) - || self - .room_members(room_id) - .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); - - self.appservice_in_room_cache - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default() - .insert(appservice.registration.id.clone(), in_room); - - Ok(in_room) - } + self.roomid_inviteviaservers.remove(&roomid); } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -270,397 +126,69 @@ impl Data { roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_servers<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { - ServerName::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) - })) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - let mut key = server.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.serverroomids.get(&key).map(|o| o.is_some()) - } - - /// Returns an iterator of all rooms a server participates in (as far as we - /// know). - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn server_rooms<'a>( - &'a self, server: &ServerName, - ) -> Box> + 'a> { - let mut prefix = server.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) - })) - } - - /// Returns an iterator of all joined members of a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_members<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + Send + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) - })) - } - - /// Returns an iterator of all our local users in the room, even if they're - /// deactivated/guests - pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a> { - Box::new( - self.room_members(room_id) - .filter_map(Result::ok) - .filter(|user| self.services.globals.user_is_local(user)), - ) - } - - /// Returns an iterator of all our local joined users in a room who are - /// active (not deactivated, not guest) - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn active_local_users_in_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box + 'a> { - Box::new( - self.local_users_in_room(room_id) - .filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)), - ) - } - - /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result> { - self.roomid_joinedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result> { - self.roomid_invitedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_useroncejoined<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuseroncejoinedids - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - }), - ) - } - - /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_members_invited<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuserid_invitecount - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_invitecount - .get(&key)? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, - )) - }) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_leftcount - .get(&key)? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db."))) - .transpose() - } - - /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box> + '_> { - Box::new( - self.userroomid_joined - .scan_prefix(user_id.as_bytes().to_vec()) - .map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) - }), - ) + self.userroomid_leftstate.remove(&userroom_id); + self.roomuserid_leftcount.remove(&roomuser_id); } /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn invite_state( - &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - + #[inline] + pub(super) fn rooms_invited<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); self.userroomid_invitestate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap(); + let room_id = utils::string_from_bytes(room_id).unwrap(); + let room_id = RoomId::parse(room_id).unwrap(); + let state = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate.")) + .unwrap(); - Ok(state) + (room_id, state) }) - .transpose() } #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn left_state( + pub(super) async fn invite_state( &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); + ) -> Result>> { + let key = (user_id, room_id); + self.userroomid_invitestate + .qry(&key) + .await + .deserialized_json() + } + #[tracing::instrument(skip(self), level = "debug")] + pub(super) async fn left_state( + &self, user_id: &UserId, room_id: &RoomId, + ) -> Result>> { + let key = (user_id, room_id); self.userroomid_leftstate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok(state) - }) - .transpose() + .qry(&key) + .await + .deserialized_json() } /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); + #[inline] + pub(super) fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); + self.userroomid_leftstate + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap(); + let room_id = utils::string_from_bytes(room_id).unwrap(); + let room_id = RoomId::parse(room_id).unwrap(); + let state = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate.")) + .unwrap(); - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn servers_invite_via<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let key = room_id.as_bytes().to_vec(); - - Box::new( - self.roomid_inviteviaservers - .scan_prefix(key) - .map(|(_, servers)| { - ServerName::parse( - utils::string_from_bytes( - servers - .rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { - let mut prev_servers = self - .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect_vec(); - prev_servers.extend(servers.to_owned()); - prev_servers.sort_unstable(); - prev_servers.dedup(); - - let servers = prev_servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - - Ok(()) + (room_id, state) + }) } } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 71899ceb..ce5b024b 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,9 +1,15 @@ mod data; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use conduit::{err, error, warn, Error, Result}; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt}, + warn, Result, +}; use data::Data; +use database::{Deserialized, Ignore, Interfix}; +use futures::{Stream, StreamExt}; use itertools::Itertools; use ruma::{ events::{ @@ -18,7 +24,7 @@ use ruma::{ }, int, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, }; use crate::{account_data, appservice::RegistrationInfo, globals, rooms, users, Dep}; @@ -55,7 +61,7 @@ impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] #[allow(clippy::too_many_arguments)] - pub fn update_membership( + pub async fn update_membership( &self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId, last_state: Option>>, invite_via: Option>, update_joined_count: bool, @@ -68,7 +74,7 @@ impl Service { // update #[allow(clippy::collapsible_if)] if !self.services.globals.user_is_local(user_id) { - if !self.services.users.exists(user_id)? { + if !self.services.users.exists(user_id).await { self.services.users.create(user_id, None)?; } @@ -100,17 +106,17 @@ impl Service { match &membership { MembershipState::Join => { // Check if the user never joined this room - if !self.once_joined(user_id, room_id)? { + if !self.once_joined(user_id, room_id).await { // Add the user ID to the join list then - self.db.mark_as_once_joined(user_id, room_id)?; + self.db.mark_as_once_joined(user_id, room_id); // Check if the room has a predecessor - if let Some(predecessor) = self + if let Ok(Some(predecessor)) = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .and_then(|create| serde_json::from_str(create.content.get()).ok()) - .and_then(|content: RoomCreateEventContent| content.predecessor) + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.predecessor) { // Copy user settings from predecessor to the current room: // - Push rules @@ -138,32 +144,33 @@ impl Service { // .ok(); // Copy old tags to new room - if let Some(tag_event) = self + if let Ok(tag_event) = self .services .account_data - .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? - .map(|event| { + .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag) + .await + .and_then(|event| { serde_json::from_str(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { self.services .account_data - .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?) + .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event) + .await .ok(); }; // Copy direct chat flag - if let Some(direct_event) = self + if let Ok(mut direct_event) = self .services .account_data - .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? - .map(|event| { + .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into()) + .await + .and_then(|event| { serde_json::from_str::(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { - let mut direct_event = direct_event?; let mut room_ids_updated = false; - for room_ids in direct_event.content.0.values_mut() { if room_ids.iter().any(|r| r == &predecessor.room_id) { room_ids.push(room_id.to_owned()); @@ -172,18 +179,21 @@ impl Service { } if room_ids_updated { - self.services.account_data.update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &serde_json::to_value(&direct_event).expect("to json always works"), - )?; + self.services + .account_data + .update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event).expect("to json always works"), + ) + .await?; } }; } } - self.db.mark_as_joined(user_id, room_id)?; + self.db.mark_as_joined(user_id, room_id); }, MembershipState::Invite => { // We want to know if the sender is ignored by the receiver @@ -196,12 +206,12 @@ impl Service { GlobalAccountDataEventType::IgnoredUserList .to_string() .into(), - )? - .map(|event| { + ) + .await + .and_then(|event| { serde_json::from_str::(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) - .transpose()? .map_or(false, |ignored| { ignored .content @@ -214,194 +224,282 @@ impl Service { return Ok(()); } - self.db - .mark_as_invited(user_id, room_id, last_state, invite_via)?; + self.mark_as_invited(user_id, room_id, last_state, invite_via) + .await; }, MembershipState::Leave | MembershipState::Ban => { - self.db.mark_as_left(user_id, room_id)?; + self.db.mark_as_left(user_id, room_id); }, _ => {}, } if update_joined_count { - self.update_joined_count(room_id)?; + self.update_joined_count(room_id).await; } Ok(()) } - #[tracing::instrument(skip(self, room_id), level = "debug")] - pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } - #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] - pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { - self.db.appservice_in_room(room_id, appservice) + pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool { + let maybe = self + .db + .appservice_in_room_cache + .read() + .unwrap() + .get(room_id) + .and_then(|map| map.get(&appservice.registration.id)) + .copied(); + + if let Some(b) = maybe { + b + } else { + let bridge_user_id = UserId::parse_with_server_name( + appservice.registration.sender_localpart.as_str(), + self.services.globals.server_name(), + ) + .ok(); + + let in_room = if let Some(id) = &bridge_user_id { + self.is_joined(id, room_id).await + } else { + false + }; + + let in_room = in_room + || self + .room_members(room_id) + .ready_any(|userid| appservice.users.is_match(userid.as_str())) + .await; + + self.db + .appservice_in_room_cache + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default() + .insert(appservice.registration.id.clone(), in_room); + + in_room + } } /// Direct DB function to directly mark a user as left. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.mark_as_left(user_id, room_id) - } + pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_left(user_id, room_id); } /// Direct DB function to directly mark a user as joined. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.mark_as_joined(user_id, room_id) - } + pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_joined(user_id, room_id); } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { self.db.forget(room_id, user_id); } /// Returns an iterator of all servers participating in this room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_servers(room_id) + pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomserverids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, server): (Ignore, &ServerName)| server) } #[tracing::instrument(skip(self), level = "debug")] - pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - self.db.server_in_room(server, room_id) + pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool { + let key = (server, room_id); + self.db.serverroomids.qry(&key).await.is_ok() } /// Returns an iterator of all rooms a server participates in (as far as we /// know). #[tracing::instrument(skip(self), level = "debug")] - pub fn server_rooms(&self, server: &ServerName) -> impl Iterator> + '_ { - self.db.server_rooms(server) + pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream + Send + 'a { + let prefix = (server, Interfix); + self.db + .serverroomids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) } /// Returns true if server can see user by sharing at least one room. #[tracing::instrument(skip(self), level = "debug")] - pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result { - Ok(self - .server_rooms(server) - .filter_map(Result::ok) - .any(|room_id: OwnedRoomId| self.is_joined(user_id, &room_id).unwrap_or(false))) + pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool { + self.server_rooms(server) + .any(|room_id| self.is_joined(user_id, room_id)) + .await } /// Returns true if user_a and user_b share at least one room. #[tracing::instrument(skip(self), level = "debug")] - pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result { + pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool { // Minimize number of point-queries by iterating user with least nr rooms - let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() { + let (a, b) = if self.rooms_joined(user_a).count().await < self.rooms_joined(user_b).count().await { (user_a, user_b) } else { (user_b, user_a) }; - Ok(self - .rooms_joined(a) - .filter_map(Result::ok) - .any(|room_id| self.is_joined(b, &room_id).unwrap_or(false))) + self.rooms_joined(a) + .any(|room_id| self.is_joined(b, room_id)) + .await } - /// Returns an iterator over all joined members of a room. + /// Returns an iterator of all joined members of a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_members(&self, room_id: &RoomId) -> impl Iterator> + Send + '_ { - self.db.room_members(room_id) + pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_joined + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } /// Returns the number of users which are currently in a room #[tracing::instrument(skip(self), level = "debug")] - pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.db.room_joined_count(room_id) } + pub async fn room_joined_count(&self, room_id: &RoomId) -> Result { + self.db.roomid_joinedcount.qry(room_id).await.deserialized() + } #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local users in the room, even if they're /// deactivated/guests - pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { - self.db.local_users_in_room(room_id) + pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + self.room_members(room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) } #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) - pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { - self.db.active_local_users_in_room(room_id) + pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + self.local_users_in_room(room_id) + .filter(|user| self.services.users.is_active(user)) } /// Returns the number of users which are currently invited to a room #[tracing::instrument(skip(self), level = "debug")] - pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.db.room_invited_count(room_id) } + pub async fn room_invited_count(&self, room_id: &RoomId) -> Result { + self.db + .roomid_invitedcount + .qry(room_id) + .await + .deserialized() + } /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_useroncejoined(room_id) + pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuseroncejoinedids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } /// Returns an iterator over all invited members of a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_members_invited(room_id) + pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_invitecount + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.get_invite_count(room_id, user_id) + pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db + .roomuserid_invitecount + .qry(&key) + .await + .deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.get_left_count(room_id, user_id) + pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db.roomuserid_leftcount.qry(&key).await.deserialized() } /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator> + '_ { - self.db.rooms_joined(user_id) + pub fn rooms_joined(&self, user_id: &UserId) -> impl Stream + Send { + self.db + .userroomid_joined + .keys_prefix(user_id) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) } /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_invited( - &self, user_id: &UserId, - ) -> impl Iterator>)>> + '_ { + pub fn rooms_invited<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream>)> + Send + 'a { self.db.rooms_invited(user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { - self.db.invite_state(user_id, room_id) + pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { + self.db.invite_state(user_id, room_id).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { - self.db.left_state(user_id, room_id) + pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { + self.db.left_state(user_id, room_id).await } /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_left( - &self, user_id: &UserId, - ) -> impl Iterator>)>> + '_ { + pub fn rooms_left<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream>)> + Send + 'a { self.db.rooms_left(user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.once_joined(user_id, room_id) + pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.roomuseroncejoinedids.qry(&key).await.is_ok() } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_joined(user_id, room_id) } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.is_invited(user_id, room_id) + pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_joined.qry(&key).await.is_ok() } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } + pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_invitestate.qry(&key).await.is_ok() + } #[tracing::instrument(skip(self), level = "debug")] - pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.servers_invite_via(room_id) + pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_leftstate.qry(&key).await.is_ok() + } + + #[tracing::instrument(skip(self), level = "debug")] + pub fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> impl Stream + Send + 'a { + self.db + .roomid_inviteviaservers + .stream_prefix(room_id) + .ignore_err() + .map(|(_, servers): (Ignore, Vec<&ServerName>)| &**(servers.last().expect("at least one servername"))) } /// Gets up to three servers that are likely to be in the room in the @@ -409,37 +507,27 @@ impl Service { /// /// See #[tracing::instrument(skip(self))] - pub fn servers_route_via(&self, room_id: &RoomId) -> Result> { + pub async fn servers_route_via(&self, room_id: &RoomId) -> Result> { let most_powerful_user_server = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map(|pdu| { - serde_json::from_str(pdu.content.get()).map(|conent: RoomPowerLevelsEventContent| { - conent - .users - .iter() - .max_by_key(|(_, power)| *power) - .and_then(|x| { - if x.1 >= &int!(50) { - Some(x) - } else { - None - } - }) - .map(|(user, _power)| user.server_name().to_owned()) - }) + .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") + .await + .map(|content: RoomPowerLevelsEventContent| { + content + .users + .iter() + .max_by_key(|(_, power)| *power) + .and_then(|x| (x.1 >= &int!(50)).then_some(x)) + .map(|(user, _power)| user.server_name().to_owned()) }) - .transpose() - .map_err(|e| { - error!("Invalid power levels event content in database: {e}"); - Error::bad_database("Invalid power levels event content in database") - })? - .flatten(); + .map_err(|e| err!(Database(error!(?e, "Invalid power levels event content in database."))))?; let mut servers: Vec = self .room_members(room_id) - .filter_map(Result::ok) + .collect::>() + .await + .iter() .counts_by(|user| user.server_name().to_owned()) .iter() .sorted_by_key(|(_, users)| *users) @@ -468,4 +556,139 @@ impl Service { .expect("locked") .clear(); } + + pub async fn update_joined_count(&self, room_id: &RoomId) { + let mut joinedcount = 0_u64; + let mut invitedcount = 0_u64; + let mut joined_servers = HashSet::new(); + + self.room_members(room_id) + .ready_for_each(|joined| { + joined_servers.insert(joined.server_name().to_owned()); + joinedcount = joinedcount.saturating_add(1); + }) + .await; + + invitedcount = invitedcount.saturating_add( + self.room_members_invited(room_id) + .count() + .await + .try_into() + .unwrap_or(0), + ); + + self.db + .roomid_joinedcount + .insert(room_id.as_bytes(), &joinedcount.to_be_bytes()); + + self.db + .roomid_invitedcount + .insert(room_id.as_bytes(), &invitedcount.to_be_bytes()); + + self.room_servers(room_id) + .ready_for_each(|old_joined_server| { + if !joined_servers.remove(old_joined_server) { + // Server not in room anymore + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(old_joined_server.as_bytes()); + + let mut serverroom_id = old_joined_server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.roomserverids.remove(&roomserver_id); + self.db.serverroomids.remove(&serverroom_id); + } + }) + .await; + + // Now only new servers are in joined_servers anymore + for server in joined_servers { + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(server.as_bytes()); + + let mut serverroom_id = server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.roomserverids.insert(&roomserver_id, &[]); + self.db.serverroomids.insert(&serverroom_id, &[]); + } + + self.db + .appservice_in_room_cache + .write() + .unwrap() + .remove(room_id); + } + + pub async fn mark_as_invited( + &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, + invite_via: Option>, + ) { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.userroomid_invitestate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), + ); + self.db + .roomuserid_invitecount + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + + if let Some(servers) = invite_via { + let mut prev_servers = self + .servers_invite_via(room_id) + .map(ToOwned::to_owned) + .collect::>() + .await; + #[allow(clippy::redundant_clone)] // this is a necessary clone? + prev_servers.append(servers.clone().as_mut()); + let servers = prev_servers.iter().rev().unique().rev().collect_vec(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.db + .roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers); + } + } + + #[tracing::instrument(skip(self), level = "debug")] + pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) { + let mut prev_servers = self + .servers_invite_via(room_id) + .map(ToOwned::to_owned) + .collect::>() + .await; + prev_servers.extend(servers.to_owned()); + prev_servers.sort_unstable(); + prev_servers.dedup(); + + let servers = prev_servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.db + .roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers); + } } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 33773001..9a9f70a2 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; -use conduit::{checked, utils, Error, Result}; +use conduit::{err, expected, utils, Result}; use database::{Database, Map}; use super::CompressedStateEvent; @@ -22,11 +22,13 @@ impl Data { } } - pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result { + pub(super) async fn get_statediff(&self, shortstatehash: u64) -> Result { let value = self .shortstatehash_statediff - .get(&shortstatehash.to_be_bytes())? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + .qry(&shortstatehash) + .await + .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; + let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); let parent = if parent != 0 { Some(parent) @@ -40,10 +42,10 @@ impl Data { let stride = size_of::(); let mut i = stride; - while let Some(v) = value.get(i..checked!(i + 2 * stride)?) { + while let Some(v) = value.get(i..expected!(i + 2 * stride)) { if add_mode && v.starts_with(&0_u64.to_be_bytes()) { add_mode = false; - i = checked!(i + stride)?; + i = expected!(i + stride); continue; } if add_mode { @@ -51,7 +53,7 @@ impl Data { } else { removed.insert(v.try_into().expect("we checked the size above")); } - i = checked!(i + 2 * stride)?; + i = expected!(i + 2 * stride); } Ok(StateDiff { @@ -61,7 +63,7 @@ impl Data { }) } - pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) -> Result<()> { + pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) { let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); for new in diff.added.iter() { value.extend_from_slice(&new[..]); @@ -75,6 +77,6 @@ impl Data { } self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value) + .insert(&shortstatehash.to_be_bytes(), &value); } } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 2550774e..cd3f2f73 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -27,14 +27,12 @@ type StateInfoLruCache = Mutex< >, >; -type ShortStateInfoResult = Result< - Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed - )>, ->; +type ShortStateInfoResult = Vec<( + u64, // sstatehash + Arc>, // full state + Arc>, // added + Arc>, // removed +)>; type ParentStatesVec = Vec<( u64, // sstatehash @@ -43,7 +41,7 @@ type ParentStatesVec = Vec<( Arc>, // removed )>; -type HashSetCompressStateEvent = Result<(u64, Arc>, Arc>)>; +type HashSetCompressStateEvent = (u64, Arc>, Arc>); pub type CompressedStateEvent = [u8; 2 * size_of::()]; pub struct Service { @@ -86,12 +84,11 @@ impl crate::Service for Service { impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(skip(self), level = "debug")] - pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { + pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result { if let Some(r) = self .stateinfo_cache .lock() - .unwrap() + .expect("locked") .get_mut(&shortstatehash) { return Ok(r.clone()); @@ -101,11 +98,11 @@ impl Service { parent, added, removed, - } = self.db.get_statediff(shortstatehash)?; + } = self.db.get_statediff(shortstatehash).await?; if let Some(parent) = parent { - let mut response = self.load_shortstatehash_info(parent)?; - let mut state = (*response.last().unwrap().1).clone(); + let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; + let mut state = (*response.last().expect("at least one response").1).clone(); state.extend(added.iter().copied()); let removed = (*removed).clone(); for r in &removed { @@ -116,7 +113,7 @@ impl Service { self.stateinfo_cache .lock() - .unwrap() + .expect("locked") .insert(shortstatehash, response.clone()); Ok(response) @@ -124,33 +121,42 @@ impl Service { let response = vec![(shortstatehash, added.clone(), added, removed)]; self.stateinfo_cache .lock() - .unwrap() + .expect("locked") .insert(shortstatehash, response.clone()); + Ok(response) } } - pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result { + pub async fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> CompressedStateEvent { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( &self .services .short - .get_or_create_shorteventid(event_id)? + .get_or_create_shorteventid(event_id) + .await .to_be_bytes(), ); - Ok(v.try_into().expect("we checked the size above")) + + v.try_into().expect("we checked the size above") } /// Returns shortstatekey, event id #[inline] - pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc)> { - Ok(( - utils::u64_from_bytes(&compressed_event[0..size_of::()]).expect("bytes have right length"), - self.services.short.get_eventid_from_short( - utils::u64_from_bytes(&compressed_event[size_of::()..]).expect("bytes have right length"), - )?, - )) + pub async fn parse_compressed_state_event( + &self, compressed_event: &CompressedStateEvent, + ) -> Result<(u64, Arc)> { + use utils::u64_from_u8; + + let shortstatekey = u64_from_u8(&compressed_event[0..size_of::()]); + let event_id = self + .services + .short + .get_eventid_from_short(u64_from_u8(&compressed_event[size_of::()..])) + .await?; + + Ok((shortstatekey, event_id)) } /// Creates a new shortstatehash that often is just a diff to an already @@ -227,7 +233,7 @@ impl Service { added: statediffnew, removed: statediffremoved, }, - )?; + ); return Ok(()); }; @@ -280,7 +286,7 @@ impl Service { added: statediffnew, removed: statediffremoved, }, - )?; + ); } Ok(()) @@ -288,10 +294,15 @@ impl Service { /// Returns the new shortstatehash, and the state diff from the previous /// room state - pub fn save_state( + pub async fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, - ) -> HashSetCompressStateEvent { - let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?; + ) -> Result { + let previous_shortstatehash = self + .services + .state + .get_room_shortstatehash(room_id) + .await + .ok(); let state_hash = utils::calculate_hash( &new_state_ids_compressed @@ -303,14 +314,18 @@ impl Service { let (new_shortstatehash, already_existed) = self .services .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await; if Some(new_shortstatehash) == previous_shortstatehash { return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); } - let states_parents = - previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + let states_parents = if let Some(p) = previous_shortstatehash { + self.load_shortstatehash_info(p).await.unwrap_or_default() + } else { + ShortStateInfoResult::new() + }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = new_state_ids_compressed diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index fb279a00..f50b812c 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,13 +1,18 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{checked, utils, Error, PduEvent, Result}; -use database::Map; +use conduit::{ + checked, + result::LogErr, + utils, + utils::{stream::TryIgnore, ReadyExt}, + PduEvent, Result, +}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; use crate::{rooms, Dep}; -type PduEventIterResult<'a> = Result> + 'a>>; - pub(super) struct Data { threadid_userids: Arc, services: Services, @@ -30,38 +35,37 @@ impl Data { } } - pub(super) fn threads_until<'a>( + pub(super) async fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, - ) -> PduEventIterResult<'a> { + ) -> Result + Send + 'a> { let prefix = self .services .short - .get_shortroomid(room_id)? - .expect("room exists") + .get_shortroomid(room_id) + .await? .to_be_bytes() .to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); - Ok(Box::new( - self.threadid_userids - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pduid, _users)| { - let count = utils::u64_from_bytes(&pduid[(size_of::())..]) - .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; - let mut pdu = self - .services - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((count, pdu)) - }), - )) + let stream = self + .threadid_userids + .rev_raw_keys_from(¤t) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|pduid| (utils::u64_from_u8(&pduid[(size_of::())..]), pduid)) + .filter_map(move |(count, pduid)| async move { + let mut pdu = self.services.timeline.get_pdu_from_id(pduid).await.ok()?; + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + Some((count, pdu)) + }); + + Ok(stream) } pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { @@ -71,28 +75,12 @@ impl Data { .collect::>() .join(&[0xFF][..]); - self.threadid_userids.insert(root_id, &users)?; + self.threadid_userids.insert(root_id, &users); Ok(()) } - pub(super) fn get_participants(&self, root_id: &[u8]) -> Result>> { - if let Some(users) = self.threadid_userids.get(root_id)? { - Ok(Some( - users - .split(|b| *b == 0xFF) - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(bytes) - .map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?, - ) - .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) - }) - .filter_map(Result::ok) - .collect(), - )) - } else { - Ok(None) - } + pub(super) async fn get_participants(&self, root_id: &[u8]) -> Result> { + self.threadid_userids.qry(root_id).await.deserialized() } } diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index ae51cd0f..2eafe5d5 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,12 +2,12 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Error, PduEvent, Result}; +use conduit::{err, PduEvent, Result}; use data::Data; +use futures::Stream; use ruma::{ - api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, - events::relation::BundledThread, - uint, CanonicalJsonValue, EventId, RoomId, UserId, + api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue, + EventId, RoomId, UserId, }; use serde_json::json; @@ -36,30 +36,35 @@ impl crate::Service for Service { } impl Service { - pub fn threads_until<'a>( + pub async fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, - ) -> Result> + 'a> { - self.db.threads_until(user_id, room_id, until, include) + ) -> Result + Send + 'a> { + self.db + .threads_until(user_id, room_id, until, include) + .await } - pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { + pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { let root_id = self .services .timeline - .get_pdu_id(root_event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?; + .get_pdu_id(root_event_id) + .await + .map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?; let root_pdu = self .services .timeline - .get_pdu_from_id(&root_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; + .get_pdu_from_id(&root_id) + .await + .map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?; let mut root_pdu_json = self .services .timeline - .get_pdu_json_from_id(&root_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; + .get_pdu_json_from_id(&root_id) + .await + .map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?; if let CanonicalJsonValue::Object(unsigned) = root_pdu_json .entry("unsigned".to_owned()) @@ -103,11 +108,12 @@ impl Service { self.services .timeline - .replace_pdu(&root_id, &root_pdu_json, &root_pdu)?; + .replace_pdu(&root_id, &root_pdu_json, &root_pdu) + .await?; } let mut users = Vec::new(); - if let Some(userids) = self.db.get_participants(&root_id)? { + if let Ok(userids) = self.db.get_participants(&root_id).await { users.extend_from_slice(&userids); } else { users.push(root_pdu.sender); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 2f0c8f25..cd746be4 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,12 +1,20 @@ use std::{ collections::{hash_map, HashMap}, mem::size_of, - sync::{Arc, Mutex}, + sync::Arc, }; -use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result}; -use database::{Database, Map}; -use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use conduit::{ + err, expected, + result::{LogErr, NotFound}, + utils, + utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, + Err, PduCount, PduEvent, Result, +}; +use database::{Database, Deserialized, KeyVal, Map}; +use futures::{FutureExt, Stream, StreamExt}; +use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use tokio::sync::Mutex; use crate::{rooms, Dep}; @@ -25,8 +33,7 @@ struct Services { short: Dep, } -type PdusIterItem = Result<(PduCount, PduEvent)>; -type PdusIterator<'a> = Box + 'a>; +pub type PdusIterItem = (PduCount, PduEvent); type LastTimelineCountCache = Mutex>; impl Data { @@ -46,23 +53,20 @@ impl Data { } } - pub(super) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + pub(super) async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache .lock() - .expect("locked") + .await .entry(room_id.to_owned()) { hash_map::Entry::Vacant(v) => { if let Some(last_count) = self - .pdus_until(sender_user, room_id, PduCount::max())? - .find_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) { + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .next() + .await + { Ok(*v.insert(last_count.0)) } else { Ok(PduCount::Normal(0)) @@ -73,232 +77,215 @@ impl Data { } /// Returns the `count` of this pdu's id. - pub(super) fn get_pdu_count(&self, event_id: &EventId) -> Result> { + pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result { self.eventid_pduid - .get(event_id.as_bytes())? + .qry(event_id) + .await .map(|pdu_id| pdu_count(&pdu_id)) - .transpose() } /// Returns the json of a pdu. - pub(super) fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.get_non_outlier_pdu_json(event_id)?.map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - ) + pub(super) async fn get_pdu_json(&self, event_id: &EventId) -> Result { + if let Ok(pdu) = self.get_non_outlier_pdu_json(event_id).await { + return Ok(pdu); + } + + self.eventid_outlierpdu + .qry(event_id) + .await + .deserialized_json() } /// Returns the json of a pdu. - pub(super) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() + pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.qry(&pduid).await.deserialized_json() } /// Returns the pdu's id. #[inline] - pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.eventid_pduid.get(event_id.as_bytes()) + pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result> { + self.eventid_pduid.qry(event_id).await } /// Returns the pdu directly from `eventid_pduid` only. - pub(super) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() + pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.qry(&pduid).await.deserialized_json() + } + + /// Like get_non_outlier_pdu(), but without the expense of fetching and + /// parsing the PduEvent + pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.qry(&pduid).await?; + + Ok(()) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub(super) fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(pdu) = self - .get_non_outlier_pdu(event_id)? - .map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - )? - .map(Arc::new) - { - Ok(Some(pdu)) - } else { - Ok(None) + pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result> { + if let Ok(pdu) = self.get_non_outlier_pdu(event_id).await { + return Ok(Arc::new(pdu)); } + + self.eventid_outlierpdu + .qry(event_id) + .await + .deserialized_json() + .map(Arc::new) + } + + /// Like get_non_outlier_pdu(), but without the expense of fetching and + /// parsing the PduEvent + pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { + self.eventid_outlierpdu.qry(event_id).await?; + + Ok(()) + } + + /// Like get_pdu(), but without the expense of fetching and parsing the data + pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool { + let non_outlier = self.non_outlier_pdu_exists(event_id).map(|res| res.is_ok()); + let outlier = self.outlier_pdu_exists(event_id).map(|res| res.is_ok()); + + //TODO: parallelize + non_outlier.await || outlier.await } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub(super) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { + self.pduid_pdu.qry(pdu_id).await.deserialized_json() } /// Returns the pdu as a `BTreeMap`. - pub(super) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + self.pduid_pdu.qry(pdu_id).await.deserialized_json() } - pub(super) fn append_pdu( - &self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64, - ) -> Result<()> { + pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { self.pduid_pdu.insert( pdu_id, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; + ); self.lasttimelinecount_cache .lock() - .expect("locked") + .await .insert(pdu.room_id.clone(), PduCount::Normal(count)); - self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; - - Ok(()) + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id); + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes()); } - pub(super) fn prepend_backfill_pdu( - &self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject, - ) -> Result<()> { + pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) { self.pduid_pdu.insert( pdu_id, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; + ); - self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(event_id.as_bytes())?; - - Ok(()) + self.eventid_pduid.insert(event_id.as_bytes(), pdu_id); + self.eventid_outlierpdu.remove(event_id.as_bytes()); } /// Removes a pdu and creates a new one with the same id. - pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); + pub(super) async fn replace_pdu( + &self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, + ) -> Result<()> { + if self.pduid_pdu.qry(pdu_id).await.is_not_found() { + return Err!(Request(NotFound("PDU does not exist."))); } + let pdu = serde_json::to_vec(pdu_json)?; + self.pduid_pdu.insert(pdu_id, &pdu); + Ok(()) } /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result> { - let (prefix, current) = self.count_to_id(room_id, until, 1, true)?; + pub(super) async fn pdus_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, + ) -> Result + Send + 'a> { + let (prefix, current) = self.count_to_id(room_id, until, 1, true).await?; + let stream = self + .pduid_pdu + .rev_raw_stream_from(¤t) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .map(move |item| Self::each_pdu(item, user_id)); - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) + Ok(stream) } - pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result> { - let (prefix, current) = self.count_to_id(room_id, from, 1, false)?; + pub(super) async fn pdus_after<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, + ) -> Result + Send + 'a> { + let (prefix, current) = self.count_to_id(room_id, from, 1, false).await?; + let stream = self + .pduid_pdu + .raw_stream_from(¤t) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .map(move |item| Self::each_pdu(item, user_id)); - let user_id = user_id.to_owned(); + Ok(stream) + } - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) + fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem { + let mut pdu = + serde_json::from_slice::(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON"); + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + pdu.add_age().log_err().ok(); + let count = pdu_count(pdu_id); + + (count, pdu) } pub(super) fn increment_notification_counts( &self, room_id: &RoomId, notifies: Vec, highlights: Vec, - ) -> Result<()> { - let mut notifies_batch = Vec::new(); - let mut highlights_batch = Vec::new(); + ) { + let _cork = self.db.cork(); + for user in notifies { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - notifies_batch.push(userroom_id); + increment(&self.userroomid_notificationcount, &userroom_id); } + for user in highlights { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - highlights_batch.push(userroom_id); + increment(&self.userroomid_highlightcount, &userroom_id); } - - self.userroomid_notificationcount - .increment_batch(notifies_batch.iter().map(Vec::as_slice))?; - self.userroomid_highlightcount - .increment_batch(highlights_batch.iter().map(Vec::as_slice))?; - Ok(()) } - pub(super) fn count_to_id( + pub(super) async fn count_to_id( &self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, ) -> Result<(Vec, Vec)> { let prefix = self .services .short - .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? + .get_shortroomid(room_id) + .await + .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))? .to_be_bytes() .to_vec(); + let mut pdu_id = prefix.clone(); // +1 so we don't send the base event let count_raw = match count { @@ -326,17 +313,23 @@ impl Data { } /// Returns the `count` of this pdu's id. -pub(super) fn pdu_count(pdu_id: &[u8]) -> Result { - let stride = size_of::(); - let pdu_id_len = pdu_id.len(); - let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; - let second_last_u64 = - utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]); +pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount { + const STRIDE: usize = size_of::(); - if matches!(second_last_u64, Ok(0)) { - Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) + let pdu_id_len = pdu_id.len(); + let last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - STRIDE)..]); + let second_last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - 2 * STRIDE)..expected!(pdu_id_len - STRIDE)]); + + if second_last_u64 == 0 { + PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)) } else { - Ok(PduCount::Normal(last_u64)) + PduCount::Normal(last_u64) } } + +//TODO: this is an ABA +fn increment(db: &Arc, key: &[u8]) { + let old = db.get(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, &new); +} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 04d9559d..5360d2c9 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,19 +1,20 @@ mod data; use std::{ + cmp, collections::{BTreeMap, HashSet}, fmt::Write, sync::Arc, }; use conduit::{ - debug, error, info, + debug, err, error, info, pdu::{EventHash, PduBuilder, PduCount, PduEvent}, utils, - utils::{MutexMap, MutexMapGuard}, - validated, warn, Error, Result, Server, + utils::{stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, + validated, warn, Err, Error, Result, Server, }; -use itertools::Itertools; +use futures::{future, future::ready, Future, Stream, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation}, canonical_json::to_canonical_value, @@ -39,6 +40,7 @@ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::RwLock; use self::data::Data; +pub use self::data::PdusIterItem; use crate::{ account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, rooms::state_compressor::CompressedStateEvent, sending, server_keys, Dep, @@ -129,6 +131,7 @@ impl crate::Service for Service { } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + /* let lasttimelinecount_cache = self .db .lasttimelinecount_cache @@ -136,6 +139,7 @@ impl crate::Service for Service { .expect("locked") .len(); writeln!(out, "lasttimelinecount_cache: {lasttimelinecount_cache}")?; + */ let mutex_insert = self.mutex_insert.len(); writeln!(out, "insert_mutex: {mutex_insert}")?; @@ -144,11 +148,13 @@ impl crate::Service for Service { } fn clear_cache(&self) { + /* self.db .lasttimelinecount_cache .lock() .expect("locked") .clear(); + */ } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -156,28 +162,32 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self), level = "debug")] - pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + pub async fn first_pdu_in_room(&self, room_id: &RoomId) -> Result> { + self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) + .await? .next() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() + .await + .map(|(_, p)| Arc::new(p)) + .ok_or_else(|| err!(Request(NotFound("No PDU found in room")))) } #[tracing::instrument(skip(self), level = "debug")] - pub fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - self.all_pdus(user_id!("@placeholder:conduwuit.placeholder"), room_id)? - .last() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() + pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result> { + self.pdus_until(user_id!("@placeholder:conduwuit.placeholder"), room_id, PduCount::max()) + .await? + .next() + .await + .map(|(_, p)| Arc::new(p)) + .ok_or_else(|| err!(Request(NotFound("No PDU found in room")))) } #[tracing::instrument(skip(self), level = "debug")] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - self.db.last_timeline_count(sender_user, room_id) + pub async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + self.db.last_timeline_count(sender_user, room_id).await } /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.db.get_pdu_count(event_id) } + pub async fn get_pdu_count(&self, event_id: &EventId) -> Result { self.db.get_pdu_count(event_id).await } // TODO Is this the same as the function above? /* @@ -203,49 +213,56 @@ impl Service { */ /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_json(event_id) + pub async fn get_pdu_json(&self, event_id: &EventId) -> Result { + self.db.get_pdu_json(event_id).await } /// Returns the json of a pdu. #[inline] - pub fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_non_outlier_pdu_json(event_id) + pub async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { + self.db.get_non_outlier_pdu_json(event_id).await } /// Returns the pdu's id. #[inline] - pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.db.get_pdu_id(event_id) + pub async fn get_pdu_id(&self, event_id: &EventId) -> Result> { + self.db.get_pdu_id(event_id).await } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. #[inline] - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.db.get_non_outlier_pdu(event_id) + pub async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { + self.db.get_non_outlier_pdu(event_id).await } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result>> { self.db.get_pdu(event_id) } + pub async fn get_pdu(&self, event_id: &EventId) -> Result> { self.db.get_pdu(event_id).await } + + /// Checks if pdu exists + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn pdu_exists<'a>(&'a self, event_id: &'a EventId) -> impl Future + Send + 'a { + self.db.pdu_exists(event_id) + } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.db.get_pdu_from_id(pdu_id) } + pub async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { self.db.get_pdu_from_id(pdu_id).await } /// Returns the pdu as a `BTreeMap`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.db.get_pdu_json_from_id(pdu_id) + pub async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + self.db.get_pdu_json_from_id(pdu_id).await } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self), level = "debug")] - pub fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { - self.db.replace_pdu(pdu_id, pdu_json, pdu) + pub async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + self.db.replace_pdu(pdu_id, pdu_json, pdu).await } /// Creates a new persisted data unit and adds it to a room. @@ -268,8 +285,9 @@ impl Service { let shortroomid = self .services .short - .get_shortroomid(&pdu.room_id)? - .expect("room exists"); + .get_shortroomid(&pdu.room_id) + .await + .map_err(|_| err!(Database("Room does not exist")))?; // Make unsigned fields correct. This is not properly documented in the spec, // but state events need to have previous content in the unsigned field, so @@ -279,17 +297,17 @@ impl Service { .entry("unsigned".to_owned()) .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) { - if let Some(shortstatehash) = self + if let Ok(shortstatehash) = self .services .state_accessor .pdu_shortstatehash(&pdu.event_id) - .unwrap() + .await { - if let Some(prev_state) = self + if let Ok(prev_state) = self .services .state_accessor .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) - .unwrap() + .await { unsigned.insert( "prev_content".to_owned(), @@ -318,10 +336,12 @@ impl Service { // We must keep track of all events that have been referenced. self.services .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + .mark_as_referenced(&pdu.room_id, &pdu.prev_events); + self.services .state - .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; + .set_forward_extremities(&pdu.room_id, leaves, state_lock) + .await; let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; @@ -330,17 +350,17 @@ impl Service { // appending fails self.services .read_receipt - .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + .private_read_set(&pdu.room_id, &pdu.sender, count1); self.services .user - .reset_notification_counts(&pdu.sender, &pdu.room_id)?; + .reset_notification_counts(&pdu.sender, &pdu.room_id); - let count2 = self.services.globals.next_count()?; + let count2 = self.services.globals.next_count().unwrap(); let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); // Insert pdu - self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?; + self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await; drop(insert_lock); @@ -348,12 +368,9 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? + .room_state_get_content(&pdu.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("invalid m.room.power_levels event"))) .unwrap_or_default(); let sync_pdu = pdu.to_sync_room_event(); @@ -365,7 +382,9 @@ impl Service { .services .state_cache .active_local_users_in_room(&pdu.room_id) - .collect_vec(); + .map(ToOwned::to_owned) + .collect::>() + .await; if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key { @@ -386,23 +405,20 @@ impl Service { let rules_for_user = self .services .account_data - .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())? - .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid push rules event in db for user ID {user}: {e}"); - Error::bad_database("Invalid push rules event in db.") - }) - }) - .transpose()? - .map_or_else(|| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); + .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) + .map_err(|e| err!(Database(warn!(?user, ?e, "Invalid push rules event in db for user")))) + .map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); let mut highlight = false; let mut notify = false; - for action in - self.services - .pusher - .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? + for action in self + .services + .pusher + .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id) + .await? { match action { Action::Notify => notify = true, @@ -421,31 +437,36 @@ impl Service { highlights.push(user.clone()); } - for push_key in self.services.pusher.get_pushkeys(user) { - self.services - .sending - .send_pdu_push(&pdu_id, user, push_key?)?; - } + self.services + .pusher + .get_pushkeys(user) + .ready_for_each(|push_key| { + self.services + .sending + .send_pdu_push(&pdu_id, user, push_key.to_owned()) + .expect("TODO: replace with future"); + }) + .await; } self.db - .increment_notification_counts(&pdu.room_id, notifies, highlights)?; + .increment_notification_counts(&pdu.room_id, notifies, highlights); match pdu.kind { TimelineEventType::RoomRedaction => { use RoomVersionId::*; - let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?; match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { - if self.services.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { - self.redact_pdu(redact_id, pdu, shortroomid)?; + if self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? + { + self.redact_pdu(redact_id, pdu, shortroomid).await?; } } }, @@ -457,13 +478,13 @@ impl Service { })?; if let Some(redact_id) = &content.redacts { - if self.services.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { - self.redact_pdu(redact_id, pdu, shortroomid)?; + if self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? + { + self.redact_pdu(redact_id, pdu, shortroomid).await?; } } }, @@ -492,7 +513,7 @@ impl Service { let invite_state = match content.membership { MembershipState::Invite => { - let state = self.services.state.calculate_invite_state(pdu)?; + let state = self.services.state.calculate_invite_state(pdu).await?; Some(state) }, _ => None, @@ -500,15 +521,18 @@ impl Service { // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - self.services.state_cache.update_membership( - &pdu.room_id, - &target_user_id, - content, - &pdu.sender, - invite_state, - None, - true, - )?; + self.services + .state_cache + .update_membership( + &pdu.room_id, + &target_user_id, + content, + &pdu.sender, + invite_state, + None, + true, + ) + .await?; } }, TimelineEventType::RoomMessage => { @@ -516,9 +540,7 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - self.services - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + self.services.search.index_pdu(shortroomid, &pdu_id, &body); if self.services.admin.is_admin_command(pdu, &body).await { self.services @@ -531,10 +553,10 @@ impl Service { } if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? { + if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + .add_relation(PduCount::Normal(count2), related_pducount); } } @@ -545,14 +567,17 @@ impl Service { } => { // We need to do it again here, because replies don't have // event_id as a top level field - if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? { + if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + .add_relation(PduCount::Normal(count2), related_pducount); } }, Relation::Thread(thread) => { - self.services.threads.add_to_thread(&thread.event_id, pdu)?; + self.services + .threads + .add_to_thread(&thread.event_id, pdu) + .await?; }, _ => {}, // TODO: Aggregate other types } @@ -562,7 +587,8 @@ impl Service { if self .services .state_cache - .appservice_in_room(&pdu.room_id, appservice)? + .appservice_in_room(&pdu.room_id, appservice) + .await { self.services .sending @@ -596,15 +622,14 @@ impl Service { .as_ref() .map_or(false, |state_key| users.is_match(state_key)) }; - let matching_aliases = |aliases: &NamespaceRegex| { + let matching_aliases = |aliases: NamespaceRegex| { self.services .alias .local_aliases_for_room(&pdu.room_id) - .filter_map(Result::ok) - .any(|room_alias| aliases.is_match(room_alias.as_str())) + .ready_any(move |room_alias| aliases.is_match(room_alias.as_str())) }; - if matching_aliases(&appservice.aliases) + if matching_aliases(appservice.aliases.clone()).await || appservice.rooms.is_match(pdu.room_id.as_str()) || matching_users(&appservice.users) { @@ -617,7 +642,7 @@ impl Service { Ok(pdu_id) } - pub fn create_hash_and_sign_event( + pub async fn create_hash_and_sign_event( &self, pdu_builder: PduBuilder, sender: &UserId, @@ -636,47 +661,59 @@ impl Service { let prev_events: Vec<_> = self .services .state - .get_forward_extremities(room_id)? - .into_iter() + .get_forward_extremities(room_id) .take(20) - .collect(); + .map(Arc::from) + .collect() + .await; // If there was no create event yet, assume we are creating a room - let room_version_id = self.services.state.get_room_version(room_id).or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - let content = serde_json::from_str::(content.get()) - .expect("Invalid content in RoomCreate pdu."); - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; + let room_version_id = self + .services + .state + .get_room_version(room_id) + .await + .or_else(|_| { + if event_type == TimelineEventType::RoomCreate { + let content = serde_json::from_str::(content.get()) + .expect("Invalid content in RoomCreate pdu."); + Ok(content.room_version) + } else { + Err(Error::InconsistentRoomState( + "non-create event for room of unknown version", + room_id.to_owned(), + )) + } + })?; let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - let auth_events = - self.services - .state - .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; + let auth_events = self + .services + .state + .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content) + .await?; // Our depth is the maximum depth of prev_events + 1 let depth = prev_events .iter() - .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) + .stream() + .map(Ok) + .and_then(|event_id| self.get_pdu(event_id)) + .and_then(|pdu| future::ok(pdu.depth)) + .ignore_err() + .ready_fold(uint!(0), cmp::max) + .await .saturating_add(uint!(1)); let mut unsigned = unsigned.unwrap_or_default(); if let Some(state_key) = &state_key { - if let Some(prev_pdu) = - self.services - .state_accessor - .room_state_get(room_id, &event_type.to_string().into(), state_key)? + if let Ok(prev_pdu) = self + .services + .state_accessor + .room_state_get(room_id, &event_type.to_string().into(), state_key) + .await { unsigned.insert( "prev_content".to_owned(), @@ -727,19 +764,22 @@ impl Service { signatures: None, }; + let auth_fetch = |k: &StateEventType, s: &str| { + let key = (k.clone(), s.to_owned()); + ready(auth_events.get(&key)) + }; + let auth_check = state_res::auth_check( &room_version, &pdu, - None::, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), + None, // TODO: third_party_invite + auth_fetch, ) - .map_err(|e| { - error!("Auth check failed: {:?}", e); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed.") - })?; + .await + .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Event is not authorized.")); + return Err!(Request(Forbidden("Event is not authorized."))); } // Hash and sign @@ -795,7 +835,8 @@ impl Service { let _shorteventid = self .services .short - .get_or_create_shorteventid(&pdu.event_id)?; + .get_or_create_shorteventid(&pdu.event_id) + .await; Ok((pdu, pdu_json)) } @@ -811,108 +852,117 @@ impl Service { room_id: &RoomId, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { - let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - if let Some(admin_room) = self.services.admin.get_admin_room()? { - if admin_room == room_id { - match pdu.event_type() { - TimelineEventType::RoomEncryption => { - warn!("Encryption is not allowed in the admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Encryption is not allowed in the admins room", - )); - }, - TimelineEventType::RoomMember => { - let target = pdu - .state_key() - .filter(|v| v.starts_with('@')) - .unwrap_or(sender.as_str()); - let server_user = &self.services.globals.server_user.to_string(); + let (pdu, pdu_json) = self + .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock) + .await?; - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu"))?; + if self.services.admin.is_admin_room(&pdu.room_id).await { + match pdu.event_type() { + TimelineEventType::RoomEncryption => { + warn!("Encryption is not allowed in the admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Encryption is not allowed in the admins room", + )); + }, + TimelineEventType::RoomMember => { + let target = pdu + .state_key() + .filter(|v| v.starts_with('@')) + .unwrap_or(sender.as_str()); + let server_user = &self.services.globals.server_user.to_string(); - if content.membership == MembershipState::Leave { - if target == server_user { - warn!("Server user cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot leave from admins room.", - )); - } + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu"))?; - let count = self - .services - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) - .count(); - if count < 2 { - warn!("Last admin cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot leave from admins room.", - )); - } + if content.membership == MembershipState::Leave { + if target == server_user { + warn!("Server user cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Server user cannot leave from admins room.", + )); } - if content.membership == MembershipState::Ban && pdu.state_key().is_some() { - if target == server_user { - warn!("Server user cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot be banned in admins room.", - )); - } + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; - let count = self - .services - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) - .count(); - if count < 2 { - warn!("Last admin cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot be banned in admins room.", - )); - } + if count < 2 { + warn!("Last admin cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Last admin cannot leave from admins room.", + )); } - }, - _ => {}, - } + } + + if content.membership == MembershipState::Ban && pdu.state_key().is_some() { + if target == server_user { + warn!("Server user cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Server user cannot be banned in admins room.", + )); + } + + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + warn!("Last admin cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Last admin cannot be banned in admins room.", + )); + } + } + }, + _ => {}, } } // If redaction event is not authorized, do not append it to the timeline if pdu.kind == TimelineEventType::RoomRedaction { use RoomVersionId::*; - match self.services.state.get_room_version(&pdu.room_id)? { + match self.services.state.get_room_version(&pdu.room_id).await? { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { if !self .services .state_accessor - .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); + return Err!(Request(Forbidden("User cannot redact this event."))); } }; }, _ => { let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; + .map_err(|e| err!(Database("Invalid content in redaction pdu: {e:?}")))?; if let Some(redact_id) = &content.redacts { if !self .services .state_accessor - .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); + return Err!(Request(Forbidden("User cannot redact this event."))); } } }, @@ -922,7 +972,7 @@ impl Service { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehashid = self.services.state.append_to_state(&pdu)?; + let statehashid = self.services.state.append_to_state(&pdu).await?; let pdu_id = self .append_pdu( @@ -939,14 +989,15 @@ impl Service { // in time where events in the current room state do not exist self.services .state - .set_room_state(room_id, statehashid, state_lock)?; + .set_room_state(&pdu.room_id, statehashid, state_lock); let mut servers: HashSet = self .services .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .collect(); + .room_servers(&pdu.room_id) + .map(ToOwned::to_owned) + .collect() + .await; // In case we are kicking or banning a user, we need to inform their server of // the change @@ -966,7 +1017,8 @@ impl Service { self.services .sending - .send_pdu_servers(servers.into_iter(), &pdu_id)?; + .send_pdu_servers(servers.iter().map(AsRef::as_ref).stream(), &pdu_id) + .await?; Ok(pdu.event_id) } @@ -988,15 +1040,19 @@ impl Service { // fail. self.services .state - .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; + .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed) + .await?; if soft_fail { self.services .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + .mark_as_referenced(&pdu.room_id, &pdu.prev_events); + self.services .state - .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; + .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock) + .await; + return Ok(None); } @@ -1009,71 +1065,71 @@ impl Service { /// Returns an iterator over all PDUs in a room. #[inline] - pub fn all_pdus<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, - ) -> Result> + 'a> { - self.pdus_after(user_id, room_id, PduCount::min()) + pub async fn all_pdus<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, + ) -> Result + Send + 'a> { + self.pdus_after(user_id, room_id, PduCount::min()).await } /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. #[tracing::instrument(skip(self), level = "debug")] - pub fn pdus_until<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result> + 'a> { - self.db.pdus_until(user_id, room_id, until) + pub async fn pdus_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, + ) -> Result + Send + 'a> { + self.db.pdus_until(user_id, room_id, until).await } /// Returns an iterator over all events and their token in a room that /// happened after the event with id `from` in chronological order. #[tracing::instrument(skip(self), level = "debug")] - pub fn pdus_after<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result> + 'a> { - self.db.pdus_after(user_id, room_id, from) + pub async fn pdus_after<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, + ) -> Result + Send + 'a> { + self.db.pdus_after(user_id, room_id, from).await } /// Replace a PDU with the redacted form. #[tracing::instrument(skip(self, reason))] - pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { + pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { // TODO: Don't reserialize, keep original json - if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + let Ok(pdu_id) = self.get_pdu_id(event_id).await else { + // If event does not exist, just noop + return Ok(()); + }; - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - if let Some(body) = content.body { - self.services - .search - .deindex_pdu(shortroomid, &pdu_id, &body)?; - } + let mut pdu = self + .get_pdu_from_id(&pdu_id) + .await + .map_err(|e| err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))))?; + + if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Some(body) = content.body { + self.services + .search + .deindex_pdu(shortroomid, &pdu_id, &body); } - - let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; - - pdu.redact(room_version_id, reason)?; - - self.replace_pdu( - &pdu_id, - &utils::to_canonical_object(&pdu).map_err(|e| { - error!("Failed to convert PDU to canonical JSON: {}", e); - Error::bad_database("Failed to convert PDU to canonical JSON.") - })?, - &pdu, - )?; } - // If event does not exist, just noop - Ok(()) + + let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?; + + pdu.redact(room_version_id, reason)?; + + let obj = utils::to_canonical_object(&pdu) + .map_err(|e| err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON"))))?; + + self.replace_pdu(&pdu_id, &obj, &pdu).await } #[tracing::instrument(skip(self))] pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { let first_pdu = self - .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) + .await? .next() - .expect("Room is not empty")?; + .await + .expect("Room is not empty"); if first_pdu.0 < from { // No backfill required, there are still events between them @@ -1083,17 +1139,18 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await .map(|ev| { serde_json::from_str(ev.content.get()) .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + .unwrap() }) - .transpose()? .unwrap_or_default(); let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| { if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) { - Some(user_id.server_name().to_owned()) + Some(user_id.server_name()) } else { None } @@ -1103,34 +1160,43 @@ impl Service { .services .alias .local_aliases_for_room(room_id) - .filter_map(|alias| { - alias - .ok() - .filter(|alias| !self.services.globals.server_is_ours(alias.server_name())) - .map(|alias| alias.server_name().to_owned()) + .ready_filter_map(|alias| { + self.services + .globals + .server_is_ours(alias.server_name()) + .then_some(alias.server_name()) }); - let servers = room_mods + let mut servers = room_mods + .stream() .chain(room_alias_servers) - .chain(self.services.server.config.trusted_servers.clone()) - .filter(|server_name| { - if self.services.globals.server_is_ours(server_name) { - return false; - } - + .map(ToOwned::to_owned) + .chain( + self.services + .server + .config + .trusted_servers + .iter() + .map(ToOwned::to_owned) + .stream(), + ) + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)) + .filter_map(|server_name| async move { self.services .state_cache - .server_in_room(server_name, room_id) - .unwrap_or(false) - }); + .server_in_room(&server_name, room_id) + .await + .then_some(server_name) + }) + .boxed(); - for backfill_server in servers { + while let Some(ref backfill_server) = servers.next().await { info!("Asking {backfill_server} for backfill"); let response = self .services .sending .send_federation_request( - &backfill_server, + backfill_server, federation::backfill::get_backfill::v1::Request { room_id: room_id.to_owned(), v: vec![first_pdu.1.event_id.as_ref().to_owned()], @@ -1142,7 +1208,7 @@ impl Service { Ok(response) => { let pub_key_map = RwLock::new(BTreeMap::new()); for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(&backfill_server, pdu, &pub_key_map).await { + if let Err(e) = self.backfill_pdu(backfill_server, pdu, &pub_key_map).await { warn!("Failed to add backfilled pdu in room {room_id}: {e}"); } } @@ -1163,7 +1229,7 @@ impl Service { &self, origin: &ServerName, pdu: Box, pub_key_map: &RwLock>>, ) -> Result<()> { - let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu)?; + let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?; // Lock so we cannot backfill the same pdu twice at the same time let mutex_lock = self @@ -1174,7 +1240,7 @@ impl Service { .await; // Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = self.get_pdu_id(&event_id)? { + if let Ok(pdu_id) = self.get_pdu_id(&event_id).await { let pdu_id = pdu_id.to_vec(); debug!("We already know {event_id} at {pdu_id:?}"); return Ok(()); @@ -1190,36 +1256,38 @@ impl Service { .handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map) .await?; - let value = self.get_pdu_json(&event_id)?.expect("We just created it"); - let pdu = self.get_pdu(&event_id)?.expect("We just created it"); + let value = self + .get_pdu_json(&event_id) + .await + .expect("We just created it"); + let pdu = self.get_pdu(&event_id).await.expect("We just created it"); let shortroomid = self .services .short - .get_shortroomid(&room_id)? + .get_shortroomid(&room_id) + .await .expect("room exists"); let insert_lock = self.mutex_insert.lock(&room_id).await; let max = u64::MAX; - let count = self.services.globals.next_count()?; + let count = self.services.globals.next_count().unwrap(); let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&(validated!(max - count)).to_be_bytes()); // Insert pdu - self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; + self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value); drop(insert_lock); if pdu.kind == TimelineEventType::RoomMessage { let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + .map_err(|e| err!(Database("Invalid content in pdu: {e:?}")))?; if let Some(body) = content.body { - self.services - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + self.services.search.index_pdu(shortroomid, &pdu_id, &body); } } drop(mutex_lock); diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 3cf1cdd5..bcfce616 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -46,7 +46,7 @@ impl Service { /// Sets a user as typing until the timeout timestamp is reached or /// roomtyping_remove is called. pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { - debug_info!("typing started {:?} in {:?} timeout:{:?}", user_id, room_id, timeout); + debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}"); // update clients self.typing .write() @@ -54,17 +54,19 @@ impl Service { .entry(room_id.to_owned()) .or_default() .insert(user_id.to_owned(), timeout); + self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if self.services.globals.user_is_local(user_id) { - self.federation_send(room_id, user_id, true)?; + self.federation_send(room_id, user_id, true).await?; } Ok(()) @@ -72,7 +74,7 @@ impl Service { /// Removes a user from typing before the timeout is reached. pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - debug_info!("typing stopped {:?} in {:?}", user_id, room_id); + debug_info!("typing stopped {user_id:?} in {room_id:?}"); // update clients self.typing .write() @@ -80,31 +82,31 @@ impl Service { .entry(room_id.to_owned()) .or_default() .remove(user_id); + self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if self.services.globals.user_is_local(user_id) { - self.federation_send(room_id, user_id, false)?; + self.federation_send(room_id, user_id, false).await?; } Ok(()) } - pub async fn wait_for_update(&self, room_id: &RoomId) -> Result<()> { + pub async fn wait_for_update(&self, room_id: &RoomId) { let mut receiver = self.typing_update_sender.subscribe(); while let Ok(next) = receiver.recv().await { if next == room_id { break; } } - - Ok(()) } /// Makes sure that typing events with old timestamps get removed. @@ -123,30 +125,30 @@ impl Service { removable.push(user.clone()); } } - - drop(typing); }; if !removable.is_empty() { let typing = &mut self.typing.write().await; let room = typing.entry(room_id.to_owned()).or_default(); for user in &removable { - debug_info!("typing timeout {:?} in {:?}", &user, room_id); + debug_info!("typing timeout {user:?} in {room_id:?}"); room.remove(user); } + // update clients self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation - for user in removable { - if self.services.globals.user_is_local(&user) { - self.federation_send(room_id, &user, false)?; + for user in &removable { + if self.services.globals.user_is_local(user) { + self.federation_send(room_id, user, false).await?; } } } @@ -183,7 +185,7 @@ impl Service { }) } - fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { + async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { debug_assert!( self.services.globals.user_is_local(user_id), "tried to broadcast typing status of remote user", @@ -197,7 +199,8 @@ impl Service { self.services .sending - .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?; + .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing")) + .await?; Ok(()) } diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index c7131615..d4d9874c 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,8 +1,9 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use conduit::Result; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; +use ruma::{RoomId, UserId}; use crate::{globals, rooms, Dep}; @@ -11,13 +12,13 @@ pub(super) struct Data { userroomid_highlightcount: Arc, roomuserid_lastnotificationread: Arc, roomsynctoken_shortstatehash: Arc, - userroomid_joined: Arc, services: Services, } struct Services { globals: Dep, short: Dep, + state_cache: Dep, } impl Data { @@ -28,15 +29,15 @@ impl Data { userroomid_highlightcount: db["userroomid_highlightcount"].clone(), roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(), - userroomid_joined: db["userroomid_joined"].clone(), services: Services { globals: args.depend::("globals"), short: args.depend::("rooms::short"), + state_cache: args.depend::("rooms::state_cache"), }, } } - pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -45,128 +46,73 @@ impl Data { roomuser_id.extend_from_slice(user_id.as_bytes()); self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; + .insert(&userroom_id, &0_u64.to_be_bytes()); self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; + .insert(&userroom_id, &0_u64.to_be_bytes()); self.roomuserid_lastnotificationread - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - - Ok(()) + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); } - pub(super) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - + pub(super) async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); self.userroomid_notificationcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - + pub(super) async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); self.userroomid_highlightcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastnotificationread - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) + pub(super) async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.roomuserid_lastnotificationread + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn associate_token_shortstatehash( - &self, room_id: &RoomId, token: u64, shortstatehash: u64, - ) -> Result<()> { + pub(super) async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { let shortroomid = self .services .short - .get_shortroomid(room_id)? + .get_shortroomid(room_id) + .await .expect("room exists"); let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(&token.to_be_bytes()); self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) + .insert(&key, &shortstatehash.to_be_bytes()); } - pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = self - .services - .short - .get_shortroomid(room_id)? - .expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); + pub(super) async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { + let shortroomid = self.services.short.get_shortroomid(room_id).await?; + let key: &[u64] = &[shortroomid, token]; self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")) - }) - .transpose() + .qry(key) + .await + .deserialized() } + //TODO: optimize; replace point-queries with dual iteration pub(super) fn get_shared_rooms<'a>( - &'a self, users: Vec, - ) -> Result> + 'a>> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xFF) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - .saturating_add(1); // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(Result::ok) - }); - - // We use the default compare function because keys are sorted correctly (not - // reversed) - Ok(Box::new( - utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, - ) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - }), - )) + &'a self, user_a: &'a UserId, user_b: &'a UserId, + ) -> impl Stream + Send + 'a { + self.services + .state_cache + .rooms_joined(user_a) + .filter(|room_id| self.services.state_cache.is_joined(user_b, room_id)) } } diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 93d38470..d9d90ecf 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -3,7 +3,8 @@ mod data; use std::sync::Arc; use conduit::Result; -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use futures::{pin_mut, Stream, StreamExt}; +use ruma::{RoomId, UserId}; use self::data::Data; @@ -22,32 +23,49 @@ impl crate::Service for Service { } impl Service { - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.reset_notification_counts(user_id, room_id) + #[inline] + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { + self.db.reset_notification_counts(user_id, room_id); } - pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.notification_count(user_id, room_id) + #[inline] + pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.notification_count(user_id, room_id).await } - pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.highlight_count(user_id, room_id) + #[inline] + pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.highlight_count(user_id, room_id).await } - pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.last_notification_read(user_id, room_id) + #[inline] + pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.last_notification_read(user_id, room_id).await } - pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { + #[inline] + pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { self.db .associate_token_shortstatehash(room_id, token, shortstatehash) + .await; } - pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - self.db.get_token_shortstatehash(room_id, token) + #[inline] + pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { + self.db.get_token_shortstatehash(room_id, token).await } - pub fn get_shared_rooms(&self, users: Vec) -> Result> + '_> { - self.db.get_shared_rooms(users) + #[inline] + pub fn get_shared_rooms<'a>( + &'a self, user_a: &'a UserId, user_b: &'a UserId, + ) -> impl Stream + Send + 'a { + self.db.get_shared_rooms(user_a, user_b) + } + + pub async fn has_shared_rooms<'a>(&'a self, user_a: &'a UserId, user_b: &'a UserId) -> bool { + let get_shared_rooms = self.get_shared_rooms(user_a, user_b); + + pin_mut!(get_shared_rooms); + get_shared_rooms.next().await.is_some() } } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 6c8e2544..b96f9a03 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,14 +1,21 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{ + utils, + utils::{stream::TryIgnore, ReadyExt}, + Error, Result, +}; +use database::{Database, Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{ServerName, UserId}; use super::{Destination, SendingEvent}; use crate::{globals, Dep}; -type OutgoingSendingIter<'a> = Box, Destination, SendingEvent)>> + 'a>; -type SendingEventIter<'a> = Box, SendingEvent)>> + 'a>; +pub(super) type OutgoingItem = (Key, SendingEvent, Destination); +pub(super) type SendingItem = (Key, SendingEvent); +pub(super) type QueueItem = (Key, SendingEvent); +pub(super) type Key = Vec; pub struct Data { servercurrentevent_data: Arc, @@ -36,85 +43,34 @@ impl Data { } } - #[inline] - pub fn active_requests(&self) -> OutgoingSendingIter<'_> { - Box::new( - self.servercurrentevent_data - .iter() - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), - ) - } + pub(super) fn delete_active_request(&self, key: &[u8]) { self.servercurrentevent_data.remove(key); } - #[inline] - pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> SendingEventIter<'a> { + pub(super) async fn delete_all_active_requests_for(&self, destination: &Destination) { let prefix = destination.get_prefix(); - Box::new( - self.servercurrentevent_data - .scan_prefix(prefix) - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), - ) + self.servercurrentevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servercurrentevent_data.remove(key)) + .await; } - pub(super) fn delete_active_request(&self, key: &[u8]) -> Result<()> { self.servercurrentevent_data.remove(key) } - - pub(super) fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { + pub(super) async fn delete_all_requests_for(&self, destination: &Destination) { let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { - self.servercurrentevent_data.remove(&key)?; - } + self.servercurrentevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servercurrentevent_data.remove(key)) + .await; - Ok(()) - } - - pub(super) fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { - let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { - self.servercurrentevent_data.remove(&key).unwrap(); - } - - for (key, _) in self.servernameevent_data.scan_prefix(prefix) { - self.servernameevent_data.remove(&key).unwrap(); - } - - Ok(()) - } - - pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result>> { - let mut batch = Vec::new(); - let mut keys = Vec::new(); - for (destination, event) in requests { - let mut key = destination.get_prefix(); - if let SendingEvent::Pdu(value) = &event { - key.extend_from_slice(value); - } else { - key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - } - let value = if let SendingEvent::Edu(value) = &event { - &**value - } else { - &[] - }; - batch.push((key.clone(), value.to_owned())); - keys.push(key); - } self.servernameevent_data - .insert_batch(batch.iter().map(database::KeyVal::from))?; - Ok(keys) + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servernameevent_data.remove(key)) + .await; } - pub fn queued_requests<'a>( - &'a self, destination: &Destination, - ) -> Box)>> + 'a> { - let prefix = destination.get_prefix(); - return Box::new( - self.servernameevent_data - .scan_prefix(prefix) - .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), - ); - } - - pub(super) fn mark_as_active(&self, events: &[(SendingEvent, Vec)]) -> Result<()> { - for (e, key) in events { + pub(super) fn mark_as_active(&self, events: &[QueueItem]) { + for (key, e) in events { if key.is_empty() { continue; } @@ -124,29 +80,87 @@ impl Data { } else { &[] }; - self.servercurrentevent_data.insert(key, value)?; - self.servernameevent_data.remove(key)?; + self.servercurrentevent_data.insert(key, value); + self.servernameevent_data.remove(key); + } + } + + #[inline] + pub fn active_requests(&self) -> impl Stream + Send + '_ { + self.servercurrentevent_data + .raw_stream() + .ignore_err() + .map(|(key, val)| { + let (dest, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); + + (key.to_vec(), event, dest) + }) + } + + #[inline] + pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> impl Stream + Send + 'a { + let prefix = destination.get_prefix(); + self.servercurrentevent_data + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); + + (key.to_vec(), event) + }) + } + + pub(super) fn queue_requests(&self, requests: &[(&SendingEvent, &Destination)]) -> Vec> { + let mut batch = Vec::new(); + let mut keys = Vec::new(); + for (event, destination) in requests { + let mut key = destination.get_prefix(); + if let SendingEvent::Pdu(value) = &event { + key.extend_from_slice(value); + } else { + key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); + } + let value = if let SendingEvent::Edu(value) = &event { + &**value + } else { + &[] + }; + batch.push((key.clone(), value.to_owned())); + keys.push(key); } - Ok(()) + self.servernameevent_data.insert_batch(batch.iter()); + keys } - pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { - self.servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes()) - } + pub fn queued_requests<'a>(&'a self, destination: &Destination) -> impl Stream + Send + 'a { + let prefix = destination.get_prefix(); + self.servernameevent_data + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); - pub fn get_latest_educount(&self, server_name: &ServerName) -> Result { - self.servername_educount - .get(server_name.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) + (key.to_vec(), event) }) } + + pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) { + self.servername_educount + .insert(server_name.as_bytes(), &last_count.to_be_bytes()); + } + + pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 { + self.servername_educount + .qry(server_name) + .await + .deserialized() + .unwrap_or(0) + } } #[tracing::instrument(skip(key), level = "debug")] -fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent)> { +fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, SendingEvent)> { // Appservices start with a plus Ok::<_, Error>(if key.starts_with(b"+") { let mut parts = key[1..].splitn(2, |&b| b == 0xFF); @@ -164,7 +178,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, if value.is_empty() { SendingEvent::Pdu(event.to_vec()) } else { - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) } else if key.starts_with(b"$") { @@ -192,7 +206,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent::Pdu(event.to_vec()) } else { // I'm pretty sure this should never be called - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) } else { @@ -214,7 +228,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, if value.is_empty() { SendingEvent::Pdu(event.to_vec()) } else { - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) }) diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index b90ea361..e3582f2e 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -7,10 +7,11 @@ mod sender; use std::{fmt::Debug, sync::Arc}; use async_trait::async_trait; -use conduit::{err, warn, Result, Server}; +use conduit::{err, utils::ReadyExt, warn, Result, Server}; +use futures::{future::ready, Stream, StreamExt, TryStreamExt}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, - OwnedServerName, RoomId, ServerName, UserId, + RoomId, ServerName, UserId, }; use tokio::sync::Mutex; @@ -104,7 +105,7 @@ impl Service { let dest = Destination::Push(user.to_owned(), pushkey); let event = SendingEvent::Pdu(pdu_id.to_owned()); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -117,7 +118,7 @@ impl Service { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -126,30 +127,31 @@ impl Service { } #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] - pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.send_pdu_servers(servers, pdu_id) + self.send_pdu_servers(servers, pdu_id).await } #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")] - pub fn send_pdu_servers>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned()))) - .collect::>(); + pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &[u8]) -> Result<()> + where + S: Stream + Send + 'a, + { let _cork = self.db.db.cork(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; + let requests = servers + .map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.into()))) + .collect::>() + .await; + + let keys = self + .db + .queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::>()); + for ((dest, event), queue_id) in requests.into_iter().zip(keys) { self.dispatch(Msg { dest, @@ -166,7 +168,7 @@ impl Service { let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -175,30 +177,30 @@ impl Service { } #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] - pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { + pub async fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.send_edu_servers(servers, serialized) + self.send_edu_servers(servers, serialized).await } #[tracing::instrument(skip(self, servers, serialized), level = "debug")] - pub fn send_edu_servers>(&self, servers: I, serialized: Vec) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone()))) - .collect::>(); + pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: Vec) -> Result<()> + where + S: Stream + Send + 'a, + { let _cork = self.db.db.cork(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; + let requests = servers + .map(|server| (Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone()))) + .collect::>() + .await; + + let keys = self + .db + .queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::>()); for ((dest, event), queue_id) in requests.into_iter().zip(keys) { self.dispatch(Msg { @@ -212,29 +214,33 @@ impl Service { } #[tracing::instrument(skip(self, room_id), level = "debug")] - pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { + pub async fn flush_room(&self, room_id: &RoomId) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.flush_servers(servers) + self.flush_servers(servers).await } #[tracing::instrument(skip(self, servers), level = "debug")] - pub fn flush_servers>(&self, servers: I) -> Result<()> { - let requests = servers.into_iter().map(Destination::Normal); - for dest in requests { - self.dispatch(Msg { - dest, - event: SendingEvent::Flush, - queue_id: Vec::::new(), - })?; - } - - Ok(()) + pub async fn flush_servers<'a, S>(&self, servers: S) -> Result<()> + where + S: Stream + Send + 'a, + { + servers + .map(ToOwned::to_owned) + .map(Destination::Normal) + .map(Ok) + .try_for_each(|dest| { + ready(self.dispatch(Msg { + dest, + event: SendingEvent::Flush, + queue_id: Vec::::new(), + })) + }) + .await } #[tracing::instrument(skip_all, name = "request")] @@ -263,11 +269,10 @@ impl Service { /// Cleanup event data /// Used for instance after we remove an appservice registration #[tracing::instrument(skip(self), level = "debug")] - pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { + pub async fn cleanup_events(&self, appservice_id: String) { self.db - .delete_all_requests_for(&Destination::Appservice(appservice_id))?; - - Ok(()) + .delete_all_requests_for(&Destination::Appservice(appservice_id)) + .await; } fn dispatch(&self, msg: Msg) -> Result<()> { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 206bf92b..4db9922a 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -7,18 +7,15 @@ use std::{ use base64::{engine::general_purpose, Engine as _}; use conduit::{ - debug, debug_warn, error, trace, - utils::{calculate_hash, math::continue_exponential_backoff_secs}, + debug, debug_warn, err, trace, + utils::{calculate_hash, math::continue_exponential_backoff_secs, ReadyExt}, warn, Error, Result, }; -use federation::transactions::send_transaction_message; -use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; +use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ - api::federation::{ - self, - transactions::edu::{ - DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, - }, + api::federation::transactions::{ + edu::{DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap}, + send_transaction_message, }, device_id, events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, @@ -28,7 +25,7 @@ use ruma::{ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::time::sleep_until; -use super::{appservice, Destination, Msg, SendingEvent, Service}; +use super::{appservice, data::QueueItem, Destination, Msg, SendingEvent, Service}; #[derive(Debug)] enum TransactionStatus { @@ -50,20 +47,20 @@ const CLEANUP_TIMEOUT_MS: u64 = 3500; impl Service { #[tracing::instrument(skip_all, name = "sender")] pub(super) async fn sender(&self) -> Result<()> { - let receiver = self.receiver.lock().await; - let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); + let mut futures: SendingFutures<'_> = FuturesUnordered::new(); + let receiver = self.receiver.lock().await; - self.initial_requests(&futures, &mut statuses); + self.initial_requests(&mut futures, &mut statuses).await; loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { request = receiver.recv_async() => match request { - Ok(request) => self.handle_request(request, &futures, &mut statuses), + Ok(request) => self.handle_request(request, &mut futures, &mut statuses).await, Err(_) => break, }, Some(response) = futures.next() => { - self.handle_response(response, &futures, &mut statuses); + self.handle_response(response, &mut futures, &mut statuses).await; }, } } @@ -72,18 +69,16 @@ impl Service { Ok(()) } - fn handle_response<'a>( - &'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, + async fn handle_response<'a>( + &'a self, response: SendingResult, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { match response { - Ok(dest) => self.handle_response_ok(&dest, futures, statuses), - Err((dest, e)) => Self::handle_response_err(dest, futures, statuses, &e), + Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await, + Err((dest, e)) => Self::handle_response_err(dest, statuses, &e), }; } - fn handle_response_err( - dest: Destination, _futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error, - ) { + fn handle_response_err(dest: Destination, statuses: &mut CurTransactionStatus, e: &Error) { debug!(dest = ?dest, "{e:?}"); statuses.entry(dest).and_modify(|e| { *e = match e { @@ -94,39 +89,40 @@ impl Service { }); } - fn handle_response_ok<'a>( - &'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, + #[allow(clippy::needless_pass_by_ref_mut)] + async fn handle_response_ok<'a>( + &'a self, dest: &Destination, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { let _cork = self.db.db.cork(); - self.db - .delete_all_active_requests_for(dest) - .expect("all active requests deleted"); + self.db.delete_all_active_requests_for(dest).await; // Find events that have been added since starting the last request let new_events = self .db .queued_requests(dest) - .filter_map(Result::ok) .take(DEQUEUE_LIMIT) - .collect::>(); + .collect::>() + .await; // Insert any pdus we found if !new_events.is_empty() { - self.db - .mark_as_active(&new_events) - .expect("marked as active"); - let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect(); - futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec))); + self.db.mark_as_active(&new_events); + + let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect(); + futures.push(self.send_events(dest.clone(), new_events_vec).boxed()); } else { statuses.remove(dest); } } - fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { - let iv = vec![(msg.event, msg.queue_id)]; - if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) { + #[allow(clippy::needless_pass_by_ref_mut)] + async fn handle_request<'a>( + &'a self, msg: Msg, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, + ) { + let iv = vec![(msg.queue_id, msg.event)]; + if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await { if !events.is_empty() { - futures.push(Box::pin(self.send_events(msg.dest, events))); + futures.push(self.send_events(msg.dest, events).boxed()); } else { statuses.remove(&msg.dest); } @@ -142,7 +138,7 @@ impl Service { tokio::select! { () = sleep_until(deadline.into()) => break, response = futures.next() => match response { - Some(response) => self.handle_response(response, futures, statuses), + Some(response) => self.handle_response(response, futures, statuses).await, None => return, } } @@ -151,16 +147,17 @@ impl Service { debug_warn!("Leaving with {} unfinished requests...", futures.len()); } - fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { + #[allow(clippy::needless_pass_by_ref_mut)] + async fn initial_requests<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::>::new(); - for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { + let mut active = self.db.active_requests().boxed(); + + while let Some((key, event, dest)) = active.next().await { let entry = txns.entry(dest.clone()).or_default(); if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep { - warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key)); - self.db - .delete_active_request(&key) - .expect("active request deleted"); + warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key)); + self.db.delete_active_request(&key); } else { entry.push(event); } @@ -169,16 +166,16 @@ impl Service { for (dest, events) in txns { if self.server.config.startup_netburst && !events.is_empty() { statuses.insert(dest.clone(), TransactionStatus::Running); - futures.push(Box::pin(self.send_events(dest.clone(), events))); + futures.push(self.send_events(dest.clone(), events).boxed()); } } } #[tracing::instrument(skip_all, level = "debug")] - fn select_events( + async fn select_events( &self, dest: &Destination, - new_events: Vec<(SendingEvent, Vec)>, // Events we want to send: event and full key + new_events: Vec, // Events we want to send: event and full key statuses: &mut CurTransactionStatus, ) -> Result>> { let (allow, retry) = self.select_events_current(dest.clone(), statuses)?; @@ -195,8 +192,8 @@ impl Service { if retry { self.db .active_requests_for(dest) - .filter_map(Result::ok) - .for_each(|(_, e)| events.push(e)); + .ready_for_each(|(_, e)| events.push(e)) + .await; return Ok(Some(events)); } @@ -204,17 +201,17 @@ impl Service { // Compose the next transaction let _cork = self.db.db.cork(); if !new_events.is_empty() { - self.db.mark_as_active(&new_events)?; - for (e, _) in new_events { + self.db.mark_as_active(&new_events); + for (_, e) in new_events { events.push(e); } } // Add EDU's into the transaction if let Destination::Normal(server_name) = dest { - if let Ok((select_edus, last_count)) = self.select_edus(server_name) { + if let Ok((select_edus, last_count)) = self.select_edus(server_name).await { events.extend(select_edus.into_iter().map(SendingEvent::Edu)); - self.db.set_latest_educount(server_name, last_count)?; + self.db.set_latest_educount(server_name, last_count); } } @@ -248,26 +245,32 @@ impl Service { } #[tracing::instrument(skip_all, level = "debug")] - fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { + async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { // u64: count of last edu - let since = self.db.get_latest_educount(server_name)?; + let since = self.db.get_latest_educount(server_name).await; let mut events = Vec::new(); let mut max_edu_count = since; let mut device_list_changes = HashSet::new(); - for room_id in self.services.state_cache.server_rooms(server_name) { - let room_id = room_id?; + let server_rooms = self.services.state_cache.server_rooms(server_name); + + pin_mut!(server_rooms); + while let Some(room_id) = server_rooms.next().await { // Look for device list updates in this room device_list_changes.extend( self.services .users - .keys_changed(room_id.as_ref(), since, None) - .filter_map(Result::ok) - .filter(|user_id| self.services.globals.user_is_local(user_id)), + .keys_changed(room_id.as_str(), since, None) + .ready_filter(|user_id| self.services.globals.user_is_local(user_id)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); if self.server.config.allow_outgoing_read_receipts - && !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? + && !self + .select_edus_receipts(room_id, since, &mut max_edu_count, &mut events) + .await? { break; } @@ -290,19 +293,22 @@ impl Service { } if self.server.config.allow_outgoing_presence { - self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; + self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events) + .await?; } Ok((events, max_edu_count)) } /// Look for presence - fn select_edus_presence( + async fn select_edus_presence( &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec>, ) -> Result { - // Look for presence updates for this server + let presence_since = self.services.presence.presence_since(since); + + pin_mut!(presence_since); let mut presence_updates = Vec::new(); - for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) { + while let Some((user_id, count, presence_bytes)) = presence_since.next().await { *max_edu_count = cmp::max(count, *max_edu_count); if !self.services.globals.user_is_local(&user_id) { @@ -312,7 +318,8 @@ impl Service { if !self .services .state_cache - .server_sees_user(server_name, &user_id)? + .server_sees_user(server_name, &user_id) + .await { continue; } @@ -320,7 +327,9 @@ impl Service { let presence_event = self .services .presence - .from_json_bytes_to_event(&presence_bytes, &user_id)?; + .from_json_bytes_to_event(&presence_bytes, &user_id) + .await?; + presence_updates.push(PresenceUpdate { user_id, presence: presence_event.content.presence, @@ -346,32 +355,33 @@ impl Service { } /// Look for read receipts in this room - fn select_edus_receipts( + async fn select_edus_receipts( &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec>, ) -> Result { - for r in self + let receipts = self .services .read_receipt - .readreceipts_since(room_id, since) - { - let (user_id, count, read_receipt) = r?; - *max_edu_count = cmp::max(count, *max_edu_count); + .readreceipts_since(room_id, since); + pin_mut!(receipts); + while let Some((user_id, count, read_receipt)) = receipts.next().await { + *max_edu_count = cmp::max(count, *max_edu_count); if !self.services.globals.user_is_local(&user_id) { continue; } let event = serde_json::from_str(read_receipt.json().get()) .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; + let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { let mut read = BTreeMap::new(); - let (event_id, mut receipt) = r .content .0 .into_iter() .next() .expect("we only use one event per read receipt"); + let receipt = receipt .remove(&ReceiptType::Read) .expect("our read receipts always set this") @@ -427,24 +437,17 @@ impl Service { async fn send_events_dest_appservice( &self, dest: &Destination, id: &str, events: Vec, ) -> SendingResult { - let mut pdu_jsons = Vec::new(); + let Some(appservice) = self.services.appservice.get_registration(id).await else { + return Err((dest.clone(), err!(Database(warn!(?id, "Missing appservice registration"))))); + }; + let mut pdu_jsons = Vec::new(); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { - pdu_jsons.push( - self.services - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Event in servernameevent_data not found in db."), - ) - })? - .to_room_event(), - ); + if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { + pdu_jsons.push(pdu.to_room_event()); + } }, SendingEvent::Edu(_) | SendingEvent::Flush => { // Appservices don't need EDUs (?) and flush only; @@ -453,32 +456,24 @@ impl Service { } } + let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Flush => &[], + }) + .collect::>(), + )); + //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); let client = &self.services.client.appservice; match appservice::send_request( client, - self.services - .appservice - .get_registration(id) - .await - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Could not load registration from db."), - ) - })?, + appservice, ruma::api::appservice::event::push_events::v1::Request { events: pdu_jsons, - txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::>(), - ))) - .into(), + txn_id: txn_id.into(), ephemeral: Vec::new(), to_device: Vec::new(), }, @@ -494,23 +489,17 @@ impl Service { async fn send_events_dest_push( &self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec, ) -> SendingResult { - let mut pdus = Vec::new(); + let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else { + return Err((dest.clone(), err!(Database(error!(?userid, ?pushkey, "Missing pusher"))))); + }; + let mut pdus = Vec::new(); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { - pdus.push( - self.services - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Push] Event in servernameevent_data not found in db."), - ) - })?, - ); + if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { + pdus.push(pdu); + } }, SendingEvent::Edu(_) | SendingEvent::Flush => { // Push gateways don't need EDUs (?) and flush only; @@ -529,28 +518,22 @@ impl Service { } } - let Some(pusher) = self - .services - .pusher - .get_pusher(userid, pushkey) - .map_err(|e| (dest.clone(), e))? - else { - continue; - }; - let rules_for_user = self .services .account_data .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap_or_default() - .and_then(|event| serde_json::from_str::(event.get()).ok()) - .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); + .await + .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) + .map_or_else( + |_| push::Ruleset::server_default(userid), + |ev: PushRulesEvent| ev.content.global, + ); let unread: UInt = self .services .user .notification_count(userid, &pdu.room_id) - .map_err(|e| (dest.clone(), e))? + .await .try_into() .expect("notification count can't go that high"); @@ -559,7 +542,6 @@ impl Service { .pusher .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) .await - .map(|_response| dest.clone()) .map_err(|e| (dest.clone(), e)); } @@ -586,21 +568,11 @@ impl Service { for event in &events { match event { // TODO: check room version and remove event_id if needed - SendingEvent::Pdu(pdu_id) => pdu_jsons.push( - self.convert_to_outgoing_federation_event( - self.services - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - error!(?dest, ?server, ?pdu_id, "event not found"); - ( - dest.clone(), - Error::bad_database("[Normal] Event in servernameevent_data not found in db."), - ) - })?, - ), - ), + SendingEvent::Pdu(pdu_id) => { + if let Ok(pdu) = self.services.timeline.get_pdu_json_from_id(pdu_id).await { + pdu_jsons.push(self.convert_to_outgoing_federation_event(pdu).await); + } + }, SendingEvent::Edu(edu) => { if let Ok(raw) = serde_json::from_slice(edu) { edu_jsons.push(raw); @@ -647,7 +619,7 @@ impl Service { } /// This does not return a full `Pdu` it is only to satisfy ruma's types. - pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box { + pub async fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box { if let Some(unsigned) = pdu_json .get_mut("unsigned") .and_then(|val| val.as_object_mut()) @@ -660,7 +632,7 @@ impl Service { .get("room_id") .and_then(|val| RoomId::parse(val.as_str()?).ok()) { - match self.services.state.get_room_version(&room_id) { + match self.services.state.get_room_version(&room_id).await { Ok(room_version_id) => match room_version_id { RoomVersionId::V1 | RoomVersionId::V2 => {}, _ => _ = pdu_json.remove("event_id"), diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index a565e500..ae2b8c3c 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -5,7 +5,7 @@ use std::{ }; use conduit::{debug, debug_error, debug_warn, err, error, info, trace, warn, Err, Result}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::{ discovery::{ @@ -179,7 +179,8 @@ impl Service { let result: BTreeMap<_, _> = self .services .globals - .verify_keys_for(origin)? + .verify_keys_for(origin) + .await? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -236,7 +237,8 @@ impl Service { .services .globals .db - .add_signing_key(&k.server_name, k.clone())? + .add_signing_key(&k.server_name, k.clone()) + .await .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect::>(); @@ -283,7 +285,8 @@ impl Service { .services .globals .db - .add_signing_key(&origin, key)? + .add_signing_key(&origin, key) + .await .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -384,7 +387,8 @@ impl Service { let mut result: BTreeMap<_, _> = self .services .globals - .verify_keys_for(origin)? + .verify_keys_for(origin) + .await? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -431,7 +435,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, k.clone())?; + .add_signing_key(origin, k.clone()) + .await; result.extend( k.verify_keys .into_iter() @@ -462,7 +467,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, server_key.clone())?; + .add_signing_key(origin, server_key.clone()) + .await; result.extend( server_key @@ -495,7 +501,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, server_key.clone())?; + .add_signing_key(origin, server_key.clone()) + .await; result.extend( server_key @@ -545,7 +552,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, k.clone())?; + .add_signing_key(origin, k.clone()) + .await; result.extend( k.verify_keys .into_iter() diff --git a/src/service/services.rs b/src/service/services.rs index 3aa095b8..da22fb2d 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -14,7 +14,7 @@ use crate::{ manager::Manager, media, presence, pusher, resolver, rooms, sending, server_keys, service, service::{Args, Map, Service}, - transaction_ids, uiaa, updates, users, + sync, transaction_ids, uiaa, updates, users, }; pub struct Services { @@ -32,6 +32,7 @@ pub struct Services { pub rooms: rooms::Service, pub sending: Arc, pub server_keys: Arc, + pub sync: Arc, pub transaction_ids: Arc, pub uiaa: Arc, pub updates: Arc, @@ -96,6 +97,7 @@ impl Services { }, sending: build!(sending::Service), server_keys: build!(server_keys::Service), + sync: build!(sync::Service), transaction_ids: build!(transaction_ids::Service), uiaa: build!(uiaa::Service), updates: build!(updates::Service), diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs new file mode 100644 index 00000000..1bf4610f --- /dev/null +++ b/src/service/sync/mod.rs @@ -0,0 +1,233 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + sync::{Arc, Mutex, Mutex as StdMutex}, +}; + +use conduit::Result; +use ruma::{ + api::client::sync::sync_events::{ + self, + v4::{ExtensionsConfig, SyncRequestList}, + }, + OwnedDeviceId, OwnedRoomId, OwnedUserId, +}; + +pub struct Service { + connections: DbConnections, +} + +struct SlidingSyncCache { + lists: BTreeMap, + subscriptions: BTreeMap, + known_rooms: BTreeMap>, // For every room, the roomsince number + extensions: ExtensionsConfig, +} + +type DbConnections = Mutex>; +type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); +type DbConnectionsVal = Arc>; + +impl crate::Service for Service { + fn build(_args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + connections: StdMutex::new(BTreeMap::new()), + })) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool { + self.connections + .lock() + .unwrap() + .contains_key(&(user_id, device_id, conn_id)) + } + + pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { + self.connections + .lock() + .expect("locked") + .remove(&(user_id, device_id, conn_id)); + } + + pub fn update_sync_request_with_cache( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, + ) -> BTreeMap> { + let Some(conn_id) = request.conn_id.clone() else { + return BTreeMap::new(); + }; + + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + for (list_id, list) in &mut request.lists { + if let Some(cached_list) = cached.lists.get(list_id) { + if list.sort.is_empty() { + list.sort.clone_from(&cached_list.sort); + }; + if list.room_details.required_state.is_empty() { + list.room_details + .required_state + .clone_from(&cached_list.room_details.required_state); + }; + list.room_details.timeline_limit = list + .room_details + .timeline_limit + .or(cached_list.room_details.timeline_limit); + list.include_old_rooms = list + .include_old_rooms + .clone() + .or_else(|| cached_list.include_old_rooms.clone()); + match (&mut list.filters, cached_list.filters.clone()) { + (Some(list_filters), Some(cached_filters)) => { + list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); + if list_filters.spaces.is_empty() { + list_filters.spaces = cached_filters.spaces; + } + list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); + list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); + if list_filters.room_types.is_empty() { + list_filters.room_types = cached_filters.room_types; + } + if list_filters.not_room_types.is_empty() { + list_filters.not_room_types = cached_filters.not_room_types; + } + list_filters.room_name_like = list_filters + .room_name_like + .clone() + .or(cached_filters.room_name_like); + if list_filters.tags.is_empty() { + list_filters.tags = cached_filters.tags; + } + if list_filters.not_tags.is_empty() { + list_filters.not_tags = cached_filters.not_tags; + } + }, + (_, Some(cached_filters)) => list.filters = Some(cached_filters), + (Some(list_filters), _) => list.filters = Some(list_filters.clone()), + (..) => {}, + } + if list.bump_event_types.is_empty() { + list.bump_event_types + .clone_from(&cached_list.bump_event_types); + }; + } + cached.lists.insert(list_id.clone(), list.clone()); + } + + cached + .subscriptions + .extend(request.room_subscriptions.clone()); + request + .room_subscriptions + .extend(cached.subscriptions.clone()); + + request.extensions.e2ee.enabled = request + .extensions + .e2ee + .enabled + .or(cached.extensions.e2ee.enabled); + + request.extensions.to_device.enabled = request + .extensions + .to_device + .enabled + .or(cached.extensions.to_device.enabled); + + request.extensions.account_data.enabled = request + .extensions + .account_data + .enabled + .or(cached.extensions.account_data.enabled); + request.extensions.account_data.lists = request + .extensions + .account_data + .lists + .clone() + .or_else(|| cached.extensions.account_data.lists.clone()); + request.extensions.account_data.rooms = request + .extensions + .account_data + .rooms + .clone() + .or_else(|| cached.extensions.account_data.rooms.clone()); + + cached.extensions = request.extensions.clone(); + + cached.known_rooms.clone() + } + + pub fn update_sync_subscriptions( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, + subscriptions: BTreeMap, + ) { + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + cached.subscriptions = subscriptions; + } + + pub fn update_sync_known_rooms( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, + new_cached_rooms: BTreeSet, globalsince: u64, + ) { + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + for (roomid, lastsince) in cached + .known_rooms + .entry(list_id.clone()) + .or_default() + .iter_mut() + { + if !new_cached_rooms.contains(roomid) { + *lastsince = 0; + } + } + let list = cached.known_rooms.entry(list_id).or_default(); + for roomid in new_cached_rooms { + list.insert(roomid, globalsince); + } + } +} diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs deleted file mode 100644 index 791b46f0..00000000 --- a/src/service/transaction_ids/data.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Database, Map}; -use ruma::{DeviceId, TransactionId, UserId}; - -pub struct Data { - userdevicetxnid_response: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - userdevicetxnid_response: db["userdevicetxnid_response"].clone(), - } - } - - pub(super) fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - self.userdevicetxnid_response.insert(&key, data)?; - - Ok(()) - } - - pub(super) fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - // If there's no entry, this is a new transaction - self.userdevicetxnid_response.get(&key) - } -} diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index 78e6337f..72f60adb 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,35 +1,45 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{implement, Result}; +use database::{Handle, Map}; use ruma::{DeviceId, TransactionId, UserId}; pub struct Service { - pub db: Data, + db: Data, +} + +struct Data { + userdevicetxnid_response: Arc, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - self.db.add_txnid(user_id, device_id, txn_id, data) - } +#[implement(Service)] +pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8]) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(txn_id.as_bytes()); - pub fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result>> { - self.db.existing_txnid(user_id, device_id, txn_id) - } + self.db.userdevicetxnid_response.insert(&key, data); +} + +// If there's no entry, this is a new transaction +#[implement(Service)] +pub async fn existing_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, +) -> Result> { + let key = (user_id, device_id, txn_id); + self.db.userdevicetxnid_response.qry(&key).await } diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs deleted file mode 100644 index ce071da0..00000000 --- a/src/service/uiaa/data.rs +++ /dev/null @@ -1,87 +0,0 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; - -use conduit::{Error, Result}; -use database::{Database, Map}; -use ruma::{ - api::client::{error::ErrorKind, uiaa::UiaaInfo}, - CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, -}; - -pub struct Data { - userdevicesessionid_uiaarequest: RwLock>, - userdevicesessionid_uiaainfo: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - userdevicesessionid_uiaainfo: db["userdevicesessionid_uiaainfo"].clone(), - } - } - - pub(super) fn set_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); - - Ok(()) - } - - pub(super) fn get_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(ToOwned::to_owned) - } - - pub(super) fn update_uiaa_session( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, - ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } - - pub(super) fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) - } -} diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 6041bbd3..7e231514 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,174 +1,243 @@ -mod data; +use std::{ + collections::BTreeMap, + sync::{Arc, RwLock}, +}; -use std::sync::Arc; - -use conduit::{error, utils, utils::hash, Error, Result, Server}; -use data::Data; +use conduit::{ + err, error, implement, utils, + utils::{hash, string::EMPTY}, + Error, Result, Server, +}; +use database::{Deserialized, Map}; use ruma::{ api::client::{ error::ErrorKind, uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier}, }, - CanonicalJsonValue, DeviceId, UserId, + CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, }; use crate::{globals, users, Dep}; -pub const SESSION_ID_LENGTH: usize = 32; - pub struct Service { - server: Arc, + userdevicesessionid_uiaarequest: RwLock, + db: Data, services: Services, - pub db: Data, } struct Services { + server: Arc, globals: Dep, users: Dep, } +struct Data { + userdevicesessionid_uiaainfo: Arc, +} + +type RequestMap = BTreeMap; +type RequestKey = (OwnedUserId, OwnedDeviceId, String); + +pub const SESSION_ID_LENGTH: usize = 32; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - server: args.server.clone(), + userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), + db: Data { + userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), + }, services: Services { + server: args.server.clone(), globals: args.depend::("globals"), users: args.depend::("users"), }, - db: Data::new(args.db), })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Creates a new Uiaa session. Make sure the session token is unique. - pub fn create( - &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, - ) -> Result<()> { - self.db.set_uiaa_request( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), /* TODO: better session error handling (why - * is it optional in ruma?) */ - json_body, - )?; - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), - Some(uiaainfo), - ) +/// Creates a new Uiaa session. Make sure the session token is unique. +#[implement(Service)] +pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue) { + // TODO: better session error handling (why is uiaainfo.session optional in + // ruma?) + self.set_uiaa_request( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + json_body, + ); + + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + Some(uiaainfo), + ); +} + +#[implement(Service)] +pub async fn try_auth( + &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, +) -> Result<(bool, UiaaInfo)> { + let mut uiaainfo = if let Some(session) = auth.session() { + self.get_uiaa_session(user_id, device_id, session).await? + } else { + uiaainfo.clone() + }; + + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); } - pub fn try_auth( - &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, - ) -> Result<(bool, UiaaInfo)> { - let mut uiaainfo = auth.session().map_or_else( - || Ok(uiaainfo.clone()), - |session| self.db.get_uiaa_session(user_id, device_id, session), - )?; + match auth { + // Find out what the user completed + AuthData::Password(Password { + identifier, + password, + #[cfg(feature = "element_hacks")] + user, + .. + }) => { + #[cfg(feature = "element_hacks")] + let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier { + username + } else if let Some(username) = user { + username + } else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } + #[cfg(not(feature = "element_hacks"))] + let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier + else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; - match auth { - // Find out what the user completed - AuthData::Password(Password { - identifier, - password, - #[cfg(feature = "element_hacks")] - user, - .. - }) => { - #[cfg(feature = "element_hacks")] - let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier { - username - } else if let Some(username) = user { - username - } else { - return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); - }; + let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; - #[cfg(not(feature = "element_hacks"))] - let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier - else { - return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); - }; - - let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; - - // Check if password is correct - if let Some(hash) = self.services.users.password_hash(&user_id)? { - let hash_matches = hash::verify_password(password, &hash).is_ok(); - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::forbidden(), - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } - - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push(AuthType::Password); - }, - AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == self.server.config.registration_token.as_deref() { - uiaainfo.completed.push(AuthType::RegistrationToken); - } else { + // Check if password is correct + if let Ok(hash) = self.services.users.password_hash(&user_id).await { + let hash_matches = hash::verify_password(password, &hash).is_ok(); + if !hash_matches { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { kind: ErrorKind::forbidden(), - message: "Invalid registration token.".to_owned(), + message: "Invalid username or password.".to_owned(), }); return Ok((false, uiaainfo)); } - }, - AuthData::Dummy(_) => { - uiaainfo.completed.push(AuthType::Dummy); - }, - k => error!("type not supported: {:?}", k), - } - - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; - } } - // We didn't break, so this flow succeeded! - completed = true; - } - if !completed { - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); - } + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push(AuthType::Password); + }, + AuthData::RegistrationToken(t) => { + if Some(t.token.trim()) == self.services.server.config.registration_token.as_deref() { + uiaainfo.completed.push(AuthType::RegistrationToken); + } else { + uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::forbidden(), + message: "Invalid registration token.".to_owned(), + }); + return Ok((false, uiaainfo)); + } + }, + AuthData::Dummy(_) => { + uiaainfo.completed.push(AuthType::Dummy); + }, + k => error!("type not supported: {:?}", k), + } - // UIAA was successful! Remove this session and return true - self.db.update_uiaa_session( + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } + } + // We didn't break, so this flow succeeded! + completed = true; + } + + if !completed { + self.update_uiaa_session( user_id, device_id, uiaainfo.session.as_ref().expect("session is always set"), - None, - )?; - Ok((true, uiaainfo)) + Some(&uiaainfo), + ); + + return Ok((false, uiaainfo)); } - #[must_use] - pub fn get_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option { - self.db.get_uiaa_request(user_id, device_id, session) + // UIAA was successful! Remove this session and return true + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + ); + + Ok((true, uiaainfo)) +} + +#[implement(Service)] +fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue) { + let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); + self.userdevicesessionid_uiaarequest + .write() + .expect("locked for writing") + .insert(key, request.to_owned()); +} + +#[implement(Service)] +pub fn get_uiaa_request( + &self, user_id: &UserId, device_id: Option<&DeviceId>, session: &str, +) -> Option { + let key = ( + user_id.to_owned(), + device_id.unwrap_or_else(|| EMPTY.into()).to_owned(), + session.to_owned(), + ); + + self.userdevicesessionid_uiaarequest + .read() + .expect("locked for reading") + .get(&key) + .cloned() +} + +#[implement(Service)] +fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>) { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.db.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + ); + } else { + self.db + .userdevicesessionid_uiaainfo + .remove(&userdevicesessionid); } } + +#[implement(Service)] +async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { + let key = (user_id, device_id, session); + self.db + .userdevicesessionid_uiaainfo + .qry(&key) + .await + .deserialized_json() + .map_err(|_| err!(Request(Forbidden("UIAA session does not exist.")))) +} diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index 3c69b243..4e16e22b 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -1,19 +1,22 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{debug, err, info, utils, warn, Error, Result}; -use database::Map; +use conduit::{debug, info, warn, Result}; +use database::{Deserialized, Map}; use ruma::events::room::message::RoomMessageEventContent; use serde::Deserialize; -use tokio::{sync::Notify, time::interval}; +use tokio::{ + sync::Notify, + time::{interval, MissedTickBehavior}, +}; use crate::{admin, client, globals, Dep}; pub struct Service { - services: Services, - db: Arc, - interrupt: Notify, interval: Duration, + interrupt: Notify, + db: Arc, + services: Services, } struct Services { @@ -22,12 +25,12 @@ struct Services { globals: Dep, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct CheckForUpdatesResponse { updates: Vec, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct CheckForUpdatesResponseEntry { id: u64, date: String, @@ -42,33 +45,38 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), + interrupt: Notify::new(), + db: args.db["global"].clone(), services: Services { globals: args.depend::("globals"), admin: args.depend::("admin"), client: args.depend::("client"), }, - db: args.db["global"].clone(), - interrupt: Notify::new(), - interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), })) } + #[tracing::instrument(skip_all, name = "updates", level = "trace")] async fn worker(self: Arc) -> Result<()> { if !self.services.globals.allow_check_for_updates() { debug!("Disabling update check"); return Ok(()); } + let mut i = interval(self.interval); + i.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { tokio::select! { - () = self.interrupt.notified() => return Ok(()), + () = self.interrupt.notified() => break, _ = i.tick() => (), } - if let Err(e) = self.handle_updates().await { + if let Err(e) = self.check().await { warn!(%e, "Failed to check for updates"); } } + + Ok(()) } fn interrupt(&self) { self.interrupt.notify_waiters(); } @@ -77,52 +85,52 @@ impl crate::Service for Service { } impl Service { - #[tracing::instrument(skip_all)] - async fn handle_updates(&self) -> Result<()> { + #[tracing::instrument(skip_all, level = "trace")] + async fn check(&self) -> Result<()> { let response = self .services .client .default .get(CHECK_FOR_UPDATES_URL) .send() + .await? + .text() .await?; - let response = serde_json::from_str::(&response.text().await?) - .map_err(|e| err!("Bad check for updates response: {e}"))?; - - let mut last_update_id = self.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > self.last_check_for_updates_id()? { - info!("{:#}", update.message); - self.services - .admin - .send_message(RoomMessageEventContent::text_markdown(format!( - "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", - update.date, update.message - ))) - .await; + let response = serde_json::from_str::(&response)?; + for update in &response.updates { + if update.id > self.last_check_for_updates_id().await { + self.handle(update).await; + self.update_check_for_updates_id(update.id); } } - self.update_check_for_updates_id(last_update_id)?; Ok(()) } + async fn handle(&self, update: &CheckForUpdatesResponseEntry) { + info!("{} {:#}", update.date, update.message); + self.services + .admin + .send_message(RoomMessageEventContent::text_markdown(format!( + "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", + update.date, update.message + ))) + .await + .ok(); + } + #[inline] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + pub fn update_check_for_updates_id(&self, id: u64) { self.db - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - - Ok(()) + .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes()); } - pub fn last_check_for_updates_id(&self) -> Result { + pub async fn last_check_for_updates_id(&self) -> u64 { self.db - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) + .qry(LAST_CHECK_FOR_UPDATES_COUNT) + .await + .deserialized() + .unwrap_or(0_u64) } } diff --git a/src/service/users/data.rs b/src/service/users/data.rs deleted file mode 100644 index 70ff12e3..00000000 --- a/src/service/users/data.rs +++ /dev/null @@ -1,1098 +0,0 @@ -use std::{collections::BTreeMap, mem::size_of, sync::Arc}; - -use conduit::{debug_info, err, utils, warn, Err, Error, Result, Server}; -use database::Map; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, - OwnedMxcUri, OwnedUserId, UInt, UserId, -}; - -use crate::{globals, rooms, users::clean_signatures, Dep}; - -pub struct Data { - keychangeid_userid: Arc, - keyid_key: Arc, - onetimekeyid_onetimekeys: Arc, - openidtoken_expiresatuserid: Arc, - todeviceid_events: Arc, - token_userdeviceid: Arc, - userdeviceid_metadata: Arc, - userdeviceid_token: Arc, - userfilterid_filter: Arc, - userid_avatarurl: Arc, - userid_blurhash: Arc, - userid_devicelistversion: Arc, - userid_displayname: Arc, - userid_lastonetimekeyupdate: Arc, - userid_masterkeyid: Arc, - userid_password: Arc, - userid_selfsigningkeyid: Arc, - userid_usersigningkeyid: Arc, - useridprofilekey_value: Arc, - services: Services, -} - -struct Services { - server: Arc, - globals: Dep, - state_cache: Dep, - state_accessor: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - keychangeid_userid: db["keychangeid_userid"].clone(), - keyid_key: db["keyid_key"].clone(), - onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(), - openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), - todeviceid_events: db["todeviceid_events"].clone(), - token_userdeviceid: db["token_userdeviceid"].clone(), - userdeviceid_metadata: db["userdeviceid_metadata"].clone(), - userdeviceid_token: db["userdeviceid_token"].clone(), - userfilterid_filter: db["userfilterid_filter"].clone(), - userid_avatarurl: db["userid_avatarurl"].clone(), - userid_blurhash: db["userid_blurhash"].clone(), - userid_devicelistversion: db["userid_devicelistversion"].clone(), - userid_displayname: db["userid_displayname"].clone(), - userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), - userid_masterkeyid: db["userid_masterkeyid"].clone(), - userid_password: db["userid_password"].clone(), - userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(), - userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(), - useridprofilekey_value: db["useridprofilekey_value"].clone(), - services: Services { - server: args.server.clone(), - globals: args.depend::("globals"), - state_cache: args.depend::("rooms::state_cache"), - state_accessor: args.depend::("rooms::state_accessor"), - }, - } - } - - /// Check if a user has an account on this homeserver. - #[inline] - pub(super) fn exists(&self, user_id: &UserId) -> Result { - Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) - } - - /// Check if account is deactivated - pub(super) fn is_deactivated(&self, user_id: &UserId) -> Result { - Ok(self - .userid_password - .get(user_id.as_bytes())? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))? - .is_empty()) - } - - /// Returns the number of users registered on this server. - #[inline] - pub(super) fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } - - /// Find out which user an access token belongs to. - pub(super) fn find_from_token(&self, token: &str) -> Result> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xFF); - let user_bytes = parts - .next() - .ok_or_else(|| err!(Database("User ID in token_userdeviceid is invalid.")))?; - let device_bytes = parts - .next() - .ok_or_else(|| err!(Database("Device ID in token_userdeviceid is invalid.")))?; - - Ok(Some(( - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid. {e}")))?, - utils::string_from_bytes(device_bytes) - .map_err(|e| err!(Database("Device ID in token_userdeviceid is invalid. {e}")))?, - ))) - }) - } - - /// Returns an iterator over all users on this homeserver. - pub fn iter<'a>(&'a self) -> Box> + 'a> { - Box::new(self.userid_password.iter().map(|(bytes, _)| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database("User ID in userid_password is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in userid_password is invalid. {e}"))) - })) - } - - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is - /// greater then zero. - pub(super) fn list_local_users(&self) -> Result> { - let users: Vec = self - .userid_password - .iter() - .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) - .collect(); - Ok(users) - } - - /// Returns the password hash for the given user. - pub(super) fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) - } - - /// Hash and set the user's password to the Argon2 hash - pub(super) fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = utils::hash::password(password) { - self.userid_password - .insert(user_id.as_bytes(), hash.as_bytes())?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.userid_password.insert(user_id.as_bytes(), b"")?; - Ok(()) - } - } - - /// Returns the displayname of a user on this homeserver. - pub(super) fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database("Displayname in db is invalid. {e}")))?, - )) - }) - } - - /// Sets a new displayname or removes it if displayname is None. You still - /// need to nofify all rooms of this change. - pub(super) fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - if let Some(displayname) = displayname { - self.userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes())?; - } else { - self.userid_displayname.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the `avatar_url` of a user. - pub(super) fn avatar_url(&self, user_id: &UserId) -> Result> { - self.userid_avatarurl - .get(user_id.as_bytes())? - .map(|bytes| { - let s_bytes = utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database(warn!("Avatar URL in db is invalid: {e}"))))?; - let mxc_uri: OwnedMxcUri = s_bytes.into(); - Ok(mxc_uri) - }) - .transpose() - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub(super) fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { - if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; - } else { - self.userid_avatarurl.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the blurhash of a user. - pub(super) fn blurhash(&self, user_id: &UserId) -> Result> { - self.userid_blurhash - .get(user_id.as_bytes())? - .map(|bytes| { - utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Avatar URL in db is invalid. {e}"))) - }) - .transpose() - } - - /// Gets a specific user profile key - pub(super) fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - - self.useridprofilekey_value - .get(&key)? - .map_or(Ok(None), |bytes| Ok(Some(serde_json::from_slice(&bytes).unwrap()))) - } - - /// Gets all the user's profile keys and values in an iterator - pub(super) fn all_profile_keys<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a + Send> { - let prefix = user_id.as_bytes().to_vec(); - - Box::new( - self.useridprofilekey_value - .scan_prefix(prefix) - .map(|(key, value)| { - let profile_key_name = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("Profile key in db is invalid")))?, - ) - .map_err(|e| err!(Database("Profile key in db is invalid. {e}")))?; - - let profile_key_value = serde_json::from_slice(&value) - .map_err(|e| err!(Database("Profile key in db is invalid. {e}")))?; - - Ok((profile_key_name, profile_key_value)) - }), - ) - } - - /// Sets a new profile key value, removes the key if value is None - pub(super) fn set_profile_key( - &self, user_id: &UserId, profile_key: &str, profile_key_value: Option, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(value) = profile_key_value { - let value = serde_json::to_vec(&value).unwrap(); - - self.useridprofilekey_value.insert(&key, &value) - } else { - self.useridprofilekey_value.remove(&key) - } - } - - /// Get the timezone of a user. - pub(super) fn timezone(&self, user_id: &UserId) -> Result> { - // first check the unstable prefix - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - - let value = self - .useridprofilekey_value - .get(&key)? - .map(|bytes| utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Timezone in db is invalid. {e}")))) - .transpose() - .unwrap(); - - // TODO: transparently migrate unstable key usage to the stable key once MSC4133 - // and MSC4175 are stable, likely a remove/insert in this block - if value.is_none() || value.as_ref().is_some_and(String::is_empty) { - // check the stable prefix - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"m.tz"); - - return self - .useridprofilekey_value - .get(&key)? - .map(|bytes| { - utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Timezone in db is invalid. {e}"))) - }) - .transpose(); - } - - Ok(value) - } - - /// Sets a new timezone or removes it if timezone is None. - pub(super) fn set_timezone(&self, user_id: &UserId, timezone: Option) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(timezone) = timezone { - self.useridprofilekey_value - .insert(&key, timezone.as_bytes())?; - } else { - self.useridprofilekey_value.remove(&key)?; - } - - Ok(()) - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub(super) fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - if let Some(blurhash) = blurhash { - self.userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes())?; - } else { - self.userid_blurhash.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Adds a new device to a user. - pub(super) fn create_device( - &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, - client_ip: Option, - ) -> Result<()> { - // This method should never be called for nonexistent users. We shouldn't assert - // though... - if !self.exists(user_id)? { - warn!("Called create_device for non-existent user {} in database", user_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); - } - - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: client_ip, - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - )?; - - self.set_token(user_id, device_id, token)?; - - Ok(()) - } - - /// Removes a device from a user. - pub(super) fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Remove tokens - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.userdeviceid_token.remove(&userdeviceid)?; - self.token_userdeviceid.remove(&old_token)?; - } - - // Remove todevice events - let mut prefix = userdeviceid.clone(); - prefix.push(0xFF); - - for (key, _) in self.todeviceid_events.scan_prefix(prefix) { - self.todeviceid_events.remove(&key)?; - } - - // TODO: Remove onetimekeys - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.remove(&userdeviceid)?; - - Ok(()) - } - - /// Returns an iterator over all device ids of this user. - pub(super) fn all_device_ids<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - // All devices have metadata - Box::new( - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("UserDevice ID in db is invalid.")))?, - ) - .map_err(|e| err!(Database("Device ID in userdeviceid_metadata is invalid. {e}")))? - .into()) - }), - ) - } - - /// Replaces the access token of one device. - pub(super) fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // should not be None, but we shouldn't assert either lol... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - return Err!(Database(error!( - "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." - ))); - } - - // Remove old token - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(&old_token)?; - // It will be removed from userdeviceid_token by the insert later - } - - // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; - - Ok(()) - } - - pub(super) fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - // All devices have metadata - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&key)?.is_none() { - return Err!(Database(error!( - "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." - ))); - } - - key.push(0xFF); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - )?; - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - pub(super) fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|e| err!(Database("Count in roomid_lastroomactiveupdate is invalid. {e}"))) - }) - } - - pub(super) fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; - - self.onetimekeyid_onetimekeys - .scan_prefix(prefix) - .next() - .map(|(key, value)| { - self.onetimekeyid_onetimekeys.remove(&key)?; - - Ok(( - serde_json::from_slice( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid.")))?, - ) - .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}")))?, - serde_json::from_slice(&value).map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}")))?, - )) - }) - .transpose() - } - - pub(super) fn count_one_time_keys( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - let mut counts = BTreeMap::new(); - - for algorithm in self - .onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKey ID in db is invalid.")))?, - ) - .map_err(|e| err!(Database("DeviceKeyId in db is invalid. {e}")))? - .algorithm(), - ) - }) { - let count: &mut UInt = counts.entry(algorithm?).or_default(); - *count = count.saturating_add(uint!(1)); - } - - Ok(counts) - } - - pub(super) fn add_device_keys( - &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, - ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - )?; - - self.mark_device_key_update(user_id)?; - - Ok(()) - } - - pub(super) fn add_cross_signing_keys( - &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, - user_signing_key: &Option>, notify: bool, - ) -> Result<()> { - // TODO: Check signatures - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let (master_key_key, _) = Self::parse_master_key(user_id, master_key)?; - - self.keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes())?; - - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; - - self.userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key)?; - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?; - - if user_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained more than one key.", - )); - } - - let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; - - self.userid_usersigningkeyid - .insert(user_id.as_bytes(), &user_signing_key_key)?; - } - - if notify { - self.mark_device_key_update(user_id)?; - } - - Ok(()) - } - - pub(super) fn sign_key( - &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - ) -> Result<()> { - let mut key = target_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(key_id.as_bytes()); - - let mut cross_signing_key: serde_json::Value = serde_json::from_slice( - &self - .keyid_key - .get(&key)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?, - ) - .map_err(|e| err!(Database("key in keyid_key is invalid. {e}")))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? - .as_object_mut() - .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? - .entry(sender_id.to_string()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? - .insert(signature.0, signature.1.into()); - - self.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - )?; - - self.mark_device_key_update(target_id)?; - - Ok(()) - } - - pub(super) fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> Box> + 'a> { - let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let mut start = prefix.clone(); - start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes()); - - let to = to.unwrap_or(u64::MAX); - - Box::new( - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to - } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); - false - } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") - })?, - ) - .map_err(|e| err!(Database("User ID in devicekeychangeid_userid is invalid. {e}"))) - }), - ) - } - - pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = self.services.globals.next_count()?.to_be_bytes(); - for room_id in self - .services - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - // Don't send key updates to unencrypted rooms - if self - .services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? - .is_none() - { - continue; - } - - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - - Ok(()) - } - - pub(super) fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes).map_err(|e| err!(Database("DeviceKeys in db are invalid. {e}")))?, - )) - }) - } - - pub(super) fn parse_master_key( - user_id: &UserId, master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let master_key = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; - let mut master_key_ids = master_key.keys.values(); - let master_key_id = master_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - Ok((master_key_key, master_key)) - } - - pub(super) fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?; - clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?; - - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"), - ))) - }) - } - - pub(super) fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_masterkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - pub(super) fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_selfsigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - pub(super) fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes) - .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?, - )) - }) - }) - } - - pub(super) fn add_to_device_event( - &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, - content: serde_json::Value, - ) -> Result<()> { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.todeviceid_events.insert(&key, &value)?; - - Ok(()) - } - - pub(super) fn get_to_device_events( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result>> { - let mut events = Vec::new(); - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|e| err!(Database("Event in todeviceid_events is invalid. {e}")))?, - ); - } - - Ok(events) - } - - pub(super) fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - let mut last = prefix.clone(); - last.extend_from_slice(&until.to_be_bytes()); - - for (key, _) in self - .todeviceid_events - .iter_from(&last, true) // this includes last - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes(&key[key.len().saturating_sub(size_of::())..key.len()]) - .map_err(|e| err!(Database("ToDeviceId has invalid count bytes. {e}")))?, - )) - }) - .filter_map(Result::ok) - .take_while(|&(_, count)| count <= until) - { - self.todeviceid_events.remove(&key)?; - } - - Ok(()) - } - - pub(super) fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!( - "Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \ - metadata in database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - )?; - - Ok(()) - } - - /// Get device metadata. - pub(super) fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) - } - - pub(super) fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|e| err!(Database("Invalid devicelistversion in db. {e}"))) - .map(Some) - }) - } - - pub(super) fn all_devices_metadata<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - - Box::new( - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::(&bytes) - .map_err(|e| err!(Database("Device in userdeviceid_metadata is invalid. {e}"))) - }), - ) - } - - /// Creates a new sync filter. Returns the filter id. - pub(super) fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - let filter_id = utils::random_string(4); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - self.userfilterid_filter - .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; - - Ok(filter_id) - } - - pub(super) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - let raw = self.userfilterid_filter.get(&key)?; - - if let Some(raw) = raw { - serde_json::from_slice(&raw).map_err(|e| err!(Database("Invalid filter event in db. {e}"))) - } else { - Ok(None) - } - } - - /// Creates an OpenID token, which can be used to prove that a user has - /// access to an account (primarily for integrations) - pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { - use std::num::Saturating as Sat; - - let expires_in = self.services.server.config.openid_token_ttl; - let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); - - let mut value = expires_at.0.to_be_bytes().to_vec(); - value.extend_from_slice(user_id.as_bytes()); - - self.openidtoken_expiresatuserid - .insert(token.as_bytes(), value.as_slice())?; - - Ok(expires_in) - } - - /// Find out which user an OpenID access token belongs to. - pub(super) fn find_from_openid_token(&self, token: &str) -> Result { - let Some(value) = self.openidtoken_expiresatuserid.get(token.as_bytes())? else { - return Err(Error::BadRequest(ErrorKind::Unauthorized, "OpenID token is unrecognised")); - }; - - let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); - - let expires_at = u64::from_be_bytes( - expires_at_bytes - .try_into() - .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, - ); - - if expires_at < utils::millis_since_unix_epoch() { - debug_info!("OpenID token is expired, removing"); - self.openidtoken_expiresatuserid.remove(token.as_bytes())?; - - return Err(Error::BadRequest(ErrorKind::Unauthorized, "OpenID token is expired")); - } - - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) - } -} - -/// Will only return with Some(username) if the password was not empty and the -/// username could be successfully parsed. -/// If `utils::string_from_bytes`(...) returns an error that username will be -/// skipped and the error will be logged. -pub(super) fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!("Failed to parse username while calling get_local_users(): {}", e.to_string()); - None - }, - } - } -} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 80897b5f..9a058ba9 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,552 +1,984 @@ -mod data; +use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc}; -use std::{ - collections::{BTreeMap, BTreeSet}, - mem, - sync::{Arc, Mutex, Mutex as StdMutex}, +use conduit::{ + debug_warn, err, utils, + utils::{stream::TryIgnore, string::Unquoted, ReadyExt, TryReadyExt}, + warn, Err, Error, Result, Server, }; - -use conduit::{Error, Result}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt}; use ruma::{ - api::client::{ - device::Device, - filter::FilterDefinition, - sync::sync_events::{ - self, - v4::{ExtensionsConfig, SyncRequestList}, - }, - }, + api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::AnyToDeviceEvent, + events::{AnyToDeviceEvent, StateEventType}, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedRoomId, OwnedUserId, - UInt, UserId, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, + OwnedMxcUri, OwnedUserId, UInt, UserId, }; -use self::data::Data; -use crate::{admin, rooms, Dep}; +use crate::{admin, globals, rooms, Dep}; pub struct Service { - connections: DbConnections, - pub db: Data, services: Services, + db: Data, } struct Services { + server: Arc, admin: Dep, + globals: Dep, + state_accessor: Dep, state_cache: Dep, } +struct Data { + keychangeid_userid: Arc, + keyid_key: Arc, + onetimekeyid_onetimekeys: Arc, + openidtoken_expiresatuserid: Arc, + todeviceid_events: Arc, + token_userdeviceid: Arc, + userdeviceid_metadata: Arc, + userdeviceid_token: Arc, + userfilterid_filter: Arc, + userid_avatarurl: Arc, + userid_blurhash: Arc, + userid_devicelistversion: Arc, + userid_displayname: Arc, + userid_lastonetimekeyupdate: Arc, + userid_masterkeyid: Arc, + userid_password: Arc, + userid_selfsigningkeyid: Arc, + userid_usersigningkeyid: Arc, + useridprofilekey_value: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - connections: StdMutex::new(BTreeMap::new()), - db: Data::new(&args), services: Services { + server: args.server.clone(), admin: args.depend::("admin"), + globals: args.depend::("globals"), + state_accessor: args.depend::("rooms::state_accessor"), state_cache: args.depend::("rooms::state_cache"), }, + db: Data { + keychangeid_userid: args.db["keychangeid_userid"].clone(), + keyid_key: args.db["keyid_key"].clone(), + onetimekeyid_onetimekeys: args.db["onetimekeyid_onetimekeys"].clone(), + openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(), + todeviceid_events: args.db["todeviceid_events"].clone(), + token_userdeviceid: args.db["token_userdeviceid"].clone(), + userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(), + userdeviceid_token: args.db["userdeviceid_token"].clone(), + userfilterid_filter: args.db["userfilterid_filter"].clone(), + userid_avatarurl: args.db["userid_avatarurl"].clone(), + userid_blurhash: args.db["userid_blurhash"].clone(), + userid_devicelistversion: args.db["userid_devicelistversion"].clone(), + userid_displayname: args.db["userid_displayname"].clone(), + userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(), + userid_masterkeyid: args.db["userid_masterkeyid"].clone(), + userid_password: args.db["userid_password"].clone(), + userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(), + userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(), + useridprofilekey_value: args.db["useridprofilekey_value"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -type DbConnections = Mutex>; -type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); -type DbConnectionsVal = Arc>; - -struct SlidingSyncCache { - lists: BTreeMap, - subscriptions: BTreeMap, - known_rooms: BTreeMap>, // For every room, the roomsince number - extensions: ExtensionsConfig, -} - impl Service { - /// Check if a user has an account on this homeserver. - #[inline] - pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } - - pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool { - self.connections - .lock() - .unwrap() - .contains_key(&(user_id, device_id, conn_id)) - } - - pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { - self.connections - .lock() - .unwrap() - .remove(&(user_id, device_id, conn_id)); - } - - pub fn update_sync_request_with_cache( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, - ) -> BTreeMap> { - let Some(conn_id) = request.conn_id.clone() else { - return BTreeMap::new(); - }; - - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - for (list_id, list) in &mut request.lists { - if let Some(cached_list) = cached.lists.get(list_id) { - if list.sort.is_empty() { - list.sort.clone_from(&cached_list.sort); - }; - if list.room_details.required_state.is_empty() { - list.room_details - .required_state - .clone_from(&cached_list.room_details.required_state); - }; - list.room_details.timeline_limit = list - .room_details - .timeline_limit - .or(cached_list.room_details.timeline_limit); - list.include_old_rooms = list - .include_old_rooms - .clone() - .or_else(|| cached_list.include_old_rooms.clone()); - match (&mut list.filters, cached_list.filters.clone()) { - (Some(list_filters), Some(cached_filters)) => { - list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); - if list_filters.spaces.is_empty() { - list_filters.spaces = cached_filters.spaces; - } - list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); - list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); - if list_filters.room_types.is_empty() { - list_filters.room_types = cached_filters.room_types; - } - if list_filters.not_room_types.is_empty() { - list_filters.not_room_types = cached_filters.not_room_types; - } - list_filters.room_name_like = list_filters - .room_name_like - .clone() - .or(cached_filters.room_name_like); - if list_filters.tags.is_empty() { - list_filters.tags = cached_filters.tags; - } - if list_filters.not_tags.is_empty() { - list_filters.not_tags = cached_filters.not_tags; - } - }, - (_, Some(cached_filters)) => list.filters = Some(cached_filters), - (Some(list_filters), _) => list.filters = Some(list_filters.clone()), - (..) => {}, - } - if list.bump_event_types.is_empty() { - list.bump_event_types - .clone_from(&cached_list.bump_event_types); - }; - } - cached.lists.insert(list_id.clone(), list.clone()); - } - - cached - .subscriptions - .extend(request.room_subscriptions.clone()); - request - .room_subscriptions - .extend(cached.subscriptions.clone()); - - request.extensions.e2ee.enabled = request - .extensions - .e2ee - .enabled - .or(cached.extensions.e2ee.enabled); - - request.extensions.to_device.enabled = request - .extensions - .to_device - .enabled - .or(cached.extensions.to_device.enabled); - - request.extensions.account_data.enabled = request - .extensions - .account_data - .enabled - .or(cached.extensions.account_data.enabled); - request.extensions.account_data.lists = request - .extensions - .account_data - .lists - .clone() - .or_else(|| cached.extensions.account_data.lists.clone()); - request.extensions.account_data.rooms = request - .extensions - .account_data - .rooms - .clone() - .or_else(|| cached.extensions.account_data.rooms.clone()); - - cached.extensions = request.extensions.clone(); - - cached.known_rooms.clone() - } - - pub fn update_sync_subscriptions( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, - subscriptions: BTreeMap, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - cached.subscriptions = subscriptions; - } - - pub fn update_sync_known_rooms( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, - new_cached_rooms: BTreeSet, globalsince: u64, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - for (roomid, lastsince) in cached - .known_rooms - .entry(list_id.clone()) - .or_default() - .iter_mut() - { - if !new_cached_rooms.contains(roomid) { - *lastsince = 0; - } - } - let list = cached.known_rooms.entry(list_id).or_default(); - for roomid in new_cached_rooms { - list.insert(roomid, globalsince); - } - } - - /// Check if account is deactivated - pub fn is_deactivated(&self, user_id: &UserId) -> Result { self.db.is_deactivated(user_id) } - /// Check if a user is an admin - pub fn is_admin(&self, user_id: &UserId) -> Result { - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { - self.services.state_cache.is_joined(user_id, &admin_room_id) - } else { - Ok(false) - } - } + #[inline] + pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await } /// Create a new user account on this homeserver. #[inline] pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password)?; - Ok(()) - } - - /// Returns the number of users registered on this server. - #[inline] - pub fn count(&self) -> Result { self.db.count() } - - /// Find out which user an access token belongs to. - pub fn find_from_token(&self, token: &str) -> Result> { - self.db.find_from_token(token) - } - - /// Returns an iterator over all users on this homeserver. - pub fn iter(&self) -> impl Iterator> + '_ { self.db.iter() } - - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is - /// greater then zero. - pub fn list_local_users(&self) -> Result> { self.db.list_local_users() } - - /// Returns the password hash for the given user. - pub fn password_hash(&self, user_id: &UserId) -> Result> { self.db.password_hash(user_id) } - - /// Hash and set the user's password to the Argon2 hash - #[inline] - pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password) - } - - /// Returns the displayname of a user on this homeserver. - pub fn displayname(&self, user_id: &UserId) -> Result> { self.db.displayname(user_id) } - - /// Sets a new displayname or removes it if displayname is None. You still - /// need to nofify all rooms of this change. - pub async fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - self.db.set_displayname(user_id, displayname) - } - - /// Get the avatar_url of a user. - pub fn avatar_url(&self, user_id: &UserId) -> Result> { self.db.avatar_url(user_id) } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub async fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { - self.db.set_avatar_url(user_id, avatar_url) - } - - /// Get the blurhash of a user. - pub fn blurhash(&self, user_id: &UserId) -> Result> { self.db.blurhash(user_id) } - - pub fn timezone(&self, user_id: &UserId) -> Result> { self.db.timezone(user_id) } - - /// Gets a specific user profile key - pub fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result> { - self.db.profile_key(user_id, profile_key) - } - - /// Gets all the user's profile keys and values in an iterator - pub fn all_profile_keys<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a + Send> { - self.db.all_profile_keys(user_id) - } - - /// Sets a new profile key value, removes the key if value is None - pub fn set_profile_key( - &self, user_id: &UserId, profile_key: &str, profile_key_value: Option, - ) -> Result<()> { - self.db - .set_profile_key(user_id, profile_key, profile_key_value) - } - - /// Sets a new tz or removes it if tz is None. - pub async fn set_timezone(&self, user_id: &UserId, tz: Option) -> Result<()> { - self.db.set_timezone(user_id, tz) - } - - /// Sets a new blurhash or removes it if blurhash is None. - pub async fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - self.db.set_blurhash(user_id, blurhash) - } - - /// Adds a new device to a user. - pub fn create_device( - &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, - client_ip: Option, - ) -> Result<()> { - self.db - .create_device(user_id, device_id, token, initial_device_display_name, client_ip) - } - - /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.remove_device(user_id, device_id) - } - - /// Returns an iterator over all device ids of this user. - pub fn all_device_ids<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { - self.db.all_device_ids(user_id) - } - - /// Replaces the access token of one device. - #[inline] - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - self.db.set_token(user_id, device_id, token) - } - - pub fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()> { - self.db - .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) - } - - // TODO: use this ? - #[allow(dead_code)] - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.db.last_one_time_keys_update(user_id) - } - - pub fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - self.db.take_one_time_key(user_id, device_id, key_algorithm) - } - - pub fn count_one_time_keys( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - self.db.count_one_time_keys(user_id, device_id) - } - - pub fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { - self.db.add_device_keys(user_id, device_id, device_keys) - } - - pub fn add_cross_signing_keys( - &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, - user_signing_key: &Option>, notify: bool, - ) -> Result<()> { - self.db - .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) - } - - pub fn sign_key( - &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - ) -> Result<()> { - self.db.sign_key(target_id, key_id, signature, sender_id) - } - - pub fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> impl Iterator> + 'a { - self.db.keys_changed(user_or_room_id, from, to) - } - - #[inline] - pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { self.db.mark_device_key_update(user_id) } - - pub fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - self.db.get_device_keys(user_id, device_id) - } - - #[inline] - pub fn parse_master_key( - &self, user_id: &UserId, master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - Data::parse_master_key(user_id, master_key) - } - - #[inline] - pub fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_key(key, sender_user, user_id, allowed_signatures) - } - - pub fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_master_key(sender_user, user_id, allowed_signatures) - } - - pub fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_self_signing_key(sender_user, user_id, allowed_signatures) - } - - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.db.get_user_signing_key(user_id) - } - - pub fn add_to_device_event( - &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, - content: serde_json::Value, - ) -> Result<()> { - self.db - .add_to_device_event(sender, target_user_id, target_device_id, event_type, content) - } - - pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - self.db.get_to_device_events(user_id, device_id) - } - - pub fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - self.db.remove_to_device_events(user_id, device_id, until) - } - - pub fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - self.db.update_device_metadata(user_id, device_id, device) - } - - /// Get device metadata. - pub fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { - self.db.get_device_metadata(user_id, device_id) - } - - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.db.get_devicelist_version(user_id) - } - - pub fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { - self.db.all_devices_metadata(user_id) + self.set_password(user_id, password) } /// Deactivate account - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> { // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } + self.all_device_ids(user_id) + .for_each(|device_id| self.remove_device(user_id, device_id)) + .await; // Set the password to "" to indicate a deactivated account. Hashes will never // result in an empty string, so the user will not be able to log in again. // Systems like changing the password without logging in should check if the // account is deactivated. - self.db.set_password(user_id, None)?; + self.set_password(user_id, None)?; // TODO: Unhook 3PID Ok(()) } - /// Creates a new sync filter. Returns the filter id. - pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - self.db.create_filter(user_id, filter) + /// Check if a user has an account on this homeserver. + #[inline] + pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.qry(user_id).await.is_ok() } + + /// Check if account is deactivated + pub async fn is_deactivated(&self, user_id: &UserId) -> Result { + self.db + .userid_password + .qry(user_id) + .map_ok(|val| val.is_empty()) + .map_err(|_| err!(Request(NotFound("User does not exist.")))) + .await } - pub fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { - self.db.get_filter(user_id, filter_id) + /// Check if account is active, infallible + pub async fn is_active(&self, user_id: &UserId) -> bool { !self.is_deactivated(user_id).await.unwrap_or(true) } + + /// Check if account is active, infallible + pub async fn is_active_local(&self, user_id: &UserId) -> bool { + self.services.globals.user_is_local(user_id) && self.is_active(user_id).await + } + + /// Returns the number of users registered on this server. + #[inline] + pub async fn count(&self) -> usize { self.db.userid_password.count().await } + + /// Find out which user an access token belongs to. + pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> { + self.db.token_userdeviceid.qry(token).await.deserialized() + } + + /// Returns an iterator over all users on this homeserver (offered for + /// compatibility) + #[allow(clippy::iter_without_into_iter, clippy::iter_not_returning_iterator)] + pub fn iter(&self) -> impl Stream + Send + '_ { self.stream().map(ToOwned::to_owned) } + + /// Returns an iterator over all users on this homeserver. + pub fn stream(&self) -> impl Stream + Send { self.db.userid_password.keys().ignore_err() } + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is + /// greater then zero. + pub fn list_local_users(&self) -> impl Stream + Send + '_ { + self.db + .userid_password + .stream() + .ignore_err() + .ready_filter_map(|(u, p): (&UserId, &[u8])| (!p.is_empty()).then_some(u)) + } + + /// Returns the password hash for the given user. + pub async fn password_hash(&self, user_id: &UserId) -> Result { + self.db.userid_password.qry(user_id).await.deserialized() + } + + /// Hash and set the user's password to the Argon2 hash + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + if let Some(password) = password { + if let Ok(hash) = utils::hash::password(password) { + self.db + .userid_password + .insert(user_id.as_bytes(), hash.as_bytes()); + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Password does not meet the requirements.", + )) + } + } else { + self.db.userid_password.insert(user_id.as_bytes(), b""); + Ok(()) + } + } + + /// Returns the displayname of a user on this homeserver. + pub async fn displayname(&self, user_id: &UserId) -> Result { + self.db.userid_displayname.qry(user_id).await.deserialized() + } + + /// Sets a new displayname or removes it if displayname is None. You still + /// need to nofify all rooms of this change. + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) { + if let Some(displayname) = displayname { + self.db + .userid_displayname + .insert(user_id.as_bytes(), displayname.as_bytes()); + } else { + self.db.userid_displayname.remove(user_id.as_bytes()); + } + } + + /// Get the `avatar_url` of a user. + pub async fn avatar_url(&self, user_id: &UserId) -> Result { + self.db.userid_avatarurl.qry(user_id).await.deserialized() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) { + if let Some(avatar_url) = avatar_url { + self.db + .userid_avatarurl + .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes()); + } else { + self.db.userid_avatarurl.remove(user_id.as_bytes()); + } + } + + /// Get the blurhash of a user. + pub async fn blurhash(&self, user_id: &UserId) -> Result { + self.db.userid_blurhash.qry(user_id).await.deserialized() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) { + if let Some(blurhash) = blurhash { + self.db + .userid_blurhash + .insert(user_id.as_bytes(), blurhash.as_bytes()); + } else { + self.db.userid_blurhash.remove(user_id.as_bytes()); + } + } + + /// Adds a new device to a user. + pub async fn create_device( + &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, + client_ip: Option, + ) -> Result<()> { + // This method should never be called for nonexistent users. We shouldn't assert + // though... + if !self.exists(user_id).await { + warn!("Called create_device for non-existent user {} in database", user_id); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); + } + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + self.db.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(&Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: client_ip, + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }) + .expect("Device::to_string never fails."), + ); + + self.set_token(user_id, device_id, token).await?; + + Ok(()) + } + + /// Removes a device from a user. + pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Remove tokens + if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await { + self.db.userdeviceid_token.remove(&userdeviceid); + self.db.token_userdeviceid.remove(&old_token); + } + + // Remove todevice events + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.todeviceid_events.remove(key)) + .await; + + // TODO: Remove onetimekeys + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + self.db.userdeviceid_metadata.remove(&userdeviceid); + } + + /// Returns an iterator over all device ids of this user. + pub fn all_device_ids<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); + self.db + .userdeviceid_metadata + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, device_id): (Ignore, &DeviceId)| device_id) + } + + /// Replaces the access token of one device. + pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + let key = (user_id, device_id); + // should not be None, but we shouldn't assert either lol... + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + // Remove old token + if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { + self.db.token_userdeviceid.remove(&old_token); + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to user device combination + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + self.db + .userdeviceid_token + .insert(&userdeviceid, token.as_bytes()); + self.db + .token_userdeviceid + .insert(token.as_bytes(), &userdeviceid); + + Ok(()) + } + + pub async fn add_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + ) -> Result<()> { + // All devices have metadata + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + key.push(0xFF); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), + ); + + self.db.onetimekeyid_onetimekeys.insert( + &key, + &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), + ); + + self.db + .userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes()); + + Ok(()) + } + + pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { + self.db + .userid_lastonetimekeyupdate + .qry(user_id) + .await + .deserialized() + .unwrap_or(0) + } + + pub async fn take_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, + ) -> Result<(OwnedDeviceKeyId, Raw)> { + self.db + .userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes()); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + let one_time_key = self + .db + .onetimekeyid_onetimekeys + .raw_stream_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + self.db.onetimekeyid_onetimekeys.remove(key); + + let key = key + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid."))) + .unwrap(); + + let key = serde_json::from_slice(key) + .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}"))) + .unwrap(); + + let val = serde_json::from_slice(val) + .map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}"))) + .unwrap(); + + (key, val) + }) + .next() + .await; + + one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found")))) + } + + pub async fn count_one_time_keys( + &self, user_id: &UserId, device_id: &DeviceId, + ) -> BTreeMap { + type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore); + + let mut algorithm_counts = BTreeMap::::new(); + let query = (user_id, device_id); + self.db + .onetimekeyid_onetimekeys + .stream_prefix(&query) + .ignore_err() + .ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| { + let device_key_id: &DeviceKeyId = device_key_id + .as_str() + .try_into() + .expect("Invalid DeviceKeyID in database"); + + let count: &mut UInt = algorithm_counts + .entry(device_key_id.algorithm()) + .or_default(); + + *count = count.saturating_add(1_u32.into()); + }) + .await; + + algorithm_counts + } + + pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.db.keyid_key.insert( + &userdeviceid, + &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), + ); + + self.mark_device_key_update(user_id).await; + } + + pub async fn add_cross_signing_keys( + &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, + user_signing_key: &Option>, notify: bool, + ) -> Result<()> { + // TODO: Check signatures + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let (master_key_key, _) = parse_master_key(user_id, master_key)?; + + self.db + .keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes()); + + self.db + .userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key); + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes()); + + self.db + .userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key); + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids + .next() + .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?; + + if user_signing_key_ids.next().is_some() { + return Err!(Request(InvalidParam("User signing key contained more than one key."))); + } + + let mut user_signing_key_key = prefix; + user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes()); + + self.db + .userid_usersigningkeyid + .insert(user_id.as_bytes(), &user_signing_key_key); + } + + if notify { + self.mark_device_key_update(user_id).await; + } + + Ok(()) + } + + pub async fn sign_key( + &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, + ) -> Result<()> { + let key = (target_id, key_id); + + let mut cross_signing_key: serde_json::Value = self + .db + .keyid_key + .qry(&key) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key."))))? + .deserialized_json() + .map_err(|e| err!(Database("key in keyid_key is invalid. {e:?}")))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? + .as_object_mut() + .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? + .entry(sender_id.to_string()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? + .insert(signature.0, signature.1.into()); + + let mut key = target_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(key_id.as_bytes()); + self.db.keyid_key.insert( + &key, + &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), + ); + + self.mark_device_key_update(target_id).await; + + Ok(()) + } + + pub fn keys_changed<'a>( + &'a self, user_or_room_id: &'a str, from: u64, to: Option, + ) -> impl Stream + Send + 'a { + type KeyVal<'a> = ((&'a str, u64), &'a UserId); + + let to = to.unwrap_or(u64::MAX); + let start = (user_or_room_id, from.saturating_add(1)); + self.db + .keychangeid_userid + .stream_from(&start) + .ignore_err() + .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| *prefix == user_or_room_id && *count <= to) + .map(|((..), user_id): KeyVal<'_>| user_id) + } + + pub async fn mark_device_key_update(&self, user_id: &UserId) { + let count = self.services.globals.next_count().unwrap().to_be_bytes(); + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + pin_mut!(rooms_joined); + while let Some(room_id) = rooms_joined.next().await { + // Don't send key updates to unencrypted rooms + if self + .services + .state_accessor + .room_state_get(room_id, &StateEventType::RoomEncryption, "") + .await + .is_err() + { + continue; + } + + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); + + self.db.keychangeid_userid.insert(&key, user_id.as_bytes()); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); + self.db.keychangeid_userid.insert(&key, user_id.as_bytes()); + } + + pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result> { + let key_id = (user_id, device_id); + self.db.keyid_key.qry(&key_id).await.deserialized_json() + } + + pub async fn get_key( + &self, key_id: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key = self + .db + .keyid_key + .qry(key_id) + .await + .deserialized_json::()?; + + let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; + let raw_value = serde_json::value::to_raw_value(&cleaned)?; + Ok(Raw::from_json(raw_value)) + } + + pub async fn get_master_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key_id = self.db.userid_masterkeyid.qry(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await + } + + pub async fn get_self_signing_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key_id = self.db.userid_selfsigningkeyid.qry(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await + } + + pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { + let key_id = self.db.userid_usersigningkeyid.qry(user_id).await?; + + self.db.keyid_key.qry(&*key_id).await.deserialized_json() + } + + pub async fn add_to_device_event( + &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, + content: serde_json::Value, + ) { + let mut key = target_user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(target_device_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); + + let mut json = serde_json::Map::new(); + json.insert("type".to_owned(), event_type.to_owned().into()); + json.insert("sender".to_owned(), sender.to_string().into()); + json.insert("content".to_owned(), content); + + let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); + + self.db.todeviceid_events.insert(&key, &value); + } + + pub fn get_to_device_events<'a>( + &'a self, user_id: &'a UserId, device_id: &'a DeviceId, + ) -> impl Stream> + Send + 'a { + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .stream_raw_prefix(&prefix) + .ready_and_then(|(_, val)| serde_json::from_slice(val).map_err(Into::into)) + .ignore_err() + } + + pub async fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + + let mut last = prefix.clone(); + last.extend_from_slice(&until.to_be_bytes()); + + self.db + .todeviceid_events + .rev_raw_keys_from(&last) // this includes last + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|key| { + let len = key.len(); + let start = len.saturating_sub(size_of::()); + let count = utils::u64_from_u8(&key[start..len]); + (key, count) + }) + .ready_take_while(move |(_, count)| *count <= until) + .ready_for_each(|(key, _)| self.db.todeviceid_events.remove(&key)) + .boxed() + .await; + } + + pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "Called update_device_metadata for a non-existent user and/or device" + ))); + } + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + self.db.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(device).expect("Device::to_string always works"), + ); + + Ok(()) + } + + /// Get device metadata. + pub async fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result { + self.db + .userdeviceid_metadata + .qry(&(user_id, device_id)) + .await + .deserialized_json() + } + + pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { + self.db + .userid_devicelistversion + .qry(user_id) + .await + .deserialized() + } + + pub fn all_devices_metadata<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + self.db + .userdeviceid_metadata + .stream_raw_prefix(&(user_id, Interfix)) + .ready_and_then(|(_, val)| serde_json::from_slice::(val).map_err(Into::into)) + .ignore_err() + } + + /// Creates a new sync filter. Returns the filter id. + pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String { + let filter_id = utils::random_string(4); + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(filter_id.as_bytes()); + + self.db + .userfilterid_filter + .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json")); + + filter_id + } + + pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result { + self.db + .userfilterid_filter + .qry(&(user_id, filter_id)) + .await + .deserialized_json() } /// Creates an OpenID token, which can be used to prove that a user has /// access to an account (primarily for integrations) pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { - self.db.create_openid_token(user_id, token) + use std::num::Saturating as Sat; + + let expires_in = self.services.server.config.openid_token_ttl; + let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); + + let mut value = expires_at.0.to_be_bytes().to_vec(); + value.extend_from_slice(user_id.as_bytes()); + + self.db + .openidtoken_expiresatuserid + .insert(token.as_bytes(), value.as_slice()); + + Ok(expires_in) } /// Find out which user an OpenID access token belongs to. - pub fn find_from_openid_token(&self, token: &str) -> Result { self.db.find_from_openid_token(token) } + pub async fn find_from_openid_token(&self, token: &str) -> Result { + let Ok(value) = self.db.openidtoken_expiresatuserid.qry(token).await else { + return Err!(Request(Unauthorized("OpenID token is unrecognised"))); + }; + + let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); + let expires_at = u64::from_be_bytes( + expires_at_bytes + .try_into() + .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, + ); + + if expires_at < utils::millis_since_unix_epoch() { + debug_warn!("OpenID token is expired, removing"); + self.db.openidtoken_expiresatuserid.remove(token.as_bytes()); + + return Err!(Request(Unauthorized("OpenID token is expired"))); + } + + let user_string = utils::string_from_bytes(user_bytes) + .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?; + + UserId::parse(user_string).map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) + } + + /// Gets a specific user profile key + pub async fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result { + let key = (user_id, profile_key); + self.db + .useridprofilekey_value + .qry(&key) + .await + .deserialized() + } + + /// Gets all the user's profile keys and values in an iterator + pub fn all_profile_keys<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream + 'a + Send { + type KeyVal = ((Ignore, String), serde_json::Value); + + let prefix = (user_id, Interfix); + self.db + .useridprofilekey_value + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, key), val): KeyVal| (key, val)) + } + + /// Sets a new profile key value, removes the key if value is None + pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(profile_key.as_bytes()); + + // TODO: insert to the stable MSC4175 key when it's stable + if let Some(value) = profile_key_value { + let value = serde_json::to_vec(&value).unwrap(); + + self.db.useridprofilekey_value.insert(&key, &value); + } else { + self.db.useridprofilekey_value.remove(&key); + } + } + + /// Get the timezone of a user. + pub async fn timezone(&self, user_id: &UserId) -> Result { + // TODO: transparently migrate unstable key usage to the stable key once MSC4133 + // and MSC4175 are stable, likely a remove/insert in this block. + + // first check the unstable prefix then check the stable prefix + let unstable_key = (user_id, "us.cloke.msc4175.tz"); + let stable_key = (user_id, "m.tz"); + self.db + .useridprofilekey_value + .qry(&unstable_key) + .or_else(|_| self.db.useridprofilekey_value.qry(&stable_key)) + .await + .deserialized() + } + + /// Sets a new timezone or removes it if timezone is None. + pub fn set_timezone(&self, user_id: &UserId, timezone: Option) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(b"us.cloke.msc4175.tz"); + + // TODO: insert to the stable MSC4175 key when it's stable + if let Some(timezone) = timezone { + self.db + .useridprofilekey_value + .insert(&key, timezone.as_bytes()); + } else { + self.db.useridprofilekey_value.remove(&key); + } + } +} + +pub fn parse_master_key(user_id: &UserId, master_key: &Raw) -> Result<(Vec, CrossSigningKey)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let master_key = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let mut master_key_ids = master_key.keys.values(); + let master_key_id = master_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + Ok((master_key_key, master_key)) } /// Ensure that a user only sees signatures from themselves and the target user -pub fn clean_signatures bool>( - cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, -) -> Result<(), Error> { +fn clean_signatures( + mut cross_signing_key: serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, +) -> Result +where + F: Fn(&UserId) -> bool + Send + Sync, +{ if let Some(signatures) = cross_signing_key .get_mut("signatures") .and_then(|v| v.as_object_mut()) @@ -563,5 +995,12 @@ pub fn clean_signatures bool>( } } - Ok(()) + Ok(cross_signing_key) +} + +//TODO: this is an ABA +fn increment(db: &Arc, key: &[u8]) { + let old = db.get(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, &new); }