refactor: use Arc<Mutex<SqliteConnection>> in SqliteTagDao, remove unsafe impl Sync

Aligns SqliteTagDao with the pattern used by SqliteExifDao and SqliteInsightDao.
The unsafe impl Sync workaround is no longer needed since Arc<Mutex<>> provides
safe interior mutability and automatic Sync derivation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Cameron
2026-03-18 17:10:11 -04:00
parent 387ce23afd
commit 8ecd3c6cf8
2 changed files with 68 additions and 27 deletions

View File

@@ -1490,7 +1490,7 @@ mod tests {
let request: Query<FilesRequest> =
Query::from_query("path=&tag_ids=1,3&recursive=true").unwrap();
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection());
let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
let tag1 = tag_dao
.create_tag(&opentelemetry::Context::current(), "tag1")
@@ -1536,7 +1536,7 @@ mod tests {
exp: 12345,
};
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection());
let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
let tag1 = tag_dao
.create_tag(&opentelemetry::Context::current(), "tag1")

View File

@@ -14,8 +14,8 @@ use opentelemetry::KeyValue;
use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer};
use schema::{tagged_photo, tags};
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::sync::Mutex;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex};
pub fn add_tag_services<T, TagD: TagDao + 'static>(app: App<T>) -> App<T>
where
@@ -330,24 +330,22 @@ pub trait TagDao: Send + Sync {
}
pub struct SqliteTagDao {
connection: SqliteConnection,
connection: Arc<Mutex<SqliteConnection>>,
}
impl SqliteTagDao {
pub(crate) fn new(connection: SqliteConnection) -> Self {
pub(crate) fn new(connection: Arc<Mutex<SqliteConnection>>) -> Self {
SqliteTagDao { connection }
}
}
impl Default for SqliteTagDao {
fn default() -> Self {
SqliteTagDao::new(connect())
SqliteTagDao {
connection: Arc::new(Mutex::new(connect())),
}
}
}
// SAFETY: SqliteTagDao is always accessed through Arc<Mutex<...>>,
// so concurrent access is prevented by the Mutex.
unsafe impl Sync for SqliteTagDao {}
impl TagDao for SqliteTagDao {
fn get_all_tags(
@@ -357,6 +355,10 @@ impl TagDao for SqliteTagDao {
) -> anyhow::Result<Vec<(i64, Tag)>> {
// select name, count(*) from tags join tagged_photo ON tags.id = tagged_photo.tag_id GROUP BY tags.name ORDER BY COUNT(*);
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_all_tags", |span| {
span.set_attribute(KeyValue::new("path", path.clone().unwrap_or_default()));
@@ -367,7 +369,7 @@ impl TagDao for SqliteTagDao {
.group_by(tags::id)
.select((count_star(), id, name, created_time))
.filter(tagged_photo::photo_name.like(path))
.get_results(&mut self.connection)
.get_results(conn.deref_mut())
.map::<Vec<(i64, Tag)>, _>(|tags_with_count: Vec<(i64, i32, String, i64)>| {
tags_with_count
.iter()
@@ -392,6 +394,10 @@ impl TagDao for SqliteTagDao {
context: &opentelemetry::Context,
path: &str,
) -> anyhow::Result<Vec<Tag>> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_tags_for_path", |span| {
span.set_attribute(KeyValue::new("path", path.to_string()));
@@ -400,12 +406,16 @@ impl TagDao for SqliteTagDao {
.left_join(tagged_photo::table)
.filter(tagged_photo::photo_name.eq(&path))
.select((tags::id, tags::name, tags::created_time))
.get_results::<Tag>(self.connection.borrow_mut())
.get_results::<Tag>(conn.deref_mut())
.with_context(|| "Unable to get tags from Sqlite")
})
}
fn create_tag(&mut self, context: &opentelemetry::Context, name: &str) -> anyhow::Result<Tag> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "insert", "create_tag", |span| {
span.set_attribute(KeyValue::new("name", name.to_string()));
@@ -414,7 +424,7 @@ impl TagDao for SqliteTagDao {
name: name.to_string(),
created_time: Utc::now().timestamp(),
})
.execute(&mut self.connection)
.execute(conn.deref_mut())
.with_context(|| format!("Unable to insert tag {:?} in Sqlite", name))
.and_then(|_| {
info!("Inserted tag: {:?}", name);
@@ -422,7 +432,7 @@ impl TagDao for SqliteTagDao {
fn last_insert_rowid() -> Integer;
}
diesel::select(last_insert_rowid())
.get_result::<i32>(&mut self.connection)
.get_result::<i32>(conn.deref_mut())
.with_context(|| "Unable to get last inserted tag from Sqlite")
})
.and_then(|id| {
@@ -430,7 +440,7 @@ impl TagDao for SqliteTagDao {
tags::table
.filter(tags::id.eq(id))
.select((tags::id, tags::name, tags::created_time))
.get_result::<Tag>(self.connection.borrow_mut())
.get_result::<Tag>(conn.deref_mut())
.with_context(|| {
format!("Unable to get tagged photo with id: {:?} from Sqlite", id)
})
@@ -444,6 +454,10 @@ impl TagDao for SqliteTagDao {
tag_name: &str,
path: &str,
) -> anyhow::Result<Option<()>> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "delete", "remove_tag", |span| {
span.set_attributes(vec![
KeyValue::new("tag_name", tag_name.to_string()),
@@ -452,7 +466,7 @@ impl TagDao for SqliteTagDao {
tags::table
.filter(tags::name.eq(tag_name))
.get_result::<Tag>(self.connection.borrow_mut())
.get_result::<Tag>(conn.deref_mut())
.optional()
.with_context(|| format!("Unable to get tag '{}'", tag_name))
.and_then(|tag| {
@@ -462,7 +476,7 @@ impl TagDao for SqliteTagDao {
.filter(tagged_photo::tag_id.eq(tag.id))
.filter(tagged_photo::photo_name.eq(path)),
)
.execute(&mut self.connection)
.execute(conn.deref_mut())
.with_context(|| format!("Unable to delete tag: '{}'", &tag.name))
.map(|_| Some(()))
} else {
@@ -479,6 +493,10 @@ impl TagDao for SqliteTagDao {
path: &str,
tag_id: i32,
) -> anyhow::Result<TaggedPhoto> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "insert", "tag_file", |span| {
span.set_attributes(vec![
KeyValue::new("path", path.to_string()),
@@ -491,7 +509,7 @@ impl TagDao for SqliteTagDao {
photo_name: path.to_string(),
created_time: Utc::now().timestamp(),
})
.execute(self.connection.borrow_mut())
.execute(conn.deref_mut())
.with_context(|| format!("Unable to tag file {:?} in sqlite", path))
.and_then(|_| {
info!("Inserted tagged photo: {:#} -> {:?}", tag_id, path);
@@ -499,13 +517,13 @@ impl TagDao for SqliteTagDao {
fn last_insert_rowid() -> diesel::sql_types::Integer;
}
diesel::select(last_insert_rowid())
.get_result::<i32>(&mut self.connection)
.get_result::<i32>(conn.deref_mut())
.with_context(|| "Unable to get last inserted tag from Sqlite")
})
.and_then(|tagged_id| {
tagged_photo::table
.find(tagged_id)
.first(self.connection.borrow_mut())
.first(conn.deref_mut())
.with_context(|| {
format!(
"Error getting inserted tagged photo with id: {:?}",
@@ -522,6 +540,10 @@ impl TagDao for SqliteTagDao {
exclude_tag_ids: Vec<i32>,
context: &opentelemetry::Context,
) -> anyhow::Result<Vec<FileWithTagCount>> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_files_with_all_tags", |_| {
use diesel::dsl::*;
@@ -568,7 +590,7 @@ impl TagDao for SqliteTagDao {
.fold(query, |q, id| q.bind::<Integer, _>(id));
query
.load::<FileWithTagCount>(&mut self.connection)
.load::<FileWithTagCount>(conn.deref_mut())
.with_context(|| "Unable to get tagged photos with all specified tags")
})
}
@@ -579,6 +601,10 @@ impl TagDao for SqliteTagDao {
exclude_tag_ids: Vec<i32>,
context: &opentelemetry::Context,
) -> anyhow::Result<Vec<FileWithTagCount>> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_files_with_any_tags", |_| {
use diesel::dsl::*;
// Create the placeholders for the IN clauses
@@ -620,7 +646,7 @@ impl TagDao for SqliteTagDao {
.fold(query, |q, id| q.bind::<Integer, _>(id));
query
.load::<FileWithTagCount>(&mut self.connection)
.load::<FileWithTagCount>(conn.deref_mut())
.with_context(|| "Unable to get tagged photos")
})
}
@@ -633,9 +659,13 @@ impl TagDao for SqliteTagDao {
) -> anyhow::Result<()> {
use crate::database::schema::tagged_photo::dsl::*;
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
diesel::update(tagged_photo.filter(photo_name.eq(old_name)))
.set(photo_name.eq(new_name))
.execute(&mut self.connection)?;
.execute(conn.deref_mut())?;
Ok(())
}
@@ -645,10 +675,14 @@ impl TagDao for SqliteTagDao {
) -> anyhow::Result<Vec<String>> {
use crate::database::schema::tagged_photo::dsl::*;
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
tagged_photo
.select(photo_name)
.distinct()
.load(&mut self.connection)
.load(conn.deref_mut())
.with_context(|| "Unable to get photo names")
}
@@ -659,6 +693,10 @@ impl TagDao for SqliteTagDao {
) -> anyhow::Result<std::collections::HashMap<String, i64>> {
use std::collections::HashMap;
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_tag_counts_batch", |span| {
span.set_attribute(KeyValue::new("file_count", file_paths.len() as i64));
@@ -701,7 +739,7 @@ impl TagDao for SqliteTagDao {
// Execute query and convert to HashMap
query
.load::<TagCountRow>(&mut self.connection)
.load::<TagCountRow>(conn.deref_mut())
.with_context(|| "Unable to get batch tag counts")
.map(|rows| {
rows.into_iter()
@@ -739,7 +777,10 @@ mod tests {
}
}
// SAFETY: TestTagDao is only used in single-threaded tests
// SAFETY: TestTagDao uses RefCell<T> fields which are !Send because they allow
// multiple mutable borrows without coordination. This impl is sound because
// TestTagDao is test-only, used within a single test function, and never moved
// into spawned tasks or shared across threads.
unsafe impl Send for TestTagDao {}
unsafe impl Sync for TestTagDao {}