backoff to valhalla

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-04-24 13:21:36 -07:00 committed by June
parent 255bcf5243
commit 67f9553790

View file

@ -1,6 +1,6 @@
use std::{
collections::{hash_map, HashSet},
time::{Duration, Instant, SystemTime},
collections::HashSet,
time::{Duration, SystemTime},
};
use futures_util::{stream::FuturesUnordered, StreamExt};
@ -14,15 +14,15 @@ use ruma::{
membership::create_join_event,
},
serde::Base64,
CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedServerName,
OwnedServerSigningKeyId, RoomVersionId, ServerName,
CanonicalJsonObject, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedServerSigningKeyId,
RoomVersionId, ServerName,
};
use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore};
use tokio::sync::{RwLock, RwLockWriteGuard};
use tracing::{debug, error, info, trace, warn};
use crate::{
service::{Arc, BTreeMap, HashMap, Result},
service::{BTreeMap, HashMap, Result},
services, Error,
};
@ -117,7 +117,7 @@ impl super::Service {
async fn get_server_keys_from_cache(
&self, pdu: &RawJsonValue,
servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
room_version: &RoomVersionId,
_room_version: &RoomVersionId,
pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
@ -125,31 +125,6 @@ impl super::Service {
Error::BadServerResponse("Invalid PDU in server response")
})?;
let event_id = format!(
"${}",
ruma::signatures::reference_hash(&value, room_version).expect("ruma can calculate reference hashes")
);
let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids");
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(event_id)
{
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
debug!("Backing off from {}", event_id);
return Err(Error::BadServerResponse("bad event, still backing off"));
}
}
let signatures = value
.get("signatures")
.ok_or(Error::BadServerResponse("No signatures in server response pdu."))?
@ -372,62 +347,6 @@ impl super::Service {
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let permit = services()
.globals
.servername_ratelimiter
.read()
.await
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit = if let Some(p) = permit {
p
} else {
let mut write = services().globals.servername_ratelimiter.write().await;
let s = Arc::clone(
write
.entry(origin.to_owned())
.or_insert_with(|| Arc::new(Semaphore::new(1))),
);
s.acquire_owned()
}
.await;
let back_off = |id| async {
match services()
.globals
.bad_signature_ratelimiter
.write()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
if let Some((time, tries)) = services()
.globals
.bad_signature_ratelimiter
.read()
.await
.get(&signature_ids)
{
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
debug!("Backing off from {:?}", signature_ids);
return Err(Error::BadServerResponse("bad signature, still backing off"));
}
}
let mut result: BTreeMap<_, _> = services()
.globals
.signing_keys_for(origin)?
@ -600,10 +519,6 @@ impl super::Service {
}
}
drop(permit);
back_off(signature_ids).await;
warn!("Failed to find public key for server: {origin}");
Err(Error::BadServerResponse("Failed to find public key for server"))
}