refactor: use Arc<Mutex<SqliteConnection>> in SqliteTagDao, remove unsafe impl Sync
Aligns SqliteTagDao with the pattern used by SqliteExifDao and SqliteInsightDao. The unsafe impl Sync workaround is no longer needed since Arc<Mutex<>> provides safe interior mutability and automatic Sync derivation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1490,7 +1490,7 @@ mod tests {
|
||||
let request: Query<FilesRequest> =
|
||||
Query::from_query("path=&tag_ids=1,3&recursive=true").unwrap();
|
||||
|
||||
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection());
|
||||
let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
|
||||
|
||||
let tag1 = tag_dao
|
||||
.create_tag(&opentelemetry::Context::current(), "tag1")
|
||||
@@ -1536,7 +1536,7 @@ mod tests {
|
||||
exp: 12345,
|
||||
};
|
||||
|
||||
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection());
|
||||
let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
|
||||
|
||||
let tag1 = tag_dao
|
||||
.create_tag(&opentelemetry::Context::current(), "tag1")
|
||||
|
||||
91
src/tags.rs
91
src/tags.rs
@@ -14,8 +14,8 @@ use opentelemetry::KeyValue;
|
||||
use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer};
|
||||
use schema::{tagged_photo, tags};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::BorrowMut;
|
||||
use std::sync::Mutex;
|
||||
use std::ops::DerefMut;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub fn add_tag_services<T, TagD: TagDao + 'static>(app: App<T>) -> App<T>
|
||||
where
|
||||
@@ -330,25 +330,23 @@ pub trait TagDao: Send + Sync {
|
||||
}
|
||||
|
||||
pub struct SqliteTagDao {
|
||||
connection: SqliteConnection,
|
||||
connection: Arc<Mutex<SqliteConnection>>,
|
||||
}
|
||||
|
||||
impl SqliteTagDao {
|
||||
pub(crate) fn new(connection: SqliteConnection) -> Self {
|
||||
pub(crate) fn new(connection: Arc<Mutex<SqliteConnection>>) -> Self {
|
||||
SqliteTagDao { connection }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SqliteTagDao {
|
||||
fn default() -> Self {
|
||||
SqliteTagDao::new(connect())
|
||||
SqliteTagDao {
|
||||
connection: Arc::new(Mutex::new(connect())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: SqliteTagDao is always accessed through Arc<Mutex<...>>,
|
||||
// so concurrent access is prevented by the Mutex.
|
||||
unsafe impl Sync for SqliteTagDao {}
|
||||
|
||||
impl TagDao for SqliteTagDao {
|
||||
fn get_all_tags(
|
||||
&mut self,
|
||||
@@ -357,6 +355,10 @@ impl TagDao for SqliteTagDao {
|
||||
) -> anyhow::Result<Vec<(i64, Tag)>> {
|
||||
// select name, count(*) from tags join tagged_photo ON tags.id = tagged_photo.tag_id GROUP BY tags.name ORDER BY COUNT(*);
|
||||
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_all_tags", |span| {
|
||||
span.set_attribute(KeyValue::new("path", path.clone().unwrap_or_default()));
|
||||
|
||||
@@ -367,7 +369,7 @@ impl TagDao for SqliteTagDao {
|
||||
.group_by(tags::id)
|
||||
.select((count_star(), id, name, created_time))
|
||||
.filter(tagged_photo::photo_name.like(path))
|
||||
.get_results(&mut self.connection)
|
||||
.get_results(conn.deref_mut())
|
||||
.map::<Vec<(i64, Tag)>, _>(|tags_with_count: Vec<(i64, i32, String, i64)>| {
|
||||
tags_with_count
|
||||
.iter()
|
||||
@@ -392,6 +394,10 @@ impl TagDao for SqliteTagDao {
|
||||
context: &opentelemetry::Context,
|
||||
path: &str,
|
||||
) -> anyhow::Result<Vec<Tag>> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_tags_for_path", |span| {
|
||||
span.set_attribute(KeyValue::new("path", path.to_string()));
|
||||
|
||||
@@ -400,12 +406,16 @@ impl TagDao for SqliteTagDao {
|
||||
.left_join(tagged_photo::table)
|
||||
.filter(tagged_photo::photo_name.eq(&path))
|
||||
.select((tags::id, tags::name, tags::created_time))
|
||||
.get_results::<Tag>(self.connection.borrow_mut())
|
||||
.get_results::<Tag>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get tags from Sqlite")
|
||||
})
|
||||
}
|
||||
|
||||
fn create_tag(&mut self, context: &opentelemetry::Context, name: &str) -> anyhow::Result<Tag> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "insert", "create_tag", |span| {
|
||||
span.set_attribute(KeyValue::new("name", name.to_string()));
|
||||
|
||||
@@ -414,7 +424,7 @@ impl TagDao for SqliteTagDao {
|
||||
name: name.to_string(),
|
||||
created_time: Utc::now().timestamp(),
|
||||
})
|
||||
.execute(&mut self.connection)
|
||||
.execute(conn.deref_mut())
|
||||
.with_context(|| format!("Unable to insert tag {:?} in Sqlite", name))
|
||||
.and_then(|_| {
|
||||
info!("Inserted tag: {:?}", name);
|
||||
@@ -422,7 +432,7 @@ impl TagDao for SqliteTagDao {
|
||||
fn last_insert_rowid() -> Integer;
|
||||
}
|
||||
diesel::select(last_insert_rowid())
|
||||
.get_result::<i32>(&mut self.connection)
|
||||
.get_result::<i32>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get last inserted tag from Sqlite")
|
||||
})
|
||||
.and_then(|id| {
|
||||
@@ -430,7 +440,7 @@ impl TagDao for SqliteTagDao {
|
||||
tags::table
|
||||
.filter(tags::id.eq(id))
|
||||
.select((tags::id, tags::name, tags::created_time))
|
||||
.get_result::<Tag>(self.connection.borrow_mut())
|
||||
.get_result::<Tag>(conn.deref_mut())
|
||||
.with_context(|| {
|
||||
format!("Unable to get tagged photo with id: {:?} from Sqlite", id)
|
||||
})
|
||||
@@ -444,6 +454,10 @@ impl TagDao for SqliteTagDao {
|
||||
tag_name: &str,
|
||||
path: &str,
|
||||
) -> anyhow::Result<Option<()>> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "delete", "remove_tag", |span| {
|
||||
span.set_attributes(vec![
|
||||
KeyValue::new("tag_name", tag_name.to_string()),
|
||||
@@ -452,7 +466,7 @@ impl TagDao for SqliteTagDao {
|
||||
|
||||
tags::table
|
||||
.filter(tags::name.eq(tag_name))
|
||||
.get_result::<Tag>(self.connection.borrow_mut())
|
||||
.get_result::<Tag>(conn.deref_mut())
|
||||
.optional()
|
||||
.with_context(|| format!("Unable to get tag '{}'", tag_name))
|
||||
.and_then(|tag| {
|
||||
@@ -462,7 +476,7 @@ impl TagDao for SqliteTagDao {
|
||||
.filter(tagged_photo::tag_id.eq(tag.id))
|
||||
.filter(tagged_photo::photo_name.eq(path)),
|
||||
)
|
||||
.execute(&mut self.connection)
|
||||
.execute(conn.deref_mut())
|
||||
.with_context(|| format!("Unable to delete tag: '{}'", &tag.name))
|
||||
.map(|_| Some(()))
|
||||
} else {
|
||||
@@ -479,6 +493,10 @@ impl TagDao for SqliteTagDao {
|
||||
path: &str,
|
||||
tag_id: i32,
|
||||
) -> anyhow::Result<TaggedPhoto> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "insert", "tag_file", |span| {
|
||||
span.set_attributes(vec![
|
||||
KeyValue::new("path", path.to_string()),
|
||||
@@ -491,7 +509,7 @@ impl TagDao for SqliteTagDao {
|
||||
photo_name: path.to_string(),
|
||||
created_time: Utc::now().timestamp(),
|
||||
})
|
||||
.execute(self.connection.borrow_mut())
|
||||
.execute(conn.deref_mut())
|
||||
.with_context(|| format!("Unable to tag file {:?} in sqlite", path))
|
||||
.and_then(|_| {
|
||||
info!("Inserted tagged photo: {:#} -> {:?}", tag_id, path);
|
||||
@@ -499,13 +517,13 @@ impl TagDao for SqliteTagDao {
|
||||
fn last_insert_rowid() -> diesel::sql_types::Integer;
|
||||
}
|
||||
diesel::select(last_insert_rowid())
|
||||
.get_result::<i32>(&mut self.connection)
|
||||
.get_result::<i32>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get last inserted tag from Sqlite")
|
||||
})
|
||||
.and_then(|tagged_id| {
|
||||
tagged_photo::table
|
||||
.find(tagged_id)
|
||||
.first(self.connection.borrow_mut())
|
||||
.first(conn.deref_mut())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Error getting inserted tagged photo with id: {:?}",
|
||||
@@ -522,6 +540,10 @@ impl TagDao for SqliteTagDao {
|
||||
exclude_tag_ids: Vec<i32>,
|
||||
context: &opentelemetry::Context,
|
||||
) -> anyhow::Result<Vec<FileWithTagCount>> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_files_with_all_tags", |_| {
|
||||
use diesel::dsl::*;
|
||||
|
||||
@@ -568,7 +590,7 @@ impl TagDao for SqliteTagDao {
|
||||
.fold(query, |q, id| q.bind::<Integer, _>(id));
|
||||
|
||||
query
|
||||
.load::<FileWithTagCount>(&mut self.connection)
|
||||
.load::<FileWithTagCount>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get tagged photos with all specified tags")
|
||||
})
|
||||
}
|
||||
@@ -579,6 +601,10 @@ impl TagDao for SqliteTagDao {
|
||||
exclude_tag_ids: Vec<i32>,
|
||||
context: &opentelemetry::Context,
|
||||
) -> anyhow::Result<Vec<FileWithTagCount>> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_files_with_any_tags", |_| {
|
||||
use diesel::dsl::*;
|
||||
// Create the placeholders for the IN clauses
|
||||
@@ -620,7 +646,7 @@ impl TagDao for SqliteTagDao {
|
||||
.fold(query, |q, id| q.bind::<Integer, _>(id));
|
||||
|
||||
query
|
||||
.load::<FileWithTagCount>(&mut self.connection)
|
||||
.load::<FileWithTagCount>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get tagged photos")
|
||||
})
|
||||
}
|
||||
@@ -633,9 +659,13 @@ impl TagDao for SqliteTagDao {
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::database::schema::tagged_photo::dsl::*;
|
||||
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
diesel::update(tagged_photo.filter(photo_name.eq(old_name)))
|
||||
.set(photo_name.eq(new_name))
|
||||
.execute(&mut self.connection)?;
|
||||
.execute(conn.deref_mut())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -645,10 +675,14 @@ impl TagDao for SqliteTagDao {
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
use crate::database::schema::tagged_photo::dsl::*;
|
||||
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
tagged_photo
|
||||
.select(photo_name)
|
||||
.distinct()
|
||||
.load(&mut self.connection)
|
||||
.load(conn.deref_mut())
|
||||
.with_context(|| "Unable to get photo names")
|
||||
}
|
||||
|
||||
@@ -659,6 +693,10 @@ impl TagDao for SqliteTagDao {
|
||||
) -> anyhow::Result<std::collections::HashMap<String, i64>> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_tag_counts_batch", |span| {
|
||||
span.set_attribute(KeyValue::new("file_count", file_paths.len() as i64));
|
||||
|
||||
@@ -701,7 +739,7 @@ impl TagDao for SqliteTagDao {
|
||||
|
||||
// Execute query and convert to HashMap
|
||||
query
|
||||
.load::<TagCountRow>(&mut self.connection)
|
||||
.load::<TagCountRow>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get batch tag counts")
|
||||
.map(|rows| {
|
||||
rows.into_iter()
|
||||
@@ -739,7 +777,10 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: TestTagDao is only used in single-threaded tests
|
||||
// SAFETY: TestTagDao uses RefCell<T> fields which are !Send because they allow
|
||||
// multiple mutable borrows without coordination. This impl is sound because
|
||||
// TestTagDao is test-only, used within a single test function, and never moved
|
||||
// into spawned tasks or shared across threads.
|
||||
unsafe impl Send for TestTagDao {}
|
||||
unsafe impl Sync for TestTagDao {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user