refactor: use Arc<Mutex<SqliteConnection>> in SqliteTagDao, remove unsafe impl Sync

Aligns SqliteTagDao with the pattern used by SqliteExifDao and SqliteInsightDao.
The unsafe impl Sync workaround is no longer needed since Arc<Mutex<>> provides
safe interior mutability and automatic Sync derivation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Cameron
2026-03-18 17:10:11 -04:00
parent 387ce23afd
commit 8ecd3c6cf8
2 changed files with 68 additions and 27 deletions

View File

@@ -1490,7 +1490,7 @@ mod tests {
let request: Query<FilesRequest> = let request: Query<FilesRequest> =
Query::from_query("path=&tag_ids=1,3&recursive=true").unwrap(); Query::from_query("path=&tag_ids=1,3&recursive=true").unwrap();
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection()); let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
let tag1 = tag_dao let tag1 = tag_dao
.create_tag(&opentelemetry::Context::current(), "tag1") .create_tag(&opentelemetry::Context::current(), "tag1")
@@ -1536,7 +1536,7 @@ mod tests {
exp: 12345, exp: 12345,
}; };
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection()); let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
let tag1 = tag_dao let tag1 = tag_dao
.create_tag(&opentelemetry::Context::current(), "tag1") .create_tag(&opentelemetry::Context::current(), "tag1")

View File

@@ -14,8 +14,8 @@ use opentelemetry::KeyValue;
use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer}; use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer};
use schema::{tagged_photo, tags}; use schema::{tagged_photo, tags};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut; use std::ops::DerefMut;
use std::sync::Mutex; use std::sync::{Arc, Mutex};
pub fn add_tag_services<T, TagD: TagDao + 'static>(app: App<T>) -> App<T> pub fn add_tag_services<T, TagD: TagDao + 'static>(app: App<T>) -> App<T>
where where
@@ -330,24 +330,22 @@ pub trait TagDao: Send + Sync {
} }
pub struct SqliteTagDao { pub struct SqliteTagDao {
connection: SqliteConnection, connection: Arc<Mutex<SqliteConnection>>,
} }
impl SqliteTagDao { impl SqliteTagDao {
pub(crate) fn new(connection: SqliteConnection) -> Self { pub(crate) fn new(connection: Arc<Mutex<SqliteConnection>>) -> Self {
SqliteTagDao { connection } SqliteTagDao { connection }
} }
} }
impl Default for SqliteTagDao { impl Default for SqliteTagDao {
fn default() -> Self { fn default() -> Self {
SqliteTagDao::new(connect()) SqliteTagDao {
connection: Arc::new(Mutex::new(connect())),
}
} }
} }
// SAFETY: SqliteTagDao is always accessed through Arc<Mutex<...>>,
// so concurrent access is prevented by the Mutex.
unsafe impl Sync for SqliteTagDao {}
impl TagDao for SqliteTagDao { impl TagDao for SqliteTagDao {
fn get_all_tags( fn get_all_tags(
@@ -357,6 +355,10 @@ impl TagDao for SqliteTagDao {
) -> anyhow::Result<Vec<(i64, Tag)>> { ) -> anyhow::Result<Vec<(i64, Tag)>> {
// select name, count(*) from tags join tagged_photo ON tags.id = tagged_photo.tag_id GROUP BY tags.name ORDER BY COUNT(*); // select name, count(*) from tags join tagged_photo ON tags.id = tagged_photo.tag_id GROUP BY tags.name ORDER BY COUNT(*);
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_all_tags", |span| { trace_db_call(context, "query", "get_all_tags", |span| {
span.set_attribute(KeyValue::new("path", path.clone().unwrap_or_default())); span.set_attribute(KeyValue::new("path", path.clone().unwrap_or_default()));
@@ -367,7 +369,7 @@ impl TagDao for SqliteTagDao {
.group_by(tags::id) .group_by(tags::id)
.select((count_star(), id, name, created_time)) .select((count_star(), id, name, created_time))
.filter(tagged_photo::photo_name.like(path)) .filter(tagged_photo::photo_name.like(path))
.get_results(&mut self.connection) .get_results(conn.deref_mut())
.map::<Vec<(i64, Tag)>, _>(|tags_with_count: Vec<(i64, i32, String, i64)>| { .map::<Vec<(i64, Tag)>, _>(|tags_with_count: Vec<(i64, i32, String, i64)>| {
tags_with_count tags_with_count
.iter() .iter()
@@ -392,6 +394,10 @@ impl TagDao for SqliteTagDao {
context: &opentelemetry::Context, context: &opentelemetry::Context,
path: &str, path: &str,
) -> anyhow::Result<Vec<Tag>> { ) -> anyhow::Result<Vec<Tag>> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_tags_for_path", |span| { trace_db_call(context, "query", "get_tags_for_path", |span| {
span.set_attribute(KeyValue::new("path", path.to_string())); span.set_attribute(KeyValue::new("path", path.to_string()));
@@ -400,12 +406,16 @@ impl TagDao for SqliteTagDao {
.left_join(tagged_photo::table) .left_join(tagged_photo::table)
.filter(tagged_photo::photo_name.eq(&path)) .filter(tagged_photo::photo_name.eq(&path))
.select((tags::id, tags::name, tags::created_time)) .select((tags::id, tags::name, tags::created_time))
.get_results::<Tag>(self.connection.borrow_mut()) .get_results::<Tag>(conn.deref_mut())
.with_context(|| "Unable to get tags from Sqlite") .with_context(|| "Unable to get tags from Sqlite")
}) })
} }
fn create_tag(&mut self, context: &opentelemetry::Context, name: &str) -> anyhow::Result<Tag> { fn create_tag(&mut self, context: &opentelemetry::Context, name: &str) -> anyhow::Result<Tag> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "insert", "create_tag", |span| { trace_db_call(context, "insert", "create_tag", |span| {
span.set_attribute(KeyValue::new("name", name.to_string())); span.set_attribute(KeyValue::new("name", name.to_string()));
@@ -414,7 +424,7 @@ impl TagDao for SqliteTagDao {
name: name.to_string(), name: name.to_string(),
created_time: Utc::now().timestamp(), created_time: Utc::now().timestamp(),
}) })
.execute(&mut self.connection) .execute(conn.deref_mut())
.with_context(|| format!("Unable to insert tag {:?} in Sqlite", name)) .with_context(|| format!("Unable to insert tag {:?} in Sqlite", name))
.and_then(|_| { .and_then(|_| {
info!("Inserted tag: {:?}", name); info!("Inserted tag: {:?}", name);
@@ -422,7 +432,7 @@ impl TagDao for SqliteTagDao {
fn last_insert_rowid() -> Integer; fn last_insert_rowid() -> Integer;
} }
diesel::select(last_insert_rowid()) diesel::select(last_insert_rowid())
.get_result::<i32>(&mut self.connection) .get_result::<i32>(conn.deref_mut())
.with_context(|| "Unable to get last inserted tag from Sqlite") .with_context(|| "Unable to get last inserted tag from Sqlite")
}) })
.and_then(|id| { .and_then(|id| {
@@ -430,7 +440,7 @@ impl TagDao for SqliteTagDao {
tags::table tags::table
.filter(tags::id.eq(id)) .filter(tags::id.eq(id))
.select((tags::id, tags::name, tags::created_time)) .select((tags::id, tags::name, tags::created_time))
.get_result::<Tag>(self.connection.borrow_mut()) .get_result::<Tag>(conn.deref_mut())
.with_context(|| { .with_context(|| {
format!("Unable to get tagged photo with id: {:?} from Sqlite", id) format!("Unable to get tagged photo with id: {:?} from Sqlite", id)
}) })
@@ -444,6 +454,10 @@ impl TagDao for SqliteTagDao {
tag_name: &str, tag_name: &str,
path: &str, path: &str,
) -> anyhow::Result<Option<()>> { ) -> anyhow::Result<Option<()>> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "delete", "remove_tag", |span| { trace_db_call(context, "delete", "remove_tag", |span| {
span.set_attributes(vec![ span.set_attributes(vec![
KeyValue::new("tag_name", tag_name.to_string()), KeyValue::new("tag_name", tag_name.to_string()),
@@ -452,7 +466,7 @@ impl TagDao for SqliteTagDao {
tags::table tags::table
.filter(tags::name.eq(tag_name)) .filter(tags::name.eq(tag_name))
.get_result::<Tag>(self.connection.borrow_mut()) .get_result::<Tag>(conn.deref_mut())
.optional() .optional()
.with_context(|| format!("Unable to get tag '{}'", tag_name)) .with_context(|| format!("Unable to get tag '{}'", tag_name))
.and_then(|tag| { .and_then(|tag| {
@@ -462,7 +476,7 @@ impl TagDao for SqliteTagDao {
.filter(tagged_photo::tag_id.eq(tag.id)) .filter(tagged_photo::tag_id.eq(tag.id))
.filter(tagged_photo::photo_name.eq(path)), .filter(tagged_photo::photo_name.eq(path)),
) )
.execute(&mut self.connection) .execute(conn.deref_mut())
.with_context(|| format!("Unable to delete tag: '{}'", &tag.name)) .with_context(|| format!("Unable to delete tag: '{}'", &tag.name))
.map(|_| Some(())) .map(|_| Some(()))
} else { } else {
@@ -479,6 +493,10 @@ impl TagDao for SqliteTagDao {
path: &str, path: &str,
tag_id: i32, tag_id: i32,
) -> anyhow::Result<TaggedPhoto> { ) -> anyhow::Result<TaggedPhoto> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "insert", "tag_file", |span| { trace_db_call(context, "insert", "tag_file", |span| {
span.set_attributes(vec![ span.set_attributes(vec![
KeyValue::new("path", path.to_string()), KeyValue::new("path", path.to_string()),
@@ -491,7 +509,7 @@ impl TagDao for SqliteTagDao {
photo_name: path.to_string(), photo_name: path.to_string(),
created_time: Utc::now().timestamp(), created_time: Utc::now().timestamp(),
}) })
.execute(self.connection.borrow_mut()) .execute(conn.deref_mut())
.with_context(|| format!("Unable to tag file {:?} in sqlite", path)) .with_context(|| format!("Unable to tag file {:?} in sqlite", path))
.and_then(|_| { .and_then(|_| {
info!("Inserted tagged photo: {:#} -> {:?}", tag_id, path); info!("Inserted tagged photo: {:#} -> {:?}", tag_id, path);
@@ -499,13 +517,13 @@ impl TagDao for SqliteTagDao {
fn last_insert_rowid() -> diesel::sql_types::Integer; fn last_insert_rowid() -> diesel::sql_types::Integer;
} }
diesel::select(last_insert_rowid()) diesel::select(last_insert_rowid())
.get_result::<i32>(&mut self.connection) .get_result::<i32>(conn.deref_mut())
.with_context(|| "Unable to get last inserted tag from Sqlite") .with_context(|| "Unable to get last inserted tag from Sqlite")
}) })
.and_then(|tagged_id| { .and_then(|tagged_id| {
tagged_photo::table tagged_photo::table
.find(tagged_id) .find(tagged_id)
.first(self.connection.borrow_mut()) .first(conn.deref_mut())
.with_context(|| { .with_context(|| {
format!( format!(
"Error getting inserted tagged photo with id: {:?}", "Error getting inserted tagged photo with id: {:?}",
@@ -522,6 +540,10 @@ impl TagDao for SqliteTagDao {
exclude_tag_ids: Vec<i32>, exclude_tag_ids: Vec<i32>,
context: &opentelemetry::Context, context: &opentelemetry::Context,
) -> anyhow::Result<Vec<FileWithTagCount>> { ) -> anyhow::Result<Vec<FileWithTagCount>> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_files_with_all_tags", |_| { trace_db_call(context, "query", "get_files_with_all_tags", |_| {
use diesel::dsl::*; use diesel::dsl::*;
@@ -568,7 +590,7 @@ impl TagDao for SqliteTagDao {
.fold(query, |q, id| q.bind::<Integer, _>(id)); .fold(query, |q, id| q.bind::<Integer, _>(id));
query query
.load::<FileWithTagCount>(&mut self.connection) .load::<FileWithTagCount>(conn.deref_mut())
.with_context(|| "Unable to get tagged photos with all specified tags") .with_context(|| "Unable to get tagged photos with all specified tags")
}) })
} }
@@ -579,6 +601,10 @@ impl TagDao for SqliteTagDao {
exclude_tag_ids: Vec<i32>, exclude_tag_ids: Vec<i32>,
context: &opentelemetry::Context, context: &opentelemetry::Context,
) -> anyhow::Result<Vec<FileWithTagCount>> { ) -> anyhow::Result<Vec<FileWithTagCount>> {
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_files_with_any_tags", |_| { trace_db_call(context, "query", "get_files_with_any_tags", |_| {
use diesel::dsl::*; use diesel::dsl::*;
// Create the placeholders for the IN clauses // Create the placeholders for the IN clauses
@@ -620,7 +646,7 @@ impl TagDao for SqliteTagDao {
.fold(query, |q, id| q.bind::<Integer, _>(id)); .fold(query, |q, id| q.bind::<Integer, _>(id));
query query
.load::<FileWithTagCount>(&mut self.connection) .load::<FileWithTagCount>(conn.deref_mut())
.with_context(|| "Unable to get tagged photos") .with_context(|| "Unable to get tagged photos")
}) })
} }
@@ -633,9 +659,13 @@ impl TagDao for SqliteTagDao {
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
use crate::database::schema::tagged_photo::dsl::*; use crate::database::schema::tagged_photo::dsl::*;
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
diesel::update(tagged_photo.filter(photo_name.eq(old_name))) diesel::update(tagged_photo.filter(photo_name.eq(old_name)))
.set(photo_name.eq(new_name)) .set(photo_name.eq(new_name))
.execute(&mut self.connection)?; .execute(conn.deref_mut())?;
Ok(()) Ok(())
} }
@@ -645,10 +675,14 @@ impl TagDao for SqliteTagDao {
) -> anyhow::Result<Vec<String>> { ) -> anyhow::Result<Vec<String>> {
use crate::database::schema::tagged_photo::dsl::*; use crate::database::schema::tagged_photo::dsl::*;
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
tagged_photo tagged_photo
.select(photo_name) .select(photo_name)
.distinct() .distinct()
.load(&mut self.connection) .load(conn.deref_mut())
.with_context(|| "Unable to get photo names") .with_context(|| "Unable to get photo names")
} }
@@ -659,6 +693,10 @@ impl TagDao for SqliteTagDao {
) -> anyhow::Result<std::collections::HashMap<String, i64>> { ) -> anyhow::Result<std::collections::HashMap<String, i64>> {
use std::collections::HashMap; use std::collections::HashMap;
let mut conn = self
.connection
.lock()
.expect("Unable to lock SqliteTagDao connection");
trace_db_call(context, "query", "get_tag_counts_batch", |span| { trace_db_call(context, "query", "get_tag_counts_batch", |span| {
span.set_attribute(KeyValue::new("file_count", file_paths.len() as i64)); span.set_attribute(KeyValue::new("file_count", file_paths.len() as i64));
@@ -701,7 +739,7 @@ impl TagDao for SqliteTagDao {
// Execute query and convert to HashMap // Execute query and convert to HashMap
query query
.load::<TagCountRow>(&mut self.connection) .load::<TagCountRow>(conn.deref_mut())
.with_context(|| "Unable to get batch tag counts") .with_context(|| "Unable to get batch tag counts")
.map(|rows| { .map(|rows| {
rows.into_iter() rows.into_iter()
@@ -739,7 +777,10 @@ mod tests {
} }
} }
// SAFETY: TestTagDao is only used in single-threaded tests // SAFETY: TestTagDao uses RefCell<T> fields which are !Send because they allow
// multiple mutable borrows without coordination. This impl is sound because
// TestTagDao is test-only, used within a single test function, and never moved
// into spawned tasks or shared across threads.
unsafe impl Send for TestTagDao {} unsafe impl Send for TestTagDao {}
unsafe impl Sync for TestTagDao {} unsafe impl Sync for TestTagDao {}