use mutex_map for url preview lock

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-17 02:48:43 +00:00
parent b0ac5255c8
commit b116984e46
2 changed files with 10 additions and 16 deletions

View file

@ -1,6 +1,6 @@
#![allow(deprecated)]
use std::{io::Cursor, sync::Arc, time::Duration};
use std::{io::Cursor, time::Duration};
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
@ -656,16 +656,7 @@ async fn get_url_preview(services: &Services, url: &str) -> Result<UrlPreviewDat
}
// ensure that only one request is made per URL
let mutex_request = Arc::clone(
services
.media
.url_preview_mutex
.write()
.await
.entry(url.to_owned())
.or_default(),
);
let _request_lock = mutex_request.lock().await;
let _request_lock = services.media.url_preview_mutex.lock(url).await;
match services.media.get_url_preview(url).await {
Some(preview) => Ok(preview),

View file

@ -2,18 +2,17 @@ mod data;
mod tests;
mod thumbnail;
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::SystemTime};
use std::{path::PathBuf, sync::Arc, time::SystemTime};
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use conduit::{debug, debug_error, err, error, utils, Err, Result, Server};
use conduit::{debug, debug_error, err, error, utils, utils::MutexMap, Err, Result, Server};
use data::{Data, Metadata};
use ruma::{OwnedMxcUri, OwnedUserId};
use serde::Serialize;
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt, BufReader},
sync::{Mutex, RwLock},
};
use crate::services;
@ -44,7 +43,7 @@ pub struct UrlPreviewData {
pub struct Service {
server: Arc<Server>,
pub(crate) db: Data,
pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>,
pub url_preview_mutex: MutexMap<String, ()>,
}
#[async_trait]
@ -53,7 +52,7 @@ impl crate::Service for Service {
Ok(Arc::new(Self {
server: args.server.clone(),
db: Data::new(args.db),
url_preview_mutex: RwLock::new(HashMap::new()),
url_preview_mutex: MutexMap::new(),
}))
}
@ -274,10 +273,12 @@ impl Service {
}
#[inline]
#[must_use]
pub fn get_media_file(&self, key: &[u8]) -> PathBuf { self.get_media_file_sha256(key) }
/// new SHA256 file name media function. requires database migrated. uses
/// SHA256 hash of the base64 key as the file name
#[must_use]
pub fn get_media_file_sha256(&self, key: &[u8]) -> PathBuf {
let mut r = self.get_media_dir();
// Using the hash of the base64 key as the filename
@ -292,6 +293,7 @@ impl Service {
/// old base64 file name media function
/// This is the old version of `get_media_file` that uses the full base64
/// key as the filename.
#[must_use]
pub fn get_media_file_b64(&self, key: &[u8]) -> PathBuf {
let mut r = self.get_media_dir();
let encoded = encode_key(key);
@ -299,6 +301,7 @@ impl Service {
r
}
#[must_use]
pub fn get_media_dir(&self) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.server.config.database_path.clone());