This commit is contained in:
Timo Kösters 2022-04-03 11:48:25 +02:00
parent 0066f20bdd
commit d4ccfa16dc
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
5 changed files with 129 additions and 57 deletions

22
Cargo.lock generated
View file

@ -96,6 +96,27 @@ dependencies = [
"tokio",
]
[[package]]
name = "async-stream"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
dependencies = [
"async-stream-impl",
"futures-core",
]
[[package]]
name = "async-stream-impl"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "async-trait"
version = "0.1.52"
@ -389,6 +410,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
name = "conduit"
version = "0.3.0-next"
dependencies = [
"async-stream",
"axum",
"axum-server",
"base64 0.13.0",

View file

@ -27,6 +27,7 @@ ruma = { git = "https://github.com/ruma/ruma", rev = "fa2e3662a456bd8957b3e1293c
# Async runtime and utilities
tokio = { version = "1.11.0", features = ["fs", "macros", "signal", "sync"] }
async-stream = "0.3.2"
# Used for storing data permanently
sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true }
#sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] }
@ -76,7 +77,7 @@ crossbeam = { version = "0.8.1", optional = true }
num_cpus = "1.13.0"
threadpool = "1.8.1"
heed = { git = "https://github.com/timokoesters/heed.git", rev = "f6f825da7fb2c758867e05ad973ef800a6fe1d5d", optional = true }
rocksdb = { version = "0.17.0", default-features = false, features = ["multi-threaded-cf", "zstd"], optional = true }
rocksdb = { version = "0.17.0", default-features = true, features = ["multi-threaded-cf", "zstd"], optional = true }
thread_local = "1.1.3"
# used for TURN server authentication

View file

@ -36,8 +36,8 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O
db_opts.set_level_compaction_dynamic_level_bytes(true);
db_opts.set_target_file_size_base(256 * 1024 * 1024);
//db_opts.set_compaction_readahead_size(2 * 1024 * 1024);
//db_opts.set_use_direct_reads(true);
//db_opts.set_use_direct_io_for_flush_and_compaction(true);
db_opts.set_use_direct_reads(true);
db_opts.set_use_direct_io_for_flush_and_compaction(true);
db_opts.create_if_missing(true);
db_opts.increase_parallelism(num_cpus::get() as i32);
db_opts.set_max_open_files(max_open_files);

View file

@ -1,6 +1,7 @@
mod edus;
pub use edus::RoomEdus;
use futures_util::Stream;
use crate::{
pdu::{EventHash, PduBuilder},
@ -39,6 +40,7 @@ use std::{
sync::{Arc, Mutex, RwLock},
};
use tokio::sync::MutexGuard;
use async_stream::try_stream;
use tracing::{error, warn};
use super::{abstraction::Tree, pusher};
@ -1083,6 +1085,38 @@ impl Rooms {
.transpose()
}
pub async fn get_pdu_async(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
return Ok(Some(Arc::clone(p)));
}
let eventid_pduid = Arc::clone(&self.eventid_pduid);
let event_id_bytes = event_id.as_bytes().to_vec();
if let Some(pdu) = tokio::task::spawn_blocking(move || { eventid_pduid .get(&event_id_bytes)}).await.unwrap()?
.map_or_else(
|| self.eventid_outlierpdu.get(event_id.as_bytes()),
|pduid| {
Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| {
Error::bad_database("Invalid pduid in eventid_pduid.")
})?))
},
)?
.map(|pdu| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
.map(Arc::new)
})
.transpose()?
{
self.pdu_cache
.lock()
.unwrap()
.insert(event_id.to_owned(), Arc::clone(&pdu));
Ok(Some(pdu))
} else {
Ok(None)
}
}
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
@ -2109,7 +2143,7 @@ impl Rooms {
user_id: &UserId,
room_id: &RoomId,
from: u64,
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
) -> Result<impl Stream<Item = Result<(Vec<u8>, PduEvent)>>> {
// Create the first part of the full pdu id
let prefix = self
.get_shortroomid(room_id)?
@ -2124,18 +2158,23 @@ impl Rooms {
let user_id = user_id.to_owned();
Ok(self
let iter = self
.pduid_pdu
.iter_from(current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
.iter_from(current, false);
Ok(try_stream! {
while let Some((k, v)) = tokio::task::spawn_blocking(|| { iter.next() }).await.unwrap() {
if !k.starts_with(&prefix) {
return;
}
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((pdu_id, pdu))
}))
yield (k, pdu)
}
})
}
/// Replace a PDU with the redacted form.

View file

@ -49,6 +49,7 @@ use ruma::{
},
int,
receipt::ReceiptType,
room_id,
serde::{Base64, JsonObject, Raw},
signatures::{CanonicalJsonObject, CanonicalJsonValue},
state_res::{self, RoomVersion, StateMap},
@ -681,7 +682,7 @@ pub async fn send_transaction_message_route(
.roomid_mutex_federation
.write()
.unwrap()
.entry(room_id.clone())
.entry(room_id!("!somewhere:example.org").to_owned()) // only allow one room at a time
.or_default(),
);
let mutex_lock = mutex.lock().await;
@ -1141,7 +1142,7 @@ fn handle_outlier_pdu<'a>(
// Build map of auth events
let mut auth_events = HashMap::new();
for id in &incoming_pdu.auth_events {
let auth_event = match db.rooms.get_pdu(id).map_err(|e| e.to_string())? {
let auth_event = match db.rooms.get_pdu_async(id).await.map_err(|e| e.to_string())? {
Some(e) => e,
None => {
warn!("Could not find auth event {}", id);
@ -1182,7 +1183,7 @@ fn handle_outlier_pdu<'a>(
&& incoming_pdu.prev_events == incoming_pdu.auth_events
{
db.rooms
.get_pdu(&incoming_pdu.auth_events[0])
.get_pdu_async(&incoming_pdu.auth_events[0]).await
.map_err(|e| e.to_string())?
.filter(|maybe_create| **maybe_create == *create_event)
} else {
@ -1265,10 +1266,13 @@ async fn upgrade_outlier_to_timeline_pdu(
if let Some(Ok(mut state)) = state {
warn!("Using cached state");
let prev_pdu =
db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| {
"Could not find prev event, but we know the state.".to_owned()
})?;
let prev_pdu = db
.rooms
.get_pdu_async(prev_event)
.await
.ok()
.flatten()
.ok_or_else(|| "Could not find prev event, but we know the state.".to_owned())?;
if let Some(state_key) = &prev_pdu.state_key {
let shortstatekey = db
@ -1288,7 +1292,7 @@ async fn upgrade_outlier_to_timeline_pdu(
let mut okay = true;
for prev_eventid in &incoming_pdu.prev_events {
let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu(prev_eventid) {
let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu_async(prev_eventid).await {
pdu
} else {
okay = false;
@ -1337,7 +1341,7 @@ async fn upgrade_outlier_to_timeline_pdu(
}
auth_chain_sets.push(
get_auth_chain(room_id, starting_events, db)
get_auth_chain(room_id, starting_events, db).await
.map_err(|_| "Failed to load auth chain.".to_owned())?
.collect(),
);
@ -1350,7 +1354,7 @@ async fn upgrade_outlier_to_timeline_pdu(
&fork_states,
auth_chain_sets,
|id| {
let res = db.rooms.get_pdu(id);
let res = db.rooms.get_pdu_async(id).await;
if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e);
}
@ -1462,28 +1466,33 @@ async fn upgrade_outlier_to_timeline_pdu(
&& incoming_pdu.prev_events == incoming_pdu.auth_events
{
db.rooms
.get_pdu(&incoming_pdu.auth_events[0])
.get_pdu_async(&incoming_pdu.auth_events[0])
.await
.map_err(|e| e.to_string())?
.filter(|maybe_create| **maybe_create == *create_event)
} else {
None
};
let check_result = state_res::event_auth::auth_check(
&room_version,
&incoming_pdu,
previous_create.as_ref(),
None::<PduEvent>, // TODO: third party invite
|k, s| {
db.rooms
.get_shortstatekey(k, s)
.ok()
.flatten()
.and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey))
.and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten())
},
)
.map_err(|_e| "Auth check failed.".to_owned())?;
let check_result = tokio::task::spawn_blocking(move || {
state_res::event_auth::auth_check(
&room_version,
&incoming_pdu,
previous_create.as_ref(),
None::<PduEvent>, // TODO: third party invite
|k, s| {
db.rooms
.get_shortstatekey(k, s)
.ok()
.flatten()
.and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey))
.and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten())
},
)
.map_err(|_e| "Auth check failed.".to_owned())
})
.await
.unwrap()?;
if !check_result {
return Err("Event has failed auth check with state at the event.".into());
@ -1591,7 +1600,8 @@ async fn upgrade_outlier_to_timeline_pdu(
for id in dbg!(&extremities) {
match db
.rooms
.get_pdu(id)
.get_pdu_async(id)
.await
.map_err(|_| "Failed to ask db for pdu.".to_owned())?
{
Some(leaf_pdu) => {
@ -1664,7 +1674,7 @@ async fn upgrade_outlier_to_timeline_pdu(
room_id,
state.iter().map(|(_, id)| id.clone()).collect(),
db,
)
).await
.map_err(|_| "Failed to load auth chain.".to_owned())?
.collect(),
);
@ -1685,20 +1695,20 @@ async fn upgrade_outlier_to_timeline_pdu(
})
.collect();
let state = match state_res::resolve(
room_version_id,
&fork_states,
auth_chain_sets,
|id| {
let state = match tokio::task::spawn_blocking(move || {
state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
let res = db.rooms.get_pdu(id);
if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e);
}
res.ok().flatten()
},
) {
Ok(new_state) => new_state,
Err(_) => {
}).ok()
})
.await
.unwrap()
{
Some(new_state) => new_state,
None => {
return Err("State resolution failed, either an event could not be found or deserialization".into());
}
};
@ -1798,7 +1808,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>(
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = db.rooms.get_pdu(id) {
if let Ok(Some(local_pdu)) = db.rooms.get_pdu_async(id).await {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
continue;
@ -1815,7 +1825,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>(
continue;
}
if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) {
if let Ok(Some(_)) = db.rooms.get_pdu_async(&next_id).await {
trace!("Found {} in db", id);
continue;
}
@ -2153,7 +2163,7 @@ fn append_incoming_pdu<'a>(
}
#[tracing::instrument(skip(starting_events, db))]
pub(crate) fn get_auth_chain<'a>(
pub(crate) async fn get_auth_chain<'a>(
room_id: &RoomId,
starting_events: Vec<Arc<EventId>>,
db: &'a Database,
@ -2194,7 +2204,7 @@ pub(crate) fn get_auth_chain<'a>(
chunk_cache.extend(cached.iter().copied());
} else {
misses2 += 1;
let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db)?);
let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db).await?);
db.rooms
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
println!(
@ -2230,7 +2240,7 @@ pub(crate) fn get_auth_chain<'a>(
}
#[tracing::instrument(skip(event_id, db))]
fn get_auth_chain_inner(
async fn get_auth_chain_inner(
room_id: &RoomId,
event_id: &EventId,
db: &Database,
@ -2239,7 +2249,7 @@ fn get_auth_chain_inner(
let mut found = HashSet::new();
while let Some(event_id) = todo.pop() {
match db.rooms.get_pdu(&event_id) {
match db.rooms.get_pdu_async(&event_id).await {
Ok(Some(pdu)) => {
if pdu.room_id != room_id {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
@ -2423,7 +2433,7 @@ pub async fn get_event_authorization_route(
let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db)?;
let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db).await?;
Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids
@ -2477,7 +2487,7 @@ pub async fn get_room_state_route(
})
.collect();
let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?;
let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?;
Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids
@ -2532,7 +2542,7 @@ pub async fn get_room_state_ids_route(
.map(|(_, id)| (*id).to_owned())
.collect();
let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?;
let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?;
Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(),
@ -2795,7 +2805,7 @@ async fn create_join_event(
room_id,
state_ids.iter().map(|(_, id)| id.clone()).collect(),
db,
)?;
).await?;
let servers = db
.rooms