From 8ecd3c6cf897de67e76fdef3c7a80a7ec4b4c66b Mon Sep 17 00:00:00 2001 From: Cameron Date: Wed, 18 Mar 2026 17:10:11 -0400 Subject: [PATCH] refactor: use Arc> in SqliteTagDao, remove unsafe impl Sync Aligns SqliteTagDao with the pattern used by SqliteExifDao and SqliteInsightDao. The unsafe impl Sync workaround is no longer needed since Arc> provides safe interior mutability and automatic Sync derivation. Co-Authored-By: Claude Sonnet 4.6 --- src/files.rs | 4 +-- src/tags.rs | 91 +++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 68 insertions(+), 27 deletions(-) diff --git a/src/files.rs b/src/files.rs index 107b16a..a7355fb 100644 --- a/src/files.rs +++ b/src/files.rs @@ -1490,7 +1490,7 @@ mod tests { let request: Query = Query::from_query("path=&tag_ids=1,3&recursive=true").unwrap(); - let mut tag_dao = SqliteTagDao::new(in_memory_db_connection()); + let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection()))); let tag1 = tag_dao .create_tag(&opentelemetry::Context::current(), "tag1") @@ -1536,7 +1536,7 @@ mod tests { exp: 12345, }; - let mut tag_dao = SqliteTagDao::new(in_memory_db_connection()); + let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection()))); let tag1 = tag_dao .create_tag(&opentelemetry::Context::current(), "tag1") diff --git a/src/tags.rs b/src/tags.rs index 360d52b..5da6d6e 100644 --- a/src/tags.rs +++ b/src/tags.rs @@ -14,8 +14,8 @@ use opentelemetry::KeyValue; use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer}; use schema::{tagged_photo, tags}; use serde::{Deserialize, Serialize}; -use std::borrow::BorrowMut; -use std::sync::Mutex; +use std::ops::DerefMut; +use std::sync::{Arc, Mutex}; pub fn add_tag_services(app: App) -> App where @@ -330,25 +330,23 @@ pub trait TagDao: Send + Sync { } pub struct SqliteTagDao { - connection: SqliteConnection, + connection: Arc>, } impl SqliteTagDao { - pub(crate) fn new(connection: SqliteConnection) -> Self { + pub(crate) fn new(connection: Arc>) -> Self { SqliteTagDao { connection } } } impl Default for SqliteTagDao { fn default() -> Self { - SqliteTagDao::new(connect()) + SqliteTagDao { + connection: Arc::new(Mutex::new(connect())), + } } } -// SAFETY: SqliteTagDao is always accessed through Arc>, -// so concurrent access is prevented by the Mutex. -unsafe impl Sync for SqliteTagDao {} - impl TagDao for SqliteTagDao { fn get_all_tags( &mut self, @@ -357,6 +355,10 @@ impl TagDao for SqliteTagDao { ) -> anyhow::Result> { // select name, count(*) from tags join tagged_photo ON tags.id = tagged_photo.tag_id GROUP BY tags.name ORDER BY COUNT(*); + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); trace_db_call(context, "query", "get_all_tags", |span| { span.set_attribute(KeyValue::new("path", path.clone().unwrap_or_default())); @@ -367,7 +369,7 @@ impl TagDao for SqliteTagDao { .group_by(tags::id) .select((count_star(), id, name, created_time)) .filter(tagged_photo::photo_name.like(path)) - .get_results(&mut self.connection) + .get_results(conn.deref_mut()) .map::, _>(|tags_with_count: Vec<(i64, i32, String, i64)>| { tags_with_count .iter() @@ -392,6 +394,10 @@ impl TagDao for SqliteTagDao { context: &opentelemetry::Context, path: &str, ) -> anyhow::Result> { + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); trace_db_call(context, "query", "get_tags_for_path", |span| { span.set_attribute(KeyValue::new("path", path.to_string())); @@ -400,12 +406,16 @@ impl TagDao for SqliteTagDao { .left_join(tagged_photo::table) .filter(tagged_photo::photo_name.eq(&path)) .select((tags::id, tags::name, tags::created_time)) - .get_results::(self.connection.borrow_mut()) + .get_results::(conn.deref_mut()) .with_context(|| "Unable to get tags from Sqlite") }) } fn create_tag(&mut self, context: &opentelemetry::Context, name: &str) -> anyhow::Result { + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); trace_db_call(context, "insert", "create_tag", |span| { span.set_attribute(KeyValue::new("name", name.to_string())); @@ -414,7 +424,7 @@ impl TagDao for SqliteTagDao { name: name.to_string(), created_time: Utc::now().timestamp(), }) - .execute(&mut self.connection) + .execute(conn.deref_mut()) .with_context(|| format!("Unable to insert tag {:?} in Sqlite", name)) .and_then(|_| { info!("Inserted tag: {:?}", name); @@ -422,7 +432,7 @@ impl TagDao for SqliteTagDao { fn last_insert_rowid() -> Integer; } diesel::select(last_insert_rowid()) - .get_result::(&mut self.connection) + .get_result::(conn.deref_mut()) .with_context(|| "Unable to get last inserted tag from Sqlite") }) .and_then(|id| { @@ -430,7 +440,7 @@ impl TagDao for SqliteTagDao { tags::table .filter(tags::id.eq(id)) .select((tags::id, tags::name, tags::created_time)) - .get_result::(self.connection.borrow_mut()) + .get_result::(conn.deref_mut()) .with_context(|| { format!("Unable to get tagged photo with id: {:?} from Sqlite", id) }) @@ -444,6 +454,10 @@ impl TagDao for SqliteTagDao { tag_name: &str, path: &str, ) -> anyhow::Result> { + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); trace_db_call(context, "delete", "remove_tag", |span| { span.set_attributes(vec![ KeyValue::new("tag_name", tag_name.to_string()), @@ -452,7 +466,7 @@ impl TagDao for SqliteTagDao { tags::table .filter(tags::name.eq(tag_name)) - .get_result::(self.connection.borrow_mut()) + .get_result::(conn.deref_mut()) .optional() .with_context(|| format!("Unable to get tag '{}'", tag_name)) .and_then(|tag| { @@ -462,7 +476,7 @@ impl TagDao for SqliteTagDao { .filter(tagged_photo::tag_id.eq(tag.id)) .filter(tagged_photo::photo_name.eq(path)), ) - .execute(&mut self.connection) + .execute(conn.deref_mut()) .with_context(|| format!("Unable to delete tag: '{}'", &tag.name)) .map(|_| Some(())) } else { @@ -479,6 +493,10 @@ impl TagDao for SqliteTagDao { path: &str, tag_id: i32, ) -> anyhow::Result { + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); trace_db_call(context, "insert", "tag_file", |span| { span.set_attributes(vec![ KeyValue::new("path", path.to_string()), @@ -491,7 +509,7 @@ impl TagDao for SqliteTagDao { photo_name: path.to_string(), created_time: Utc::now().timestamp(), }) - .execute(self.connection.borrow_mut()) + .execute(conn.deref_mut()) .with_context(|| format!("Unable to tag file {:?} in sqlite", path)) .and_then(|_| { info!("Inserted tagged photo: {:#} -> {:?}", tag_id, path); @@ -499,13 +517,13 @@ impl TagDao for SqliteTagDao { fn last_insert_rowid() -> diesel::sql_types::Integer; } diesel::select(last_insert_rowid()) - .get_result::(&mut self.connection) + .get_result::(conn.deref_mut()) .with_context(|| "Unable to get last inserted tag from Sqlite") }) .and_then(|tagged_id| { tagged_photo::table .find(tagged_id) - .first(self.connection.borrow_mut()) + .first(conn.deref_mut()) .with_context(|| { format!( "Error getting inserted tagged photo with id: {:?}", @@ -522,6 +540,10 @@ impl TagDao for SqliteTagDao { exclude_tag_ids: Vec, context: &opentelemetry::Context, ) -> anyhow::Result> { + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); trace_db_call(context, "query", "get_files_with_all_tags", |_| { use diesel::dsl::*; @@ -568,7 +590,7 @@ impl TagDao for SqliteTagDao { .fold(query, |q, id| q.bind::(id)); query - .load::(&mut self.connection) + .load::(conn.deref_mut()) .with_context(|| "Unable to get tagged photos with all specified tags") }) } @@ -579,6 +601,10 @@ impl TagDao for SqliteTagDao { exclude_tag_ids: Vec, context: &opentelemetry::Context, ) -> anyhow::Result> { + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); trace_db_call(context, "query", "get_files_with_any_tags", |_| { use diesel::dsl::*; // Create the placeholders for the IN clauses @@ -620,7 +646,7 @@ impl TagDao for SqliteTagDao { .fold(query, |q, id| q.bind::(id)); query - .load::(&mut self.connection) + .load::(conn.deref_mut()) .with_context(|| "Unable to get tagged photos") }) } @@ -633,9 +659,13 @@ impl TagDao for SqliteTagDao { ) -> anyhow::Result<()> { use crate::database::schema::tagged_photo::dsl::*; + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); diesel::update(tagged_photo.filter(photo_name.eq(old_name))) .set(photo_name.eq(new_name)) - .execute(&mut self.connection)?; + .execute(conn.deref_mut())?; Ok(()) } @@ -645,10 +675,14 @@ impl TagDao for SqliteTagDao { ) -> anyhow::Result> { use crate::database::schema::tagged_photo::dsl::*; + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); tagged_photo .select(photo_name) .distinct() - .load(&mut self.connection) + .load(conn.deref_mut()) .with_context(|| "Unable to get photo names") } @@ -659,6 +693,10 @@ impl TagDao for SqliteTagDao { ) -> anyhow::Result> { use std::collections::HashMap; + let mut conn = self + .connection + .lock() + .expect("Unable to lock SqliteTagDao connection"); trace_db_call(context, "query", "get_tag_counts_batch", |span| { span.set_attribute(KeyValue::new("file_count", file_paths.len() as i64)); @@ -701,7 +739,7 @@ impl TagDao for SqliteTagDao { // Execute query and convert to HashMap query - .load::(&mut self.connection) + .load::(conn.deref_mut()) .with_context(|| "Unable to get batch tag counts") .map(|rows| { rows.into_iter() @@ -739,7 +777,10 @@ mod tests { } } - // SAFETY: TestTagDao is only used in single-threaded tests + // SAFETY: TestTagDao uses RefCell fields which are !Send because they allow + // multiple mutable borrows without coordination. This impl is sound because + // TestTagDao is test-only, used within a single test function, and never moved + // into spawned tasks or shared across threads. unsafe impl Send for TestTagDao {} unsafe impl Sync for TestTagDao {}