split sending/send base functions

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-04-23 15:31:40 -07:00 committed by June
parent f919fa879b
commit 57e6af6e21
2 changed files with 36 additions and 18 deletions

View file

@ -234,7 +234,7 @@ impl Service {
let permit = self.maximum_requests.acquire().await;
let timeout = Duration::from_secs(self.timeout);
let client = &services().globals.client.federation;
let response = tokio::time::timeout(timeout, send::send_request(client, dest, request))
let response = tokio::time::timeout(timeout, send::send(client, dest, request))
.await
.map_err(|_| {
warn!("Timeout after 300 seconds waiting for server response of {dest}");
@ -795,7 +795,7 @@ async fn send_events_dest_normal(
let permit = services().sending.maximum_requests.acquire().await;
let client = &services().globals.client.sender;
let response = send::send_request(
let response = send::send(
client,
server_name,
send_transaction_message::v1::Request {

View file

@ -51,7 +51,7 @@ struct ActualDest {
}
#[tracing::instrument(skip_all, name = "send")]
pub(crate) async fn send_request<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
pub(crate) async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
@ -59,22 +59,19 @@ where
return Err(Error::bad_config("Federation is disabled."));
}
trace!("Preparing to send request");
validate_dest(dest)?;
let actual = get_actual_dest(dest).await?;
let mut http_request = req
.try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_5])
.map_err(|e| {
debug_warn!("Failed to find destination {}: {}", actual.string, e);
Error::BadServerResponse("Invalid destination")
})?;
let request = prepare::<T>(dest, &actual, req).await?;
execute::<T>(client, dest, &actual, request).await
}
sign_request::<T>(dest, &mut http_request);
let request = Request::try_from(http_request)?;
async fn execute<T>(
client: &Client, dest: &ServerName, actual: &ActualDest, request: Request,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
let method = request.method().clone();
let url = request.url().clone();
validate_url(&url)?;
debug!(
method = ?method,
url = ?url,
@ -82,12 +79,32 @@ where
);
match client.execute(request).await {
Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await,
Err(e) => handle_error::<T>(dest, &actual, &method, &url, e),
Err(e) => handle_error::<T>(dest, actual, &method, &url, e),
}
}
async fn prepare<T>(dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request>
where
T: OutgoingRequest + Debug,
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5];
trace!("Preparing request");
let mut http_request = req
.try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &VERSIONS)
.map_err(|_e| Error::BadServerResponse("Invalid destination"))?;
sign_request::<T>(dest, &mut http_request);
let request = Request::try_from(http_request)?;
validate_url(request.url())?;
Ok(request)
}
async fn handle_response<T>(
dest: &ServerName, actual: ActualDest, method: &Method, url: &Url, mut response: Response,
dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
@ -126,7 +143,7 @@ where
.actual_destinations()
.write()
.await
.insert(OwnedServerName::from(dest), (actual.dest, actual.host));
.insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone()));
}
match response {
@ -176,6 +193,7 @@ async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
result
} else {
cached = false;
validate_dest(server_name)?;
resolve_actual_dest(server_name).await?
};