From b5f27d9ec2b7a1ed23f9f67ea62f8e7567be4577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Thu, 26 Aug 2021 18:59:10 +0200 Subject: [PATCH] fix: improve key fetching --- src/client_server/keys.rs | 59 +++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 8db7688d..f9895f91 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -1,5 +1,6 @@ use super::SESSION_ID_LENGTH; use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma}; +use rocket::futures::{prelude::*, stream::FuturesUnordered}; use ruma::{ api::{ client::{ @@ -18,7 +19,10 @@ use ruma::{ DeviceId, DeviceKeyAlgorithm, UserId, }; use serde_json::json; -use std::collections::{BTreeMap, HashSet}; +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + time::{Duration, Instant}, +}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; @@ -294,7 +298,7 @@ pub async fn get_keys_helper bool>( let mut user_signing_keys = BTreeMap::new(); let mut device_keys = BTreeMap::new(); - let mut get_over_federation = BTreeMap::new(); + let mut get_over_federation = HashMap::new(); for (user_id, device_ids) in device_keys_input { if user_id.server_name() != db.globals.server_name() { @@ -364,22 +368,30 @@ pub async fn get_keys_helper bool>( let mut failures = BTreeMap::new(); - for (server, vec) in get_over_federation { - let mut device_keys_input_fed = BTreeMap::new(); - for (user_id, keys) in vec { - device_keys_input_fed.insert(user_id.clone(), keys.clone()); - } - match db - .sending - .send_federation_request( - &db.globals, + let mut futures = get_over_federation + .into_iter() + .map(|(server, vec)| async move { + let mut device_keys_input_fed = BTreeMap::new(); + for (user_id, keys) in vec { + device_keys_input_fed.insert(user_id.clone(), keys.clone()); + } + ( server, - federation::keys::get_keys::v1::Request { - device_keys: device_keys_input_fed, - }, + db.sending + .send_federation_request( + &db.globals, + server, + federation::keys::get_keys::v1::Request { + device_keys: device_keys_input_fed, + }, + ) + .await, ) - .await - { + }) + .collect::>(); + + while let Some((server, response)) = futures.next().await { + match response { Ok(response) => { master_keys.extend(response.master_keys); self_signing_keys.extend(response.self_signing_keys); @@ -430,13 +442,15 @@ pub async fn claim_keys_helper( one_time_keys.insert(user_id.clone(), container); } + let mut failures = BTreeMap::new(); + for (server, vec) in get_over_federation { let mut one_time_keys_input_fed = BTreeMap::new(); for (user_id, keys) in vec { one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); } // Ignore failures - let keys = db + if let Ok(keys) = db .sending .send_federation_request( &db.globals, @@ -445,13 +459,16 @@ pub async fn claim_keys_helper( one_time_keys: one_time_keys_input_fed, }, ) - .await?; - - one_time_keys.extend(keys.one_time_keys); + .await + { + one_time_keys.extend(keys.one_time_keys); + } else { + failures.insert(server.to_string(), json!({})); + } } Ok(claim_keys::Response { - failures: BTreeMap::new(), + failures, one_time_keys, }) }