From 33172a70e6683248feae7a79398c1391d58ef2a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Thu, 26 Aug 2021 18:59:10 +0200
Subject: [PATCH] fix: improve key fetching

---
 src/client_server/keys.rs | 59 +++++++++++++++++++++++++--------------
 1 file changed, 38 insertions(+), 21 deletions(-)

diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs
index 8db7688d..f9895f91 100644
--- a/src/client_server/keys.rs
+++ b/src/client_server/keys.rs
@@ -1,5 +1,6 @@
 use super::SESSION_ID_LENGTH;
 use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma};
+use rocket::futures::{prelude::*, stream::FuturesUnordered};
 use ruma::{
     api::{
         client::{
@@ -18,7 +19,10 @@ use ruma::{
     DeviceId, DeviceKeyAlgorithm, UserId,
 };
 use serde_json::json;
-use std::collections::{BTreeMap, HashSet};
+use std::{
+    collections::{BTreeMap, HashMap, HashSet},
+    time::{Duration, Instant},
+};
 
 #[cfg(feature = "conduit_bin")]
 use rocket::{get, post};
@@ -294,7 +298,7 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>(
     let mut user_signing_keys = BTreeMap::new();
     let mut device_keys = BTreeMap::new();
 
-    let mut get_over_federation = BTreeMap::new();
+    let mut get_over_federation = HashMap::new();
 
     for (user_id, device_ids) in device_keys_input {
         if user_id.server_name() != db.globals.server_name() {
@@ -364,22 +368,30 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>(
 
     let mut failures = BTreeMap::new();
 
-    for (server, vec) in get_over_federation {
-        let mut device_keys_input_fed = BTreeMap::new();
-        for (user_id, keys) in vec {
-            device_keys_input_fed.insert(user_id.clone(), keys.clone());
-        }
-        match db
-            .sending
-            .send_federation_request(
-                &db.globals,
+    let mut futures = get_over_federation
+        .into_iter()
+        .map(|(server, vec)| async move {
+            let mut device_keys_input_fed = BTreeMap::new();
+            for (user_id, keys) in vec {
+                device_keys_input_fed.insert(user_id.clone(), keys.clone());
+            }
+            (
                 server,
-                federation::keys::get_keys::v1::Request {
-                    device_keys: device_keys_input_fed,
-                },
+                db.sending
+                    .send_federation_request(
+                        &db.globals,
+                        server,
+                        federation::keys::get_keys::v1::Request {
+                            device_keys: device_keys_input_fed,
+                        },
+                    )
+                    .await,
             )
-            .await
-        {
+        })
+        .collect::<FuturesUnordered<_>>();
+
+    while let Some((server, response)) = futures.next().await {
+        match response {
             Ok(response) => {
                 master_keys.extend(response.master_keys);
                 self_signing_keys.extend(response.self_signing_keys);
@@ -430,13 +442,15 @@ pub async fn claim_keys_helper(
         one_time_keys.insert(user_id.clone(), container);
     }
 
+    let mut failures = BTreeMap::new();
+
     for (server, vec) in get_over_federation {
         let mut one_time_keys_input_fed = BTreeMap::new();
         for (user_id, keys) in vec {
             one_time_keys_input_fed.insert(user_id.clone(), keys.clone());
         }
         // Ignore failures
-        let keys = db
+        if let Ok(keys) = db
             .sending
             .send_federation_request(
                 &db.globals,
@@ -445,13 +459,16 @@ pub async fn claim_keys_helper(
                     one_time_keys: one_time_keys_input_fed,
                 },
             )
-            .await?;
-
-        one_time_keys.extend(keys.one_time_keys);
+            .await
+        {
+            one_time_keys.extend(keys.one_time_keys);
+        } else {
+            failures.insert(server.to_string(), json!({}));
+        }
     }
 
     Ok(claim_keys::Response {
-        failures: BTreeMap::new(),
+        failures,
         one_time_keys,
     })
 }