diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 502de0f7..a94b9f86 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -1348,16 +1348,7 @@ pub(crate) async fn invite_helper( "Could not accept incoming PDU as timeline event.", ))?; - // Bind to variable because of lifetimes - let servers = services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); - - services().sending.send_pdu(servers, &pdu_id)?; - + services().sending.send_pdu_room(room_id, &pdu_id)?; return Ok(()); } diff --git a/src/api/client_server/to_device.rs b/src/api/client_server/to_device.rs index 57e2f114..128d3cda 100644 --- a/src/api/client_server/to_device.rs +++ b/src/api/client_server/to_device.rs @@ -37,7 +37,7 @@ pub async fn send_event_to_device_route( messages.insert(target_user_id.clone(), map); let count = services().globals.next_count()?; - services().sending.send_reliable_edu( + services().sending.send_edu_server( target_user_id.server_name(), serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { sender: sender_user.clone(), @@ -46,7 +46,6 @@ pub async fn send_event_to_device_route( messages, })) .expect("DirectToDevice EDU can be serialized"), - count, )?; continue; diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 88d2f786..69050a5f 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1672,14 +1672,7 @@ async fn create_join_event( .get_auth_chain(room_id, state_ids.values().cloned().collect()) .await?; - let servers = services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); - - services().sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu_room(room_id, &pdu_id)?; Ok(create_join_event::v1::RoomState { auth_chain: auth_chain_ids diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 52273390..68a65991 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -411,7 +411,7 @@ impl Service { } for push_key in services().pusher.get_pushkeys(user) { - services().sending.send_push_pdu(&pdu_id, user, push_key?)?; + services().sending.send_pdu_push(&pdu_id, user, push_key?)?; } } @@ -958,7 +958,9 @@ impl Service { // room_servers() and/or the if statement above servers.remove(services().globals.server_name()); - services().sending.send_pdu(servers.into_iter(), &pdu_id)?; + services() + .sending + .send_pdu_servers(servers.into_iter(), &pdu_id)?; Ok(pdu.event_id) } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 579f6814..20884950 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -408,7 +408,7 @@ impl Service { } #[tracing::instrument(skip(self, pdu_id, user, pushkey))] - pub fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); let event = SendingEventType::Pdu(pdu_id.to_owned()); let _cork = services().globals.db.cork()?; @@ -420,41 +420,6 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, servers, pdu_id))] - pub fn send_pdu>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned()))) - .collect::>(); - let _cork = services().globals.db.cork()?; - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; - for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { - self.sender - .send((outgoing_kind.clone(), event, key)) - .unwrap(); - } - - Ok(()) - } - - #[tracing::instrument(skip(self, server, serialized))] - pub fn send_reliable_edu(&self, server: &ServerName, serialized: Vec, id: u64) -> Result<()> { - let outgoing_kind = OutgoingKind::Normal(server.to_owned()); - let event = SendingEventType::Edu(serialized); - let _cork = services().globals.db.cork()?; - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender - .send((outgoing_kind, event, keys.into_iter().next().unwrap())) - .unwrap(); - - Ok(()) - } - #[tracing::instrument(skip(self))] pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { let outgoing_kind = OutgoingKind::Appservice(appservice_id); @@ -468,25 +433,102 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, room_id))] - pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { - let servers: HashSet = services() + #[tracing::instrument(skip(self, room_id, pdu_id))] + pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + let servers = services() .rooms .state_cache .room_servers(room_id) .filter_map(Result::ok) - .collect(); + .filter(|server| &**server != services().globals.server_name()); - self.flush_servers(servers.into_iter()) + self.send_pdu_servers(servers, pdu_id) + } + + #[tracing::instrument(skip(self, servers, pdu_id))] + pub fn send_pdu_servers>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { + let requests = servers + .into_iter() + .map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned()))) + .collect::>(); + let _cork = services().globals.db.cork()?; + let keys = self.db.queue_requests( + &requests + .iter() + .map(|(o, e)| (o, e.clone())) + .collect::>(), + )?; + for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { + self.sender + .send((outgoing_kind.clone(), event, key)) + .unwrap(); + } + + Ok(()) + } + + #[tracing::instrument(skip(self, server, serialized))] + pub fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { + let outgoing_kind = OutgoingKind::Normal(server.to_owned()); + let event = SendingEventType::Edu(serialized); + let _cork = services().globals.db.cork()?; + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id, serialized))] + pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(Result::ok) + .filter(|server| &**server != services().globals.server_name()); + + self.send_edu_servers(servers, serialized) + } + + #[tracing::instrument(skip(self, servers, serialized))] + pub fn send_edu_servers>(&self, servers: I, serialized: Vec) -> Result<()> { + let requests = servers + .into_iter() + .map(|server| (OutgoingKind::Normal(server), SendingEventType::Edu(serialized.clone()))) + .collect::>(); + let _cork = services().globals.db.cork()?; + let keys = self.db.queue_requests( + &requests + .iter() + .map(|(o, e)| (o, e.clone())) + .collect::>(), + )?; + for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { + self.sender + .send((outgoing_kind.clone(), event, key)) + .unwrap(); + } + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id))] + pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(Result::ok) + .filter(|server| &**server != services().globals.server_name()); + + self.flush_servers(servers) } #[tracing::instrument(skip(self, servers))] pub fn flush_servers>(&self, servers: I) -> Result<()> { - let requests = servers - .into_iter() - .filter(|server| server != services().globals.server_name()) - .map(OutgoingKind::Normal) - .collect::>(); + let requests = servers.into_iter().map(OutgoingKind::Normal); for outgoing_kind in requests { self.sender