Merge branch 'backfill' into 'next'

feat: handle backfill requests

See merge request famedly/conduit!459
This commit is contained in:
Timo Kösters 2023-03-16 13:32:42 +00:00
commit 664ee7d89a
22 changed files with 1638 additions and 1073 deletions

View file

@ -129,7 +129,7 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
auth_error: None, auth_error: None,
}; };
if !body.from_appservice { if !body.from_appservice && !is_guest {
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
&UserId::parse_with_server_name("", services().globals.server_name()) &UserId::parse_with_server_name("", services().globals.server_name())

View file

@ -27,36 +27,35 @@ pub async fn get_context_route(
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
let base_pdu_id = services() let base_token = services()
.rooms .rooms
.timeline .timeline
.get_pdu_id(&body.event_id)? .get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
"Base event id not found.", "Base event id not found.",
))?; ))?;
let base_token = services().rooms.timeline.pdu_count(&base_pdu_id)?; let base_event =
services()
let base_event = services() .rooms
.rooms .timeline
.timeline .get_pdu(&body.event_id)?
.get_pdu_from_id(&base_pdu_id)? .ok_or(Error::BadRequest(
.ok_or(Error::BadRequest( ErrorKind::NotFound,
ErrorKind::NotFound, "Base event not found.",
"Base event not found.", ))?;
))?;
let room_id = base_event.room_id.clone(); let room_id = base_event.room_id.clone();
if !services() if !services()
.rooms .rooms
.state_cache .state_accessor
.is_joined(sender_user, &room_id)? .user_can_see_event(sender_user, &room_id, &body.event_id)?
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this event.",
)); ));
} }
@ -83,6 +82,13 @@ pub async fn get_context_route(
/ 2, / 2,
) )
.filter_map(|r| r.ok()) // Remove buggy events .filter_map(|r| r.ok()) // Remove buggy events
.filter(|(_, pdu)| {
services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id)
.unwrap_or(false)
})
.collect(); .collect();
for (_, event) in &events_before { for (_, event) in &events_before {
@ -97,10 +103,7 @@ pub async fn get_context_route(
} }
} }
let start_token = events_before let start_token = events_before.last().map(|(count, _)| count.stringify());
.last()
.and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok())
.map(|count| count.to_string());
let events_before: Vec<_> = events_before let events_before: Vec<_> = events_before
.into_iter() .into_iter()
@ -118,6 +121,13 @@ pub async fn get_context_route(
/ 2, / 2,
) )
.filter_map(|r| r.ok()) // Remove buggy events .filter_map(|r| r.ok()) // Remove buggy events
.filter(|(_, pdu)| {
services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id)
.unwrap_or(false)
})
.collect(); .collect();
for (_, event) in &events_after { for (_, event) in &events_after {
@ -151,10 +161,7 @@ pub async fn get_context_route(
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await?; .await?;
let end_token = events_after let end_token = events_after.last().map(|(count, _)| count.stringify());
.last()
.and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok())
.map(|count| count.to_string());
let events_after: Vec<_> = events_after let events_after: Vec<_> = events_after
.into_iter() .into_iter()

View file

@ -29,7 +29,7 @@ use std::{
sync::{Arc, RwLock}, sync::{Arc, RwLock},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use tracing::{debug, error, warn}; use tracing::{debug, error, info, warn};
use crate::{ use crate::{
service::pdu::{gen_event_id_canonical_json, PduBuilder}, service::pdu::{gen_event_id_canonical_json, PduBuilder},
@ -396,11 +396,10 @@ pub async fn get_member_events_route(
) -> Result<get_member_events::v3::Response> { ) -> Result<get_member_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// TODO: check history visibility?
if !services() if !services()
.rooms .rooms
.state_cache .state_accessor
.is_joined(sender_user, &body.room_id)? .user_can_see_state_events(&sender_user, &body.room_id)?
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
@ -434,12 +433,12 @@ pub async fn joined_members_route(
if !services() if !services()
.rooms .rooms
.state_cache .state_accessor
.is_joined(sender_user, &body.room_id)? .user_can_see_state_events(&sender_user, &body.room_id)?
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You aren't a member of the room.", "You don't have permission to view this room.",
)); ));
} }
@ -491,9 +490,13 @@ async fn join_room_by_id_helper(
.state_cache .state_cache
.server_in_room(services().globals.server_name(), room_id)? .server_in_room(services().globals.server_name(), room_id)?
{ {
info!("Joining {room_id} over federation.");
let (make_join_response, remote_server) = let (make_join_response, remote_server) =
make_join_request(sender_user, room_id, servers).await?; make_join_request(sender_user, room_id, servers).await?;
info!("make_join finished");
let room_version_id = match make_join_response.room_version { let room_version_id = match make_join_response.room_version {
Some(room_version) Some(room_version)
if services() if services()
@ -578,6 +581,7 @@ async fn join_room_by_id_helper(
// It has enough fields to be called a proper event now // It has enough fields to be called a proper event now
let mut join_event = join_event_stub; let mut join_event = join_event_stub;
info!("Asking {remote_server} for send_join");
let send_join_response = services() let send_join_response = services()
.sending .sending
.send_federation_request( .send_federation_request(
@ -590,7 +594,10 @@ async fn join_room_by_id_helper(
) )
.await?; .await?;
info!("send_join finished");
if let Some(signed_raw) = &send_join_response.room_state.event { if let Some(signed_raw) = &send_join_response.room_state.event {
info!("There is a signed event. This room is probably using restricted joins");
let (signed_event_id, signed_value) = let (signed_event_id, signed_value) =
match gen_event_id_canonical_json(signed_raw, &room_version_id) { match gen_event_id_canonical_json(signed_raw, &room_version_id) {
Ok(t) => t, Ok(t) => t,
@ -630,24 +637,29 @@ async fn join_room_by_id_helper(
.expect("we created a valid pdu") .expect("we created a valid pdu")
.insert(remote_server.to_string(), signature.clone()); .insert(remote_server.to_string(), signature.clone());
} else { } else {
warn!("Server {} sent invalid sendjoin event", remote_server); warn!(
"Server {remote_server} sent invalid signature in sendjoin signatures for event {signed_value:?}",
);
} }
} }
services().rooms.short.get_or_create_shortroomid(room_id)?; services().rooms.short.get_or_create_shortroomid(room_id)?;
info!("Parsing join event");
let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone())
.map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?;
let mut state = HashMap::new(); let mut state = HashMap::new();
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
info!("Fetching join signing keys");
services() services()
.rooms .rooms
.event_handler .event_handler
.fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map)
.await?; .await?;
info!("Going through send_join response room_state");
for result in send_join_response for result in send_join_response
.room_state .room_state
.state .state
@ -677,6 +689,7 @@ async fn join_room_by_id_helper(
} }
} }
info!("Going through send_join response auth_chain");
for result in send_join_response for result in send_join_response
.room_state .room_state
.auth_chain .auth_chain
@ -694,6 +707,7 @@ async fn join_room_by_id_helper(
.add_pdu_outlier(&event_id, &value)?; .add_pdu_outlier(&event_id, &value)?;
} }
info!("Running send_join auth check");
if !state_res::event_auth::auth_check( if !state_res::event_auth::auth_check(
&state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"),
&parsed_join_pdu, &parsed_join_pdu,
@ -714,14 +728,17 @@ async fn join_room_by_id_helper(
.ok()? .ok()?
}, },
) )
.map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed"))? .map_err(|e| {
{ warn!("Auth check failed: {e}");
Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")
})? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Auth check failed", "Auth check failed",
)); ));
} }
info!("Saving state from send_join");
let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state(
room_id, room_id,
state state
@ -741,12 +758,14 @@ async fn join_room_by_id_helper(
.force_state(room_id, statehash_before_join, new, removed, &state_lock) .force_state(room_id, statehash_before_join, new, removed, &state_lock)
.await?; .await?;
info!("Updating joined counts for new room");
services().rooms.state_cache.update_joined_count(room_id)?; services().rooms.state_cache.update_joined_count(room_id)?;
// We append to state before appending the pdu, so we don't have a moment in time with the // We append to state before appending the pdu, so we don't have a moment in time with the
// pdu without it's state. This is okay because append_pdu can't fail. // pdu without it's state. This is okay because append_pdu can't fail.
let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?;
info!("Appending new room join event");
services().rooms.timeline.append_pdu( services().rooms.timeline.append_pdu(
&parsed_join_pdu, &parsed_join_pdu,
join_event, join_event,
@ -754,6 +773,7 @@ async fn join_room_by_id_helper(
&state_lock, &state_lock,
)?; )?;
info!("Setting final room state for new room");
// We set the room state after inserting the pdu, so that we never have a moment in time // We set the room state after inserting the pdu, so that we never have a moment in time
// where events in the current room state do not exist // where events in the current room state do not exist
services() services()
@ -761,6 +781,8 @@ async fn join_room_by_id_helper(
.state .state
.set_room_state(room_id, statehash_after_join, &state_lock)?; .set_room_state(room_id, statehash_after_join, &state_lock)?;
} else { } else {
info!("We can join locally");
let join_rules_event = services().rooms.state_accessor.room_state_get( let join_rules_event = services().rooms.state_accessor.room_state_get(
room_id, room_id,
&StateEventType::RoomJoinRules, &StateEventType::RoomJoinRules,
@ -879,8 +901,9 @@ async fn join_room_by_id_helper(
}; };
if !restriction_rooms.is_empty() { if !restriction_rooms.is_empty() {
// We couldn't do the join locally, maybe federation can help to satisfy the restricted info!(
// join requirements "We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"
);
let (make_join_response, remote_server) = let (make_join_response, remote_server) =
make_join_request(sender_user, room_id, servers).await?; make_join_request(sender_user, room_id, servers).await?;
@ -1038,6 +1061,7 @@ async fn make_join_request(
if remote_server == services().globals.server_name() { if remote_server == services().globals.server_name() {
continue; continue;
} }
info!("Asking {remote_server} for make_join");
let make_join_response = services() let make_join_response = services()
.sending .sending
.send_federation_request( .send_federation_request(

View file

@ -1,4 +1,7 @@
use crate::{service::pdu::PduBuilder, services, utils, Error, Result, Ruma}; use crate::{
service::{pdu::PduBuilder, rooms::timeline::PduCount},
services, utils, Error, Result, Ruma,
};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -110,29 +113,18 @@ pub async fn get_message_events_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
if !services()
.rooms
.state_cache
.is_joined(sender_user, &body.room_id)?
{
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You don't have permission to view this room.",
));
}
let from = match body.from.clone() { let from = match body.from.clone() {
Some(from) => from Some(from) => PduCount::try_from_string(&from)?,
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from` value."))?,
None => match body.dir { None => match body.dir {
ruma::api::client::Direction::Forward => 0, ruma::api::client::Direction::Forward => PduCount::min(),
ruma::api::client::Direction::Backward => u64::MAX, ruma::api::client::Direction::Backward => PduCount::max(),
}, },
}; };
let to = body.to.as_ref().map(|t| t.parse()); let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(&t).ok());
services().rooms.lazy_loading.lazy_load_confirm_delivery( services().rooms.lazy_loading.lazy_load_confirm_delivery(
sender_user, sender_user,
@ -158,15 +150,14 @@ pub async fn get_message_events_route(
.pdus_after(sender_user, &body.room_id, from)? .pdus_after(sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
.timeline .state_accessor
.pdu_count(&pdu_id) .user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
.map(|pdu_count| (pdu_count, pdu)) .unwrap_or(false)
.ok()
}) })
.take_while(|&(k, _)| Some(Ok(k)) != to) // Stop at `to` .take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.collect(); .collect();
for (_, event) in &events_after { for (_, event) in &events_after {
@ -192,26 +183,30 @@ pub async fn get_message_events_route(
.map(|(_, pdu)| pdu.to_room_event()) .map(|(_, pdu)| pdu.to_room_event())
.collect(); .collect();
resp.start = from.to_string(); resp.start = from.stringify();
resp.end = next_token.map(|count| count.to_string()); resp.end = next_token.map(|count| count.stringify());
resp.chunk = events_after; resp.chunk = events_after;
} }
ruma::api::client::Direction::Backward => { ruma::api::client::Direction::Backward => {
services()
.rooms
.timeline
.backfill_if_required(&body.room_id, from)
.await?;
let events_before: Vec<_> = services() let events_before: Vec<_> = services()
.rooms .rooms
.timeline .timeline
.pdus_until(sender_user, &body.room_id, from)? .pdus_until(sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
.timeline .state_accessor
.pdu_count(&pdu_id) .user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
.map(|pdu_count| (pdu_count, pdu)) .unwrap_or(false)
.ok()
}) })
.take_while(|&(k, _)| Some(Ok(k)) != to) // Stop at `to` .take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.collect(); .collect();
for (_, event) in &events_before { for (_, event) in &events_before {
@ -237,8 +232,8 @@ pub async fn get_message_events_route(
.map(|(_, pdu)| pdu.to_room_event()) .map(|(_, pdu)| pdu.to_room_event())
.collect(); .collect();
resp.start = from.to_string(); resp.start = from.stringify();
resp.end = next_token.map(|count| count.to_string()); resp.end = next_token.map(|count| count.stringify());
resp.chunk = events_before; resp.chunk = events_before;
} }
} }

View file

@ -1,4 +1,4 @@
use crate::{services, Error, Result, Ruma}; use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
events::{ events::{
@ -42,18 +42,28 @@ pub async fn set_read_marker_route(
} }
if let Some(event) = &body.private_read_receipt { if let Some(event) = &body.private_read_receipt {
services().rooms.edus.read_receipt.private_read_set( let count = services()
&body.room_id, .rooms
sender_user, .timeline
services() .get_pdu_count(event)?
.rooms .ok_or(Error::BadRequest(
.timeline ErrorKind::InvalidParam,
.get_pdu_count(event)? "Event does not exist.",
.ok_or(Error::BadRequest( ))?;
let count = match count {
PduCount::Backfilled(_) => {
return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Event does not exist.", "Read receipt is in backfilled timeline",
))?, ))
)?; }
PduCount::Normal(c) => c,
};
services()
.rooms
.edus
.read_receipt
.private_read_set(&body.room_id, sender_user, count)?;
} }
if let Some(event) = &body.read_receipt { if let Some(event) = &body.read_receipt {
@ -142,17 +152,27 @@ pub async fn create_receipt_route(
)?; )?;
} }
create_receipt::v3::ReceiptType::ReadPrivate => { create_receipt::v3::ReceiptType::ReadPrivate => {
let count = services()
.rooms
.timeline
.get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Event does not exist.",
))?;
let count = match count {
PduCount::Backfilled(_) => {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Read receipt is in backfilled timeline",
))
}
PduCount::Normal(c) => c,
};
services().rooms.edus.read_receipt.private_read_set( services().rooms.edus.read_receipt.private_read_set(
&body.room_id, &body.room_id,
sender_user, sender_user,
services() count,
.rooms
.timeline
.get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Event does not exist.",
))?,
)?; )?;
} }
_ => return Err(Error::bad_database("Unsupported receipt type")), _ => return Err(Error::bad_database("Unsupported receipt type")),

View file

@ -425,24 +425,25 @@ pub async fn get_room_event_route(
) -> Result<get_room_event::v3::Response> { ) -> Result<get_room_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() let event = services()
.rooms .rooms
.state_cache .timeline
.is_joined(sender_user, &body.room_id)? .get_pdu(&body.event_id)?
{ .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?;
if !services().rooms.state_accessor.user_can_see_event(
sender_user,
&event.room_id,
&body.event_id,
)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this event.",
)); ));
} }
Ok(get_room_event::v3::Response { Ok(get_room_event::v3::Response {
event: services() event: event.to_room_event(),
.rooms
.timeline
.get_pdu(&body.event_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?
.to_room_event(),
}) })
} }

View file

@ -81,6 +81,21 @@ pub async fn search_events_route(
let results: Vec<_> = results let results: Vec<_> = results
.iter() .iter()
.filter_map(|result| {
services()
.rooms
.timeline
.get_pdu_from_id(result)
.ok()?
.filter(|pdu| {
services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.unwrap_or(false)
})
.map(|pdu| pdu.to_room_event())
})
.map(|result| { .map(|result| {
Ok::<_, Error>(SearchResult { Ok::<_, Error>(SearchResult {
context: EventContextResult { context: EventContextResult {
@ -91,11 +106,7 @@ pub async fn search_events_route(
start: None, start: None,
}, },
rank: None, rank: None,
result: services() result: Some(result),
.rooms
.timeline
.get_pdu_from_id(result)?
.map(|pdu| pdu.to_room_event()),
}) })
}) })
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())

View file

@ -7,11 +7,7 @@ use ruma::{
state::{get_state_events, get_state_events_for_key, send_state_event}, state::{get_state_events, get_state_events_for_key, send_state_event},
}, },
events::{ events::{
room::{ room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType,
canonical_alias::RoomCanonicalAliasEventContent,
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
},
AnyStateEventContent, StateEventType,
}, },
serde::Raw, serde::Raw,
EventId, RoomId, UserId, EventId, RoomId, UserId,
@ -85,29 +81,10 @@ pub async fn get_state_events_route(
) -> Result<get_state_events::v3::Response> { ) -> Result<get_state_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
#[allow(clippy::blocks_in_if_conditions)]
// Users not in the room should not be able to access the state unless history_visibility is
// WorldReadable
if !services() if !services()
.rooms .rooms
.state_cache .state_accessor
.is_joined(sender_user, &body.room_id)? .user_can_see_state_events(&sender_user, &body.room_id)?
&& !matches!(
services()
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
.map(|event| {
serde_json::from_str(event.content.get())
.map(|e: RoomHistoryVisibilityEventContent| e.history_visibility)
.map_err(|_| {
Error::bad_database(
"Invalid room history visibility event in database.",
)
})
}),
Some(Ok(HistoryVisibility::WorldReadable))
)
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
@ -137,29 +114,10 @@ pub async fn get_state_events_for_key_route(
) -> Result<get_state_events_for_key::v3::Response> { ) -> Result<get_state_events_for_key::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
#[allow(clippy::blocks_in_if_conditions)]
// Users not in the room should not be able to access the state unless history_visibility is
// WorldReadable
if !services() if !services()
.rooms .rooms
.state_cache .state_accessor
.is_joined(sender_user, &body.room_id)? .user_can_see_state_events(&sender_user, &body.room_id)?
&& !matches!(
services()
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
.map(|event| {
serde_json::from_str(event.content.get())
.map(|e: RoomHistoryVisibilityEventContent| e.history_visibility)
.map_err(|_| {
Error::bad_database(
"Invalid room history visibility event in database.",
)
})
}),
Some(Ok(HistoryVisibility::WorldReadable))
)
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
@ -192,29 +150,10 @@ pub async fn get_state_events_for_empty_key_route(
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
#[allow(clippy::blocks_in_if_conditions)]
// Users not in the room should not be able to access the state unless history_visibility is
// WorldReadable
if !services() if !services()
.rooms .rooms
.state_cache .state_accessor
.is_joined(sender_user, &body.room_id)? .user_can_see_state_events(&sender_user, &body.room_id)?
&& !matches!(
services()
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
.map(|event| {
serde_json::from_str(event.content.get())
.map(|e: RoomHistoryVisibilityEventContent| e.history_visibility)
.map_err(|_| {
Error::bad_database(
"Invalid room history visibility event in database.",
)
})
}),
Some(Ok(HistoryVisibility::WorldReadable))
)
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,

File diff suppressed because it is too large Load diff

View file

@ -12,6 +12,7 @@ use ruma::{
client::error::{Error as RumaError, ErrorKind}, client::error::{Error as RumaError, ErrorKind},
federation::{ federation::{
authorization::get_event_authorization, authorization::get_event_authorization,
backfill::get_backfill,
device::get_devices::{self, v1::UserDevice}, device::get_devices::{self, v1::UserDevice},
directory::{get_public_rooms, get_public_rooms_filtered}, directory::{get_public_rooms, get_public_rooms_filtered},
discovery::{get_server_keys, get_server_version, ServerSigningKeys, VerifyKey}, discovery::{get_server_keys, get_server_version, ServerSigningKeys, VerifyKey},
@ -42,8 +43,9 @@ use ruma::{
}, },
serde::{Base64, JsonObject, Raw}, serde::{Base64, JsonObject, Raw},
to_device::DeviceIdOrAllDevices, to_device::DeviceIdOrAllDevices,
CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, ServerName, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId,
ServerName,
}; };
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use std::{ use std::{
@ -123,6 +125,8 @@ where
return Err(Error::bad_config("Federation is disabled.")); return Err(Error::bad_config("Federation is disabled."));
} }
debug!("Preparing to send request to {destination}");
let mut write_destination_to_cache = false; let mut write_destination_to_cache = false;
let cached_result = services() let cached_result = services()
@ -229,11 +233,13 @@ where
let url = reqwest_request.url().clone(); let url = reqwest_request.url().clone();
debug!("Sending request to {destination} at {url}");
let response = services() let response = services()
.globals .globals
.federation_client() .federation_client()
.execute(reqwest_request) .execute(reqwest_request)
.await; .await;
debug!("Received response from {destination} at {url}");
match response { match response {
Ok(mut response) => { Ok(mut response) => {
@ -249,10 +255,12 @@ where
.expect("http::response::Builder is usable"), .expect("http::response::Builder is usable"),
); );
debug!("Getting response bytes from {destination}");
let body = response.bytes().await.unwrap_or_else(|e| { let body = response.bytes().await.unwrap_or_else(|e| {
warn!("server error {}", e); warn!("server error {}", e);
Vec::new().into() Vec::new().into()
}); // TODO: handle timeout }); // TODO: handle timeout
debug!("Got response bytes from {destination}");
if status != 200 { if status != 200 {
warn!( warn!(
@ -271,6 +279,7 @@ where
.expect("reqwest body is valid http body"); .expect("reqwest body is valid http body");
if status == 200 { if status == 200 {
debug!("Parsing response bytes from {destination}");
let response = T::IncomingResponse::try_from_http_response(http_response); let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && write_destination_to_cache { if response.is_ok() && write_destination_to_cache {
services() services()
@ -292,6 +301,7 @@ where
Error::BadServerResponse("Server returned bad 200 response.") Error::BadServerResponse("Server returned bad 200 response.")
}) })
} else { } else {
debug!("Returning error from {destination}");
Err(Error::FederationError( Err(Error::FederationError(
destination.to_owned(), destination.to_owned(),
RumaError::from_http_response(http_response), RumaError::from_http_response(http_response),
@ -330,36 +340,38 @@ fn add_port_to_hostname(destination_str: &str) -> FedDest {
/// Implemented according to the specification at https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names /// Implemented according to the specification at https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names
/// Numbers in comments below refer to bullet points in linked section of specification /// Numbers in comments below refer to bullet points in linked section of specification
async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) { async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) {
debug!("Finding actual destination for {destination}");
let destination_str = destination.as_str().to_owned(); let destination_str = destination.as_str().to_owned();
let mut hostname = destination_str.clone(); let mut hostname = destination_str.clone();
let actual_destination = match get_ip_with_port(&destination_str) { let actual_destination = match get_ip_with_port(&destination_str) {
Some(host_port) => { Some(host_port) => {
// 1: IP literal with provided or default port debug!("1: IP literal with provided or default port");
host_port host_port
} }
None => { None => {
if let Some(pos) = destination_str.find(':') { if let Some(pos) = destination_str.find(':') {
// 2: Hostname with included port debug!("2: Hostname with included port");
let (host, port) = destination_str.split_at(pos); let (host, port) = destination_str.split_at(pos);
FedDest::Named(host.to_owned(), port.to_owned()) FedDest::Named(host.to_owned(), port.to_owned())
} else { } else {
debug!("Requesting well known for {destination}");
match request_well_known(destination.as_str()).await { match request_well_known(destination.as_str()).await {
// 3: A .well-known file is available
Some(delegated_hostname) => { Some(delegated_hostname) => {
debug!("3: A .well-known file is available");
hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); hostname = add_port_to_hostname(&delegated_hostname).into_uri_string();
match get_ip_with_port(&delegated_hostname) { match get_ip_with_port(&delegated_hostname) {
Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file
None => { None => {
if let Some(pos) = delegated_hostname.find(':') { if let Some(pos) = delegated_hostname.find(':') {
// 3.2: Hostname with port in .well-known file debug!("3.2: Hostname with port in .well-known file");
let (host, port) = delegated_hostname.split_at(pos); let (host, port) = delegated_hostname.split_at(pos);
FedDest::Named(host.to_owned(), port.to_owned()) FedDest::Named(host.to_owned(), port.to_owned())
} else { } else {
// Delegated hostname has no port in this branch debug!("Delegated hostname has no port in this branch");
if let Some(hostname_override) = if let Some(hostname_override) =
query_srv_record(&delegated_hostname).await query_srv_record(&delegated_hostname).await
{ {
// 3.3: SRV lookup successful debug!("3.3: SRV lookup successful");
let force_port = hostname_override.port(); let force_port = hostname_override.port();
if let Ok(override_ip) = services() if let Ok(override_ip) = services()
@ -390,18 +402,18 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
add_port_to_hostname(&delegated_hostname) add_port_to_hostname(&delegated_hostname)
} }
} else { } else {
// 3.4: No SRV records, just use the hostname from .well-known debug!("3.4: No SRV records, just use the hostname from .well-known");
add_port_to_hostname(&delegated_hostname) add_port_to_hostname(&delegated_hostname)
} }
} }
} }
} }
} }
// 4: No .well-known or an error occured
None => { None => {
debug!("4: No .well-known or an error occured");
match query_srv_record(&destination_str).await { match query_srv_record(&destination_str).await {
// 4: SRV record found
Some(hostname_override) => { Some(hostname_override) => {
debug!("4: SRV record found");
let force_port = hostname_override.port(); let force_port = hostname_override.port();
if let Ok(override_ip) = services() if let Ok(override_ip) = services()
@ -432,14 +444,17 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
add_port_to_hostname(&hostname) add_port_to_hostname(&hostname)
} }
} }
// 5: No SRV record found None => {
None => add_port_to_hostname(&destination_str), debug!("5: No SRV record found");
add_port_to_hostname(&destination_str)
}
} }
} }
} }
} }
} }
}; };
debug!("Actual destination: {actual_destination:?}");
// Can't use get_ip_with_port here because we don't want to add a port // Can't use get_ip_with_port here because we don't want to add a port
// to an IP address if it wasn't specified // to an IP address if it wasn't specified
@ -457,10 +472,11 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
} }
async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> { async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> {
let hostname = hostname.trim_end_matches('.');
if let Ok(Some(host_port)) = services() if let Ok(Some(host_port)) = services()
.globals .globals
.dns_resolver() .dns_resolver()
.srv_lookup(format!("_matrix._tcp.{hostname}")) .srv_lookup(format!("_matrix._tcp.{hostname}."))
.await .await
.map(|srv| { .map(|srv| {
srv.iter().next().map(|result| { srv.iter().next().map(|result| {
@ -478,19 +494,16 @@ async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> {
} }
async fn request_well_known(destination: &str) -> Option<String> { async fn request_well_known(destination: &str) -> Option<String> {
let body: serde_json::Value = serde_json::from_str( let response = services()
&services() .globals
.globals .default_client()
.default_client() .get(&format!("https://{destination}/.well-known/matrix/server"))
.get(&format!("https://{destination}/.well-known/matrix/server")) .send()
.send() .await;
.await debug!("Got well known response");
.ok()? let text = response.ok()?.text().await;
.text() debug!("Got well known response text");
.await let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?;
.ok()?,
)
.ok()?;
Some(body.get("m.server")?.as_str()?.to_owned()) Some(body.get("m.server")?.as_str()?.to_owned())
} }
@ -627,6 +640,37 @@ pub async fn get_public_rooms_route(
}) })
} }
pub fn parse_incoming_pdu(
pdu: &RawJsonValue,
) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
let room_id: OwnedRoomId = value
.get("room_id")
.and_then(|id| RoomId::parse(id.as_str()?).ok())
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid room id in pdu",
))?;
let room_version_id = services().rooms.state.get_room_version(&room_id)?;
let (event_id, value) = match gen_event_id_canonical_json(&pdu, &room_version_id) {
Ok(t) => t,
Err(_) => {
// Event could not be converted to canonical json
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Could not convert event to canonical json.",
));
}
};
Ok((event_id, value, room_id))
}
/// # `PUT /_matrix/federation/v1/send/{txnId}` /// # `PUT /_matrix/federation/v1/send/{txnId}`
/// ///
/// Push EDUs and PDUs to this server. /// Push EDUs and PDUs to this server.
@ -655,33 +699,11 @@ pub async fn send_transaction_message_route(
// let mut auth_cache = EventMap::new(); // let mut auth_cache = EventMap::new();
for pdu in &body.pdus { for pdu in &body.pdus {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { let r = parse_incoming_pdu(&pdu);
warn!("Error parsing incoming event {:?}: {:?}", pdu, e); let (event_id, value, room_id) = match r {
Error::BadServerResponse("Invalid PDU in server response")
})?;
let room_id: OwnedRoomId = match value
.get("room_id")
.and_then(|id| RoomId::parse(id.as_str()?).ok())
{
Some(id) => id,
None => {
// Event is invalid
continue;
}
};
let room_version_id = match services().rooms.state.get_room_version(&room_id) {
Ok(v) => v,
Err(_) => {
continue;
}
};
let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) {
Ok(t) => t, Ok(t) => t,
Err(_) => { Err(e) => {
// Event could not be converted to canonical json warn!("Could not parse pdu: {e}");
continue; continue;
} }
}; };
@ -943,6 +965,17 @@ pub async fn get_event_route(
)); ));
} }
if !services().rooms.state_accessor.server_can_see_event(
sender_servername,
&room_id,
&body.event_id,
)? {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Server is not allowed to see event.",
));
}
Ok(get_event::v1::Response { Ok(get_event::v1::Response {
origin: services().globals.server_name().to_owned(), origin: services().globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
@ -950,6 +983,83 @@ pub async fn get_event_route(
}) })
} }
/// # `GET /_matrix/federation/v1/backfill/<room_id>`
///
/// Retrieves events from before the sender joined the room, if the room's
/// history visibility allows.
pub async fn get_backfill_route(
body: Ruma<get_backfill::v1::Request>,
) -> Result<get_backfill::v1::Response> {
if !services().globals.allow_federation() {
return Err(Error::bad_config("Federation is disabled."));
}
let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
info!("Got backfill request from: {}", sender_servername);
if !services()
.rooms
.state_cache
.server_in_room(sender_servername, &body.room_id)?
{
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Server is not in room.",
));
}
services()
.rooms
.event_handler
.acl_check(sender_servername, &body.room_id)?;
let until = body
.v
.iter()
.map(|eventid| services().rooms.timeline.get_pdu_count(eventid))
.filter_map(|r| r.ok().flatten())
.max()
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"No known eventid in v",
))?;
let limit = body.limit.min(uint!(100));
let all_events = services()
.rooms
.timeline
.pdus_until(&user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)?
.take(limit.try_into().unwrap());
let events = all_events
.filter_map(|r| r.ok())
.filter(|(_, e)| {
matches!(
services().rooms.state_accessor.server_can_see_event(
sender_servername,
&e.room_id,
&e.event_id,
),
Ok(true),
)
})
.map(|(_, pdu)| services().rooms.timeline.get_pdu_json(&pdu.event_id))
.filter_map(|r| r.ok().flatten())
.map(|pdu| PduEvent::convert_to_outgoing_federation_event(pdu))
.collect();
Ok(get_backfill::v1::Response {
origin: services().globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdus: events,
})
}
/// # `POST /_matrix/federation/v1/get_missing_events/{roomId}` /// # `POST /_matrix/federation/v1/get_missing_events/{roomId}`
/// ///
/// Retrieves events that the sender is missing. /// Retrieves events that the sender is missing.
@ -1010,6 +1120,16 @@ pub async fn get_missing_events_route(
i += 1; i += 1;
continue; continue;
} }
if !services().rooms.state_accessor.server_can_see_event(
sender_servername,
&body.room_id,
&queued_events[i],
)? {
i += 1;
continue;
}
queued_events.extend_from_slice( queued_events.extend_from_slice(
&serde_json::from_value::<Vec<OwnedEventId>>( &serde_json::from_value::<Vec<OwnedEventId>>(
serde_json::to_value(pdu.get("prev_events").cloned().ok_or_else(|| { serde_json::to_value(pdu.get("prev_events").cloned().ok_or_else(|| {

View file

@ -1,5 +1,3 @@
use std::mem::size_of;
use ruma::RoomId; use ruma::RoomId;
use crate::{database::KeyValueDatabase, service, services, utils, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Result};
@ -15,7 +13,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes()); key.extend_from_slice(word.as_bytes());
key.push(0xff); key.push(0xff);
key.extend_from_slice(pdu_id); key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
(key, Vec::new()) (key, Vec::new())
}); });
@ -34,7 +32,6 @@ impl service::rooms::search::Data for KeyValueDatabase {
.expect("room exists") .expect("room exists")
.to_be_bytes() .to_be_bytes()
.to_vec(); .to_vec();
let prefix_clone = prefix.clone();
let words: Vec<_> = search_string let words: Vec<_> = search_string
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())
@ -46,6 +43,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
let mut prefix2 = prefix.clone(); let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes()); prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xff); prefix2.push(0xff);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone(); let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
@ -53,7 +51,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
self.tokenids self.tokenids
.iter_from(&last_possible_id, true) // Newest pdus first .iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2)) .take_while(move |(k, _)| k.starts_with(&prefix2))
.map(|(key, _)| key[key.len() - size_of::<u64>()..].to_vec()) .map(move |(key, _)| key[prefix3.len()..].to_vec())
}); });
let common_elements = match utils::common_elements(iterators, |a, b| { let common_elements = match utils::common_elements(iterators, |a, b| {
@ -64,12 +62,6 @@ impl service::rooms::search::Data for KeyValueDatabase {
None => return Ok(None), None => return Ok(None),
}; };
let mapped = common_elements.map(move |id| { Ok(Some((Box::new(common_elements), words)))
let mut pduid = prefix_clone.clone();
pduid.extend_from_slice(&id);
pduid
});
Ok(Some((Box::new(mapped), words)))
} }
} }

View file

@ -7,30 +7,10 @@ use tracing::error;
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
use service::rooms::timeline::PduCount;
impl service::rooms::timeline::Data for KeyValueDatabase { impl service::rooms::timeline::Data for KeyValueDatabase {
fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
// Look for PDUs in that room.
self.pduid_pdu
.iter_from(&prefix, false)
.filter(|(k, _)| k.starts_with(&prefix))
.map(|(_, pdu)| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid first PDU in db."))
.map(Arc::new)
})
.next()
.transpose()
}
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
match self match self
.lasttimelinecount_cache .lasttimelinecount_cache
.lock() .lock()
@ -39,20 +19,18 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
{ {
hash_map::Entry::Vacant(v) => { hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self if let Some(last_count) = self
.pdus_until(sender_user, room_id, u64::MAX)? .pdus_until(sender_user, room_id, PduCount::max())?
.filter_map(|r| { .find_map(|r| {
// Filter out buggy events // Filter out buggy events
if r.is_err() { if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r); error!("Bad pdu in pdus_since: {:?}", r);
} }
r.ok() r.ok()
}) })
.map(|(pduid, _)| self.pdu_count(&pduid))
.next()
{ {
Ok(*v.insert(last_count?)) Ok(*v.insert(last_count.0))
} else { } else {
Ok(0) Ok(PduCount::Normal(0))
} }
} }
hash_map::Entry::Occupied(o) => Ok(*o.get()), hash_map::Entry::Occupied(o) => Ok(*o.get()),
@ -60,29 +38,28 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
} }
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> { fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.eventid_pduid Ok(self
.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pdu_id| self.pdu_count(&pdu_id)) .map(|pdu_id| pdu_count(&pdu_id))
.transpose() .transpose()?)
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid self.get_non_outlier_pdu_json(event_id)?.map_or_else(
.get(event_id.as_bytes())? || {
.map_or_else( self.eventid_outlierpdu
|| self.eventid_outlierpdu.get(event_id.as_bytes()), .get(event_id.as_bytes())?
|pduid| { .map(|pdu| {
Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { serde_json::from_slice(&pdu)
Error::bad_database("Invalid pduid in eventid_pduid.") .map_err(|_| Error::bad_database("Invalid PDU in db."))
})?)) })
}, .transpose()
)? },
.map(|pdu| { |x| Ok(Some(x)),
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) )
})
.transpose()
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
@ -103,7 +80,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
/// Returns the pdu's id. /// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
self.eventid_pduid.get(event_id.as_bytes()) Ok(self.eventid_pduid.get(event_id.as_bytes())?)
} }
/// Returns the pdu. /// Returns the pdu.
@ -133,22 +110,20 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
} }
if let Some(pdu) = self if let Some(pdu) = self
.eventid_pduid .get_non_outlier_pdu(event_id)?
.get(event_id.as_bytes())?
.map_or_else( .map_or_else(
|| self.eventid_outlierpdu.get(event_id.as_bytes()), || {
|pduid| { self.eventid_outlierpdu
Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { .get(event_id.as_bytes())?
Error::bad_database("Invalid pduid in eventid_pduid.") .map(|pdu| {
})?)) serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
})
.transpose()
}, },
|x| Ok(Some(x)),
)? )?
.map(|pdu| { .map(Arc::new)
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
.map(Arc::new)
})
.transpose()?
{ {
self.pdu_cache self.pdu_cache
.lock() .lock()
@ -182,12 +157,6 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
}) })
} }
/// Returns the `count` of this pdu's id.
fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> {
utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))
}
fn append_pdu( fn append_pdu(
&self, &self,
pdu_id: &[u8], pdu_id: &[u8],
@ -203,7 +172,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
self.lasttimelinecount_cache self.lasttimelinecount_cache
.lock() .lock()
.unwrap() .unwrap()
.insert(pdu.room_id.clone(), count); .insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
@ -211,6 +180,23 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn prepend_backfill_pdu(
&self,
pdu_id: &[u8],
event_id: &EventId,
json: &CanonicalJsonObject,
) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(())
}
/// Removes a pdu and creates a new one with the same id. /// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() { if self.pduid_pdu.get(pdu_id)?.is_some() {
@ -227,70 +213,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
} }
} }
/// Returns an iterator over all events in a room that happened after the event with id `since`
/// in chronological order.
fn pdus_since<'a>(
&'a self,
user_id: &UserId,
room_id: &RoomId,
since: u64,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
// Skip the first pdu if it's exactly at since, because we sent that last time
let mut first_pdu_id = prefix.clone();
first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes());
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&first_pdu_id, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((pdu_id, pdu))
}),
))
}
/// Returns an iterator over all events and their tokens in a room that happened before the /// Returns an iterator over all events and their tokens in a room that happened before the
/// event with id `until` in reverse-chronological order. /// event with id `until` in reverse-chronological order.
fn pdus_until<'a>( fn pdus_until<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
until: u64, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
// Create the first part of the full pdu id let (prefix, current) = count_to_id(&room_id, until, 1, true)?;
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until`
let current: &[u8] = &current;
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
Ok(Box::new( Ok(Box::new(
self.pduid_pdu self.pduid_pdu
.iter_from(current, true) .iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| { .map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
@ -298,7 +235,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
if pdu.sender != user_id { if pdu.sender != user_id {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
Ok((pdu_id, pdu)) let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}), }),
)) ))
} }
@ -307,27 +245,15 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
from: u64, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> { ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
// Create the first part of the full pdu id let (prefix, current) = count_to_id(&room_id, from, 1, false)?;
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event
let current: &[u8] = &current;
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
Ok(Box::new( Ok(Box::new(
self.pduid_pdu self.pduid_pdu
.iter_from(current, false) .iter_from(&current, false)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| { .map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
@ -335,7 +261,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
if pdu.sender != user_id { if pdu.sender != user_id {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
Ok((pdu_id, pdu)) let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}), }),
)) ))
} }
@ -368,3 +295,60 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
} }
/// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 = utils::u64_from_bytes(
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()],
);
if matches!(second_last_u64, Ok(0)) {
Ok(PduCount::Backfilled(u64::MAX - last_u64))
} else {
Ok(PduCount::Normal(last_u64))
}
}
fn count_to_id(
room_id: &RoomId,
count: PduCount,
offset: u64,
subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x - offset
} else {
x + offset
}
}
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX - x;
if subtract {
if num > 0 {
num - offset
} else {
num
}
} else {
num + offset
}
}
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}

View file

@ -1,7 +1,10 @@
pub mod abstraction; pub mod abstraction;
pub mod key_value; pub mod key_value;
use crate::{services, utils, Config, Error, PduEvent, Result, Services, SERVICES}; use crate::{
service::rooms::timeline::PduCount, services, utils, Config, Error, PduEvent, Result, Services,
SERVICES,
};
use abstraction::{KeyValueDatabaseEngine, KvTree}; use abstraction::{KeyValueDatabaseEngine, KvTree};
use directories::ProjectDirs; use directories::ProjectDirs;
use lru_cache::LruCache; use lru_cache::LruCache;
@ -161,7 +164,7 @@ pub struct KeyValueDatabase {
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>, pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>, pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>, pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, u64>>, pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
} }
impl KeyValueDatabase { impl KeyValueDatabase {

View file

@ -390,6 +390,7 @@ fn routes() -> Router {
.ruma_route(server_server::get_public_rooms_filtered_route) .ruma_route(server_server::get_public_rooms_filtered_route)
.ruma_route(server_server::send_transaction_message_route) .ruma_route(server_server::send_transaction_message_route)
.ruma_route(server_server::get_event_route) .ruma_route(server_server::get_event_route)
.ruma_route(server_server::get_backfill_route)
.ruma_route(server_server::get_missing_events_route) .ruma_route(server_server::get_missing_events_route)
.ruma_route(server_server::get_event_authorization_route) .ruma_route(server_server::get_event_authorization_route)
.ruma_route(server_server::get_room_state_route) .ruma_route(server_server::get_room_state_route)

View file

@ -77,7 +77,15 @@ impl Services {
search: rooms::search::Service { db }, search: rooms::search::Service { db },
short: rooms::short::Service { db }, short: rooms::short::Service { db },
state: rooms::state::Service { db }, state: rooms::state::Service { db },
state_accessor: rooms::state_accessor::Service { db }, state_accessor: rooms::state_accessor::Service {
db,
server_visibility_cache: Mutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
user_visibility_cache: Mutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
},
state_cache: rooms::state_cache::Service { db }, state_cache: rooms::state_cache::Service { db },
state_compressor: rooms::state_compressor::Service { state_compressor: rooms::state_compressor::Service {
db, db,

View file

@ -111,9 +111,11 @@ impl PduEvent {
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"unsigned": self.unsigned,
}); });
if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned);
}
if let Some(state_key) = &self.state_key { if let Some(state_key) = &self.state_key {
json["state_key"] = json!(state_key); json["state_key"] = json!(state_key);
} }
@ -133,10 +135,12 @@ impl PduEvent {
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"unsigned": self.unsigned,
"room_id": self.room_id, "room_id": self.room_id,
}); });
if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned);
}
if let Some(state_key) = &self.state_key { if let Some(state_key) = &self.state_key {
json["state_key"] = json!(state_key); json["state_key"] = json!(state_key);
} }
@ -155,10 +159,12 @@ impl PduEvent {
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"unsigned": self.unsigned,
"room_id": self.room_id, "room_id": self.room_id,
}); });
if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned);
}
if let Some(state_key) = &self.state_key { if let Some(state_key) = &self.state_key {
json["state_key"] = json!(state_key); json["state_key"] = json!(state_key);
} }
@ -171,32 +177,38 @@ impl PduEvent {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_state_event(&self) -> Raw<AnyStateEvent> { pub fn to_state_event(&self) -> Raw<AnyStateEvent> {
let json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"unsigned": self.unsigned,
"room_id": self.room_id, "room_id": self.room_id,
"state_key": self.state_key, "state_key": self.state_key,
}); });
if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned);
}
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> { pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> {
let json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"unsigned": self.unsigned,
"state_key": self.state_key, "state_key": self.state_key,
}); });
if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned);
}
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
@ -214,18 +226,21 @@ impl PduEvent {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> { pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
let json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"redacts": self.redacts, "redacts": self.redacts,
"unsigned": self.unsigned,
"room_id": self.room_id, "room_id": self.room_id,
"state_key": self.state_key, "state_key": self.state_key,
}); });
if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned);
}
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }

View file

@ -392,11 +392,12 @@ impl Service {
} }
// The original create event must be in the auth events // The original create event must be in the auth events
if auth_events if !matches!(
.get(&(StateEventType::RoomCreate, "".to_owned())) auth_events
.map(|a| a.as_ref()) .get(&(StateEventType::RoomCreate, "".to_owned()))
!= Some(create_event) .map(|a| a.as_ref()),
{ Some(_) | None
) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Incoming event refers to wrong create event.", "Incoming event refers to wrong create event.",
@ -1459,12 +1460,12 @@ impl Service {
} }
if servers.is_empty() { if servers.is_empty() {
// We had all keys locally info!("We had all keys locally");
return Ok(()); return Ok(());
} }
for server in services().globals.trusted_servers() { for server in services().globals.trusted_servers() {
trace!("Asking batch signing keys from trusted server {}", server); info!("Asking batch signing keys from trusted server {}", server);
if let Ok(keys) = services() if let Ok(keys) = services()
.sending .sending
.send_federation_request( .send_federation_request(
@ -1507,10 +1508,12 @@ impl Service {
} }
if servers.is_empty() { if servers.is_empty() {
info!("Trusted server supplied all signing keys");
return Ok(()); return Ok(());
} }
} }
info!("Asking individual servers for signing keys: {servers:?}");
let mut futures: FuturesUnordered<_> = servers let mut futures: FuturesUnordered<_> = servers
.into_keys() .into_keys()
.map(|server| async move { .map(|server| async move {
@ -1525,21 +1528,27 @@ impl Service {
.collect(); .collect();
while let Some(result) = futures.next().await { while let Some(result) = futures.next().await {
info!("Received new result");
if let (Ok(get_keys_response), origin) = result { if let (Ok(get_keys_response), origin) = result {
let result: BTreeMap<_, _> = services() info!("Result is from {origin}");
.globals if let Ok(key) = get_keys_response.server_key.deserialize() {
.add_signing_key(&origin, get_keys_response.server_key.deserialize().unwrap())? let result: BTreeMap<_, _> = services()
.into_iter() .globals
.map(|(k, v)| (k.to_string(), v.key)) .add_signing_key(&origin, key)?
.collect(); .into_iter()
.map(|(k, v)| (k.to_string(), v.key))
pub_key_map .collect();
.write() pub_key_map
.map_err(|_| Error::bad_database("RwLock is poisoned."))? .write()
.insert(origin.to_string(), result); .map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(origin.to_string(), result);
}
} }
info!("Done handling result");
} }
info!("Search for signing keys done");
Ok(()) Ok(())
} }

View file

@ -9,11 +9,13 @@ use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::Result; use crate::Result;
use super::timeline::PduCount;
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
pub lazy_load_waiting: pub lazy_load_waiting:
Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, u64), HashSet<OwnedUserId>>>, Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>,
} }
impl Service { impl Service {
@ -36,7 +38,7 @@ impl Service {
device_id: &DeviceId, device_id: &DeviceId,
room_id: &RoomId, room_id: &RoomId,
lazy_load: HashSet<OwnedUserId>, lazy_load: HashSet<OwnedUserId>,
count: u64, count: PduCount,
) { ) {
self.lazy_load_waiting.lock().unwrap().insert( self.lazy_load_waiting.lock().unwrap().insert(
( (
@ -55,7 +57,7 @@ impl Service {
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
room_id: &RoomId, room_id: &RoomId,
since: u64, since: PduCount,
) -> Result<()> { ) -> Result<()> {
if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&(
user_id.to_owned(), user_id.to_owned(),

View file

@ -1,13 +1,29 @@
mod data; mod data;
use std::{collections::HashMap, sync::Arc}; use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
pub use data::Data; pub use data::Data;
use ruma::{events::StateEventType, EventId, RoomId}; use lru_cache::LruCache;
use ruma::{
events::{
room::{
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
member::{MembershipState, RoomMemberEventContent},
},
StateEventType,
},
EventId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use tracing::error;
use crate::{PduEvent, Result}; use crate::{services, Error, PduEvent, Result};
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>,
pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>,
} }
impl Service { impl Service {
@ -46,6 +62,178 @@ impl Service {
self.db.state_get(shortstatehash, event_type, state_key) self.db.state_get(shortstatehash, event_type, state_key)
} }
/// Get membership for given user in state
fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result<MembershipState> {
self.state_get(
shortstatehash,
&StateEventType::RoomMember,
user_id.as_str(),
)?
.map_or(Ok(MembershipState::Leave), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomMemberEventContent| c.membership)
.map_err(|_| Error::bad_database("Invalid room membership event in database."))
})
}
/// The user was a joined member at this state (potentially in the past)
fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id)
.map(|s| s == MembershipState::Join)
.unwrap_or_default() // Return sensible default, i.e. false
}
/// The user was an invited or joined room member at this state (potentially
/// in the past)
fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id)
.map(|s| s == MembershipState::Join || s == MembershipState::Invite)
.unwrap_or_default() // Return sensible default, i.e. false
}
/// Whether a server is allowed to see an event through federation, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, origin, room_id, event_id))]
pub fn server_can_see_event(
&self,
origin: &ServerName,
room_id: &RoomId,
event_id: &EventId,
) -> Result<bool> {
let shortstatehash = match self.pdu_shortstatehash(event_id)? {
Some(shortstatehash) => shortstatehash,
None => return Ok(true),
};
if let Some(visibility) = self
.server_visibility_cache
.lock()
.unwrap()
.get_mut(&(origin.to_owned(), shortstatehash))
{
return Ok(*visibility);
}
let history_visibility = self
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(HistoryVisibility::Shared), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|_| {
Error::bad_database("Invalid history visibility event in database.")
})
})?;
let mut current_server_members = services()
.rooms
.state_cache
.room_members(room_id)
.filter_map(|r| r.ok())
.filter(|member| member.server_name() == origin);
let visibility = match history_visibility {
HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
current_server_members.any(|member| self.user_was_invited(shortstatehash, &member))
}
HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
current_server_members.any(|member| self.user_was_joined(shortstatehash, &member))
}
_ => {
error!("Unknown history visibility {history_visibility}");
false
}
};
self.server_visibility_cache
.lock()
.unwrap()
.insert((origin.to_owned(), shortstatehash), visibility);
Ok(visibility)
}
/// Whether a user is allowed to see an event, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, user_id, room_id, event_id))]
pub fn user_can_see_event(
&self,
user_id: &UserId,
room_id: &RoomId,
event_id: &EventId,
) -> Result<bool> {
let shortstatehash = match self.pdu_shortstatehash(event_id)? {
Some(shortstatehash) => shortstatehash,
None => return Ok(true),
};
if let Some(visibility) = self
.user_visibility_cache
.lock()
.unwrap()
.get_mut(&(user_id.to_owned(), shortstatehash))
{
return Ok(*visibility);
}
let currently_member = services().rooms.state_cache.is_joined(&user_id, &room_id)?;
let history_visibility = self
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(HistoryVisibility::Shared), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|_| {
Error::bad_database("Invalid history visibility event in database.")
})
})?;
let visibility = match history_visibility {
HistoryVisibility::WorldReadable => true,
HistoryVisibility::Shared => currently_member,
HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
self.user_was_invited(shortstatehash, &user_id)
}
HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
self.user_was_joined(shortstatehash, &user_id)
}
_ => {
error!("Unknown history visibility {history_visibility}");
false
}
};
self.user_visibility_cache
.lock()
.unwrap()
.insert((user_id.to_owned(), shortstatehash), visibility);
Ok(visibility)
}
/// Whether a user is allowed to see an event, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, user_id, room_id))]
pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let currently_member = services().rooms.state_cache.is_joined(&user_id, &room_id)?;
let history_visibility = self
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(HistoryVisibility::Shared), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|_| {
Error::bad_database("Invalid history visibility event in database.")
})
})?;
Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable)
}
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.db.pdu_shortstatehash(event_id) self.db.pdu_shortstatehash(event_id)

View file

@ -4,12 +4,13 @@ use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use crate::{PduEvent, Result}; use crate::{PduEvent, Result};
use super::PduCount;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>>; fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount>;
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64>;
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>>; fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>>;
/// Returns the json of a pdu. /// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>; fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>;
@ -38,9 +39,6 @@ pub trait Data: Send + Sync {
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>>; fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>>;
/// Returns the `count` of this pdu's id.
fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64>;
/// Adds a new pdu to the timeline /// Adds a new pdu to the timeline
fn append_pdu( fn append_pdu(
&self, &self,
@ -50,33 +48,34 @@ pub trait Data: Send + Sync {
count: u64, count: u64,
) -> Result<()>; ) -> Result<()>;
// Adds a new pdu to the backfilled timeline
fn prepend_backfill_pdu(
&self,
pdu_id: &[u8],
event_id: &EventId,
json: &CanonicalJsonObject,
) -> Result<()>;
/// Removes a pdu and creates a new one with the same id. /// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()>; fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()>;
/// Returns an iterator over all events in a room that happened after the event with id `since`
/// in chronological order.
fn pdus_since<'a>(
&'a self,
user_id: &UserId,
room_id: &RoomId,
since: u64,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>>;
/// Returns an iterator over all events and their tokens in a room that happened before the /// Returns an iterator over all events and their tokens in a room that happened before the
/// event with id `until` in reverse-chronological order. /// event with id `until` in reverse-chronological order.
fn pdus_until<'a>( fn pdus_until<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
until: u64, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>>; ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;
/// Returns an iterator over all events in a room that happened after the event with id `from`
/// in chronological order.
fn pdus_after<'a>( fn pdus_after<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
from: u64, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>>; ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;
fn increment_notification_counts( fn increment_notification_counts(
&self, &self,

View file

@ -1,7 +1,9 @@
mod data; mod data;
use std::collections::HashMap; use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap};
use std::sync::RwLock;
use std::{ use std::{
collections::HashSet, collections::HashSet,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
@ -9,6 +11,8 @@ use std::{
pub use data::Data; pub use data::Data;
use regex::Regex; use regex::Regex;
use ruma::api::federation;
use ruma::serde::Base64;
use ruma::{ use ruma::{
api::client::error::ErrorKind, api::client::error::ErrorKind,
canonical_json::to_canonical_value, canonical_json::to_canonical_value,
@ -27,11 +31,13 @@ use ruma::{
uint, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, uint, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId,
OwnedServerName, RoomAliasId, RoomId, UserId, OwnedServerName, RoomAliasId, RoomId, UserId,
}; };
use ruma::{user_id, ServerName};
use serde::Deserialize; use serde::Deserialize;
use serde_json::value::to_raw_value; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::MutexGuard; use tokio::sync::MutexGuard;
use tracing::{error, warn}; use tracing::{error, info, warn};
use crate::api::server_server;
use crate::{ use crate::{
service::pdu::{EventHash, PduBuilder}, service::pdu::{EventHash, PduBuilder},
services, utils, Error, PduEvent, Result, services, utils, Error, PduEvent, Result,
@ -39,23 +45,91 @@ use crate::{
use super::state_compressor::CompressedStateEvent; use super::state_compressor::CompressedStateEvent;
#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)]
pub enum PduCount {
Backfilled(u64),
Normal(u64),
}
impl PduCount {
pub fn min() -> Self {
Self::Backfilled(u64::MAX)
}
pub fn max() -> Self {
Self::Normal(u64::MAX)
}
pub fn try_from_string(token: &str) -> Result<Self> {
if token.starts_with('-') {
token[1..].parse().map(PduCount::Backfilled)
} else {
token.parse().map(PduCount::Normal)
}
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token."))
}
pub fn stringify(&self) -> String {
match self {
PduCount::Backfilled(x) => format!("-{x}"),
PduCount::Normal(x) => x.to_string(),
}
}
}
impl PartialOrd for PduCount {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PduCount {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o),
(PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s),
(PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater,
(PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn comparisons() {
assert!(PduCount::Normal(1) < PduCount::Normal(2));
assert!(PduCount::Backfilled(2) < PduCount::Backfilled(1));
assert!(PduCount::Normal(1) > PduCount::Backfilled(1));
assert!(PduCount::Backfilled(1) < PduCount::Normal(1));
}
}
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
pub lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, u64>>, pub lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
} }
impl Service { impl Service {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
self.db.first_pdu_in_room(room_id) self.all_pdus(&user_id!("@doesntmatter:conduit.rs"), &room_id)?
.next()
.map(|o| o.map(|(_, p)| Arc::new(p)))
.transpose()
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> { pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
self.db.last_timeline_count(sender_user, room_id) self.db.last_timeline_count(sender_user, room_id)
} }
/// Returns the `count` of this pdu's id.
pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.db.get_pdu_count(event_id)
}
// TODO Is this the same as the function above? // TODO Is this the same as the function above?
/* /*
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
@ -79,11 +153,6 @@ impl Service {
} }
*/ */
/// Returns the `count` of this pdu's id.
pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> {
self.db.get_pdu_count(event_id)
}
/// Returns the json of a pdu. /// Returns the json of a pdu.
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.db.get_pdu_json(event_id) self.db.get_pdu_json(event_id)
@ -128,11 +197,6 @@ impl Service {
self.db.get_pdu_json_from_id(pdu_id) self.db.get_pdu_json_from_id(pdu_id)
} }
/// Returns the `count` of this pdu's id.
pub fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> {
self.db.pdu_count(pdu_id)
}
/// Removes a pdu and creates a new one with the same id. /// Removes a pdu and creates a new one with the same id.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> {
@ -863,19 +927,8 @@ impl Service {
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> {
self.pdus_since(user_id, room_id, 0) self.pdus_after(user_id, room_id, PduCount::min())
}
/// Returns an iterator over all events in a room that happened after the event with id `since`
/// in chronological order.
pub fn pdus_since<'a>(
&'a self,
user_id: &UserId,
room_id: &RoomId,
since: u64,
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
self.db.pdus_since(user_id, room_id, since)
} }
/// Returns an iterator over all events and their tokens in a room that happened before the /// Returns an iterator over all events and their tokens in a room that happened before the
@ -885,8 +938,8 @@ impl Service {
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
until: u64, until: PduCount,
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> {
self.db.pdus_until(user_id, room_id, until) self.db.pdus_until(user_id, room_id, until)
} }
@ -897,8 +950,8 @@ impl Service {
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
from: u64, from: PduCount,
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> {
self.db.pdus_after(user_id, room_id, from) self.db.pdus_after(user_id, room_id, from)
} }
@ -915,4 +968,159 @@ impl Service {
// If event does not exist, just noop // If event does not exist, just noop
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, room_id))]
pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> {
let first_pdu = self
.all_pdus(&user_id!("@doesntmatter:conduit.rs"), &room_id)?
.next()
.expect("Room is not empty")?;
if first_pdu.0 < from {
// No backfill required, there are still events between them
return Ok(());
}
let power_levels: RoomPowerLevelsEventContent = services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| {
serde_json::from_str(ev.content.get())
.map_err(|_| Error::bad_database("invalid m.room.power_levels event"))
})
.transpose()?
.unwrap_or_default();
let mut admin_servers = power_levels
.users
.iter()
.filter(|(_, level)| **level > power_levels.users_default)
.map(|(user_id, _)| user_id.server_name())
.collect::<HashSet<_>>();
admin_servers.remove(services().globals.server_name());
// Request backfill
for backfill_server in admin_servers {
info!("Asking {backfill_server} for backfill");
let response = services()
.sending
.send_federation_request(
backfill_server,
federation::backfill::get_backfill::v1::Request {
room_id: room_id.to_owned(),
v: vec![first_pdu.1.event_id.as_ref().to_owned()],
limit: uint!(100),
},
)
.await;
match response {
Ok(response) => {
let mut pub_key_map = RwLock::new(BTreeMap::new());
for pdu in response.pdus {
if let Err(e) = self
.backfill_pdu(backfill_server, pdu, &mut pub_key_map)
.await
{
warn!("Failed to add backfilled pdu: {e}");
}
}
return Ok(());
}
Err(e) => {
warn!("{backfill_server} could not provide backfill: {e}");
}
}
}
info!("No servers could backfill");
Ok(())
}
#[tracing::instrument(skip(self, pdu))]
pub async fn backfill_pdu(
&self,
origin: &ServerName,
pdu: Box<RawJsonValue>,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?;
// Lock so we cannot backfill the same pdu twice at the same time
let mutex = Arc::clone(
services()
.globals
.roomid_mutex_federation
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default(),
);
let mutex_lock = mutex.lock().await;
// Skip the PDU if we already have it as a timeline event
if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(&event_id)? {
info!("We already know {event_id} at {pdu_id:?}");
return Ok(());
}
services()
.rooms
.event_handler
.handle_incoming_pdu(origin, &event_id, &room_id, value, false, &pub_key_map)
.await?;
let value = self.get_pdu_json(&event_id)?.expect("We just created it");
let pdu = self.get_pdu(&event_id)?.expect("We just created it");
let shortroomid = services()
.rooms
.short
.get_shortroomid(&room_id)?
.expect("room exists");
let mutex_insert = Arc::clone(
services()
.globals
.roomid_mutex_insert
.write()
.unwrap()
.entry(room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().unwrap();
let count = services().globals.next_count()?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes());
// Insert pdu
self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?;
drop(insert_lock);
match pdu.kind {
RoomEventType::RoomMessage => {
#[derive(Deserialize)]
struct ExtractBody {
body: Option<String>,
}
let content = serde_json::from_str::<ExtractBody>(pdu.content.get())
.map_err(|_| Error::bad_database("Invalid content in pdu."))?;
if let Some(body) = content.body {
services()
.rooms
.search
.index_pdu(shortroomid, &pdu_id, &body)?;
}
}
_ => {}
}
drop(mutex_lock);
info!("Prepended backfill pdu");
Ok(())
}
} }

View file

@ -40,7 +40,7 @@ use tokio::{
select, select,
sync::{mpsc, Mutex, Semaphore}, sync::{mpsc, Mutex, Semaphore},
}; };
use tracing::{error, warn}; use tracing::{debug, error, warn};
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum OutgoingKind { pub enum OutgoingKind {
@ -683,8 +683,18 @@ impl Service {
where where
T: Debug, T: Debug,
{ {
debug!("Waiting for permit");
let permit = self.maximum_requests.acquire().await; let permit = self.maximum_requests.acquire().await;
let response = server_server::send_request(destination, request).await; debug!("Got permit");
let response = tokio::time::timeout(
Duration::from_secs(2 * 60),
server_server::send_request(destination, request),
)
.await
.map_err(|_| {
warn!("Timeout waiting for server response of {destination}");
Error::BadServerResponse("Timeout waiting for server response")
})?;
drop(permit); drop(permit);
response response