diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 4a1f3743..1a2fe73d 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -1033,6 +1033,11 @@ async fn join_room_by_id_helper( drop(state_lock); let pub_key_map = RwLock::new(BTreeMap::new()); + services() + .rooms + .event_handler + .fetch_required_signing_keys([&signed_value], &pub_key_map) + .await?; services() .rooms .event_handler @@ -1259,6 +1264,12 @@ pub(crate) async fn invite_helper<'a>( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; + services() + .rooms + .event_handler + .fetch_required_signing_keys([&value], &pub_key_map) + .await?; + let pdu_id: Vec = services() .rooms .event_handler diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index c8ef9fff..42eb6aa0 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -220,7 +220,7 @@ where let keys_result = services() .rooms .event_handler - .fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_owned()]) + .fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.to_owned()]) .await; let keys = match keys_result { diff --git a/src/api/server_server.rs b/src/api/server_server.rs index b8c42d80..357cafad 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -764,6 +764,7 @@ pub async fn send_transaction_message_route( // events that it references. // let mut auth_cache = EventMap::new(); + let mut parsed_pdus = vec![]; for pdu in &body.pdus { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { warn!("Error parsing incoming event {:?}: {:?}", pdu, e); @@ -791,8 +792,28 @@ pub async fn send_transaction_message_route( continue; } }; + parsed_pdus.push((event_id, value, room_id)); // We do not add the event_id field to the pdu here because of signature and hashes checks + } + // We go through all the signatures we see on the PDUs and fetch the corresponding + // signing keys + services() + .rooms + .event_handler + .fetch_required_signing_keys( + parsed_pdus.iter().map(|(_event_id, event, _room_id)| event), + &pub_key_map, + ) + .await + .unwrap_or_else(|e| { + warn!( + "Could not fetch all signatures for PDUs from {}: {:?}", + sender_servername, e + ) + }); + + for (event_id, value, room_id) in parsed_pdus { let mutex = Arc::clone( services() .globals diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index b22f8ed4..1e9529d8 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -803,7 +803,7 @@ impl Service { services() .rooms .event_handler - .fetch_required_signing_keys(&value, &pub_key_map) + .fetch_required_signing_keys([&value], &pub_key_map) .await?; let pub_key_map = pub_key_map.read().unwrap(); diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index e26dfb01..db7158c4 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -286,16 +286,11 @@ impl Service { pub_key_map: &'a RwLock>>, ) -> AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>> { Box::pin(async move { - // 1.1. Remove unsigned field + // 1. Remove unsigned field value.remove("unsigned"); // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - // We go through all the signatures we see on the value and fetch the corresponding signing - // keys - self.fetch_required_signing_keys(&value, pub_key_map) - .await?; - // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match let create_event_content: RoomCreateEventContent = @@ -1032,14 +1027,14 @@ impl Service { hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), }; - let mut pdus = vec![]; + let mut events_with_auth_events = vec![]; for id in events { // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree // (get_pdu_json checks both) if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { trace!("Found {} in db", id); - pdus.push((local_pdu, None)); + events_with_auth_events.push((id, Some(local_pdu), vec![])); continue; } @@ -1138,7 +1133,36 @@ impl Service { } } } + events_with_auth_events.push((id, None, events_in_reverse_order)) + } + // We go through all the signatures we see on the PDUs and their unresolved + // dependencies and fetch the corresponding signing keys + info!("fetch_required_signing_keys for {}", origin); + self.fetch_required_signing_keys( + events_with_auth_events + .iter() + .flat_map(|(_id, _local_pdu, events)| events) + .map(|(_event_id, event)| event), + pub_key_map, + ) + .await + .unwrap_or_else(|e| { + warn!( + "Could not fetch all signatures for PDUs from {}: {:?}", + origin, e + ) + }); + + let mut pdus = vec![]; + for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Some(local_pdu) = local_pdu { + trace!("Found {} in db", id); + pdus.push((local_pdu, None)); + } for (next_id, value) in events_in_reverse_order.iter().rev() { if let Some((time, tries)) = services() .globals @@ -1289,53 +1313,94 @@ impl Service { } #[tracing::instrument(skip_all)] - pub(crate) async fn fetch_required_signing_keys( - &self, - event: &BTreeMap, + pub(crate) async fn fetch_required_signing_keys<'a, E>( + &'a self, + events: E, pub_key_map: &RwLock>>, - ) -> Result<()> { - let signatures = event - .get("signatures") - .ok_or(Error::BadServerResponse( - "No signatures in server response pdu.", - ))? - .as_object() - .ok_or(Error::BadServerResponse( - "Invalid signatures object in server response pdu.", - ))?; + ) -> Result<()> + where + E: IntoIterator>, + { + let mut server_key_ids = HashMap::new(); - // We go through all the signatures we see on the value and fetch the corresponding signing - // keys - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; + for event in events.into_iter() { + for (signature_server, signature) in event + .get("signatures") + .ok_or(Error::BadServerResponse( + "No signatures in server response pdu.", + ))? + .as_object() + .ok_or(Error::BadServerResponse( + "Invalid signatures object in server response pdu.", + ))? + { + let signature_object = signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; - let signature_ids = signature_object.keys().cloned().collect::>(); - - let fetch_res = self - .fetch_signing_keys( - signature_server.as_str().try_into().map_err(|_| { - Error::BadServerResponse( - "Invalid servername in signatures of server response pdu.", - ) - })?, - signature_ids, - ) - .await; - - let keys = match fetch_res { - Ok(keys) => keys, - Err(_) => { - warn!("Signature verification failed: Could not fetch signing key.",); - continue; + for signature_id in signature_object.keys() { + server_key_ids + .entry(signature_server.clone()) + .or_insert_with(HashSet::new) + .insert(signature_id.clone()); } - }; + } + } - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(signature_server.clone(), keys); + if server_key_ids.is_empty() { + // Nothing to do, can exit early + return Ok(()); + } + + info!( + "Fetch keys for {}", + server_key_ids + .keys() + .cloned() + .collect::>() + .join(", ") + ); + + let mut server_keys: FuturesUnordered<_> = server_key_ids + .into_iter() + .map(|(signature_server, signature_ids)| async { + let signature_server2 = signature_server.clone(); + let fetch_res = self + .fetch_signing_keys_for_server( + signature_server2.as_str().try_into().map_err(|_| { + ( + signature_server.clone(), + Error::BadServerResponse( + "Invalid servername in signatures of server response pdu.", + ), + ) + })?, + signature_ids.into_iter().collect(), // HashSet to Vec + ) + .await; + + match fetch_res { + Ok(keys) => Ok((signature_server, keys)), + Err(e) => { + warn!("Signature verification failed: Could not fetch signing key.",); + Err((signature_server, e)) + } + } + }) + .collect(); + + while let Some(fetch_res) = server_keys.next().await { + match fetch_res { + Ok((signature_server, keys)) => { + pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))? + .insert(signature_server.clone(), keys); + } + Err((signature_server, e)) => { + warn!("Failed to fetch keys for {}: {:?}", signature_server, e); + } + } } Ok(()) @@ -1594,7 +1659,7 @@ impl Service { /// Search the DB for the signing keys of the given server, if we don't have them /// fetch them from the server and save to our DB. #[tracing::instrument(skip_all)] - pub async fn fetch_signing_keys( + pub async fn fetch_signing_keys_for_server( &self, origin: &ServerName, signature_ids: Vec, diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 25e1c54d..a6774ec6 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1139,6 +1139,12 @@ impl Service { return Ok(()); } + services() + .rooms + .event_handler + .fetch_required_signing_keys([&value], &pub_key_map) + .await?; + services() .rooms .event_handler