refactor: use Arc<Mutex<SqliteConnection>> in SqliteTagDao, remove unsafe impl Sync
Aligns SqliteTagDao with the pattern used by SqliteExifDao and SqliteInsightDao. The unsafe impl Sync workaround is no longer needed since Arc<Mutex<>> provides safe interior mutability and automatic Sync derivation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1490,7 +1490,7 @@ mod tests {
|
|||||||
let request: Query<FilesRequest> =
|
let request: Query<FilesRequest> =
|
||||||
Query::from_query("path=&tag_ids=1,3&recursive=true").unwrap();
|
Query::from_query("path=&tag_ids=1,3&recursive=true").unwrap();
|
||||||
|
|
||||||
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection());
|
let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
|
||||||
|
|
||||||
let tag1 = tag_dao
|
let tag1 = tag_dao
|
||||||
.create_tag(&opentelemetry::Context::current(), "tag1")
|
.create_tag(&opentelemetry::Context::current(), "tag1")
|
||||||
@@ -1536,7 +1536,7 @@ mod tests {
|
|||||||
exp: 12345,
|
exp: 12345,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection());
|
let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
|
||||||
|
|
||||||
let tag1 = tag_dao
|
let tag1 = tag_dao
|
||||||
.create_tag(&opentelemetry::Context::current(), "tag1")
|
.create_tag(&opentelemetry::Context::current(), "tag1")
|
||||||
|
|||||||
91
src/tags.rs
91
src/tags.rs
@@ -14,8 +14,8 @@ use opentelemetry::KeyValue;
|
|||||||
use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer};
|
use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer};
|
||||||
use schema::{tagged_photo, tags};
|
use schema::{tagged_photo, tags};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::borrow::BorrowMut;
|
use std::ops::DerefMut;
|
||||||
use std::sync::Mutex;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
pub fn add_tag_services<T, TagD: TagDao + 'static>(app: App<T>) -> App<T>
|
pub fn add_tag_services<T, TagD: TagDao + 'static>(app: App<T>) -> App<T>
|
||||||
where
|
where
|
||||||
@@ -330,24 +330,22 @@ pub trait TagDao: Send + Sync {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct SqliteTagDao {
|
pub struct SqliteTagDao {
|
||||||
connection: SqliteConnection,
|
connection: Arc<Mutex<SqliteConnection>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SqliteTagDao {
|
impl SqliteTagDao {
|
||||||
pub(crate) fn new(connection: SqliteConnection) -> Self {
|
pub(crate) fn new(connection: Arc<Mutex<SqliteConnection>>) -> Self {
|
||||||
SqliteTagDao { connection }
|
SqliteTagDao { connection }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for SqliteTagDao {
|
impl Default for SqliteTagDao {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
SqliteTagDao::new(connect())
|
SqliteTagDao {
|
||||||
|
connection: Arc::new(Mutex::new(connect())),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SAFETY: SqliteTagDao is always accessed through Arc<Mutex<...>>,
|
|
||||||
// so concurrent access is prevented by the Mutex.
|
|
||||||
unsafe impl Sync for SqliteTagDao {}
|
|
||||||
|
|
||||||
impl TagDao for SqliteTagDao {
|
impl TagDao for SqliteTagDao {
|
||||||
fn get_all_tags(
|
fn get_all_tags(
|
||||||
@@ -357,6 +355,10 @@ impl TagDao for SqliteTagDao {
|
|||||||
) -> anyhow::Result<Vec<(i64, Tag)>> {
|
) -> anyhow::Result<Vec<(i64, Tag)>> {
|
||||||
// select name, count(*) from tags join tagged_photo ON tags.id = tagged_photo.tag_id GROUP BY tags.name ORDER BY COUNT(*);
|
// select name, count(*) from tags join tagged_photo ON tags.id = tagged_photo.tag_id GROUP BY tags.name ORDER BY COUNT(*);
|
||||||
|
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
trace_db_call(context, "query", "get_all_tags", |span| {
|
trace_db_call(context, "query", "get_all_tags", |span| {
|
||||||
span.set_attribute(KeyValue::new("path", path.clone().unwrap_or_default()));
|
span.set_attribute(KeyValue::new("path", path.clone().unwrap_or_default()));
|
||||||
|
|
||||||
@@ -367,7 +369,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
.group_by(tags::id)
|
.group_by(tags::id)
|
||||||
.select((count_star(), id, name, created_time))
|
.select((count_star(), id, name, created_time))
|
||||||
.filter(tagged_photo::photo_name.like(path))
|
.filter(tagged_photo::photo_name.like(path))
|
||||||
.get_results(&mut self.connection)
|
.get_results(conn.deref_mut())
|
||||||
.map::<Vec<(i64, Tag)>, _>(|tags_with_count: Vec<(i64, i32, String, i64)>| {
|
.map::<Vec<(i64, Tag)>, _>(|tags_with_count: Vec<(i64, i32, String, i64)>| {
|
||||||
tags_with_count
|
tags_with_count
|
||||||
.iter()
|
.iter()
|
||||||
@@ -392,6 +394,10 @@ impl TagDao for SqliteTagDao {
|
|||||||
context: &opentelemetry::Context,
|
context: &opentelemetry::Context,
|
||||||
path: &str,
|
path: &str,
|
||||||
) -> anyhow::Result<Vec<Tag>> {
|
) -> anyhow::Result<Vec<Tag>> {
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
trace_db_call(context, "query", "get_tags_for_path", |span| {
|
trace_db_call(context, "query", "get_tags_for_path", |span| {
|
||||||
span.set_attribute(KeyValue::new("path", path.to_string()));
|
span.set_attribute(KeyValue::new("path", path.to_string()));
|
||||||
|
|
||||||
@@ -400,12 +406,16 @@ impl TagDao for SqliteTagDao {
|
|||||||
.left_join(tagged_photo::table)
|
.left_join(tagged_photo::table)
|
||||||
.filter(tagged_photo::photo_name.eq(&path))
|
.filter(tagged_photo::photo_name.eq(&path))
|
||||||
.select((tags::id, tags::name, tags::created_time))
|
.select((tags::id, tags::name, tags::created_time))
|
||||||
.get_results::<Tag>(self.connection.borrow_mut())
|
.get_results::<Tag>(conn.deref_mut())
|
||||||
.with_context(|| "Unable to get tags from Sqlite")
|
.with_context(|| "Unable to get tags from Sqlite")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_tag(&mut self, context: &opentelemetry::Context, name: &str) -> anyhow::Result<Tag> {
|
fn create_tag(&mut self, context: &opentelemetry::Context, name: &str) -> anyhow::Result<Tag> {
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
trace_db_call(context, "insert", "create_tag", |span| {
|
trace_db_call(context, "insert", "create_tag", |span| {
|
||||||
span.set_attribute(KeyValue::new("name", name.to_string()));
|
span.set_attribute(KeyValue::new("name", name.to_string()));
|
||||||
|
|
||||||
@@ -414,7 +424,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
created_time: Utc::now().timestamp(),
|
created_time: Utc::now().timestamp(),
|
||||||
})
|
})
|
||||||
.execute(&mut self.connection)
|
.execute(conn.deref_mut())
|
||||||
.with_context(|| format!("Unable to insert tag {:?} in Sqlite", name))
|
.with_context(|| format!("Unable to insert tag {:?} in Sqlite", name))
|
||||||
.and_then(|_| {
|
.and_then(|_| {
|
||||||
info!("Inserted tag: {:?}", name);
|
info!("Inserted tag: {:?}", name);
|
||||||
@@ -422,7 +432,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
fn last_insert_rowid() -> Integer;
|
fn last_insert_rowid() -> Integer;
|
||||||
}
|
}
|
||||||
diesel::select(last_insert_rowid())
|
diesel::select(last_insert_rowid())
|
||||||
.get_result::<i32>(&mut self.connection)
|
.get_result::<i32>(conn.deref_mut())
|
||||||
.with_context(|| "Unable to get last inserted tag from Sqlite")
|
.with_context(|| "Unable to get last inserted tag from Sqlite")
|
||||||
})
|
})
|
||||||
.and_then(|id| {
|
.and_then(|id| {
|
||||||
@@ -430,7 +440,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
tags::table
|
tags::table
|
||||||
.filter(tags::id.eq(id))
|
.filter(tags::id.eq(id))
|
||||||
.select((tags::id, tags::name, tags::created_time))
|
.select((tags::id, tags::name, tags::created_time))
|
||||||
.get_result::<Tag>(self.connection.borrow_mut())
|
.get_result::<Tag>(conn.deref_mut())
|
||||||
.with_context(|| {
|
.with_context(|| {
|
||||||
format!("Unable to get tagged photo with id: {:?} from Sqlite", id)
|
format!("Unable to get tagged photo with id: {:?} from Sqlite", id)
|
||||||
})
|
})
|
||||||
@@ -444,6 +454,10 @@ impl TagDao for SqliteTagDao {
|
|||||||
tag_name: &str,
|
tag_name: &str,
|
||||||
path: &str,
|
path: &str,
|
||||||
) -> anyhow::Result<Option<()>> {
|
) -> anyhow::Result<Option<()>> {
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
trace_db_call(context, "delete", "remove_tag", |span| {
|
trace_db_call(context, "delete", "remove_tag", |span| {
|
||||||
span.set_attributes(vec![
|
span.set_attributes(vec![
|
||||||
KeyValue::new("tag_name", tag_name.to_string()),
|
KeyValue::new("tag_name", tag_name.to_string()),
|
||||||
@@ -452,7 +466,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
|
|
||||||
tags::table
|
tags::table
|
||||||
.filter(tags::name.eq(tag_name))
|
.filter(tags::name.eq(tag_name))
|
||||||
.get_result::<Tag>(self.connection.borrow_mut())
|
.get_result::<Tag>(conn.deref_mut())
|
||||||
.optional()
|
.optional()
|
||||||
.with_context(|| format!("Unable to get tag '{}'", tag_name))
|
.with_context(|| format!("Unable to get tag '{}'", tag_name))
|
||||||
.and_then(|tag| {
|
.and_then(|tag| {
|
||||||
@@ -462,7 +476,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
.filter(tagged_photo::tag_id.eq(tag.id))
|
.filter(tagged_photo::tag_id.eq(tag.id))
|
||||||
.filter(tagged_photo::photo_name.eq(path)),
|
.filter(tagged_photo::photo_name.eq(path)),
|
||||||
)
|
)
|
||||||
.execute(&mut self.connection)
|
.execute(conn.deref_mut())
|
||||||
.with_context(|| format!("Unable to delete tag: '{}'", &tag.name))
|
.with_context(|| format!("Unable to delete tag: '{}'", &tag.name))
|
||||||
.map(|_| Some(()))
|
.map(|_| Some(()))
|
||||||
} else {
|
} else {
|
||||||
@@ -479,6 +493,10 @@ impl TagDao for SqliteTagDao {
|
|||||||
path: &str,
|
path: &str,
|
||||||
tag_id: i32,
|
tag_id: i32,
|
||||||
) -> anyhow::Result<TaggedPhoto> {
|
) -> anyhow::Result<TaggedPhoto> {
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
trace_db_call(context, "insert", "tag_file", |span| {
|
trace_db_call(context, "insert", "tag_file", |span| {
|
||||||
span.set_attributes(vec![
|
span.set_attributes(vec![
|
||||||
KeyValue::new("path", path.to_string()),
|
KeyValue::new("path", path.to_string()),
|
||||||
@@ -491,7 +509,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
photo_name: path.to_string(),
|
photo_name: path.to_string(),
|
||||||
created_time: Utc::now().timestamp(),
|
created_time: Utc::now().timestamp(),
|
||||||
})
|
})
|
||||||
.execute(self.connection.borrow_mut())
|
.execute(conn.deref_mut())
|
||||||
.with_context(|| format!("Unable to tag file {:?} in sqlite", path))
|
.with_context(|| format!("Unable to tag file {:?} in sqlite", path))
|
||||||
.and_then(|_| {
|
.and_then(|_| {
|
||||||
info!("Inserted tagged photo: {:#} -> {:?}", tag_id, path);
|
info!("Inserted tagged photo: {:#} -> {:?}", tag_id, path);
|
||||||
@@ -499,13 +517,13 @@ impl TagDao for SqliteTagDao {
|
|||||||
fn last_insert_rowid() -> diesel::sql_types::Integer;
|
fn last_insert_rowid() -> diesel::sql_types::Integer;
|
||||||
}
|
}
|
||||||
diesel::select(last_insert_rowid())
|
diesel::select(last_insert_rowid())
|
||||||
.get_result::<i32>(&mut self.connection)
|
.get_result::<i32>(conn.deref_mut())
|
||||||
.with_context(|| "Unable to get last inserted tag from Sqlite")
|
.with_context(|| "Unable to get last inserted tag from Sqlite")
|
||||||
})
|
})
|
||||||
.and_then(|tagged_id| {
|
.and_then(|tagged_id| {
|
||||||
tagged_photo::table
|
tagged_photo::table
|
||||||
.find(tagged_id)
|
.find(tagged_id)
|
||||||
.first(self.connection.borrow_mut())
|
.first(conn.deref_mut())
|
||||||
.with_context(|| {
|
.with_context(|| {
|
||||||
format!(
|
format!(
|
||||||
"Error getting inserted tagged photo with id: {:?}",
|
"Error getting inserted tagged photo with id: {:?}",
|
||||||
@@ -522,6 +540,10 @@ impl TagDao for SqliteTagDao {
|
|||||||
exclude_tag_ids: Vec<i32>,
|
exclude_tag_ids: Vec<i32>,
|
||||||
context: &opentelemetry::Context,
|
context: &opentelemetry::Context,
|
||||||
) -> anyhow::Result<Vec<FileWithTagCount>> {
|
) -> anyhow::Result<Vec<FileWithTagCount>> {
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
trace_db_call(context, "query", "get_files_with_all_tags", |_| {
|
trace_db_call(context, "query", "get_files_with_all_tags", |_| {
|
||||||
use diesel::dsl::*;
|
use diesel::dsl::*;
|
||||||
|
|
||||||
@@ -568,7 +590,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
.fold(query, |q, id| q.bind::<Integer, _>(id));
|
.fold(query, |q, id| q.bind::<Integer, _>(id));
|
||||||
|
|
||||||
query
|
query
|
||||||
.load::<FileWithTagCount>(&mut self.connection)
|
.load::<FileWithTagCount>(conn.deref_mut())
|
||||||
.with_context(|| "Unable to get tagged photos with all specified tags")
|
.with_context(|| "Unable to get tagged photos with all specified tags")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -579,6 +601,10 @@ impl TagDao for SqliteTagDao {
|
|||||||
exclude_tag_ids: Vec<i32>,
|
exclude_tag_ids: Vec<i32>,
|
||||||
context: &opentelemetry::Context,
|
context: &opentelemetry::Context,
|
||||||
) -> anyhow::Result<Vec<FileWithTagCount>> {
|
) -> anyhow::Result<Vec<FileWithTagCount>> {
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
trace_db_call(context, "query", "get_files_with_any_tags", |_| {
|
trace_db_call(context, "query", "get_files_with_any_tags", |_| {
|
||||||
use diesel::dsl::*;
|
use diesel::dsl::*;
|
||||||
// Create the placeholders for the IN clauses
|
// Create the placeholders for the IN clauses
|
||||||
@@ -620,7 +646,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
.fold(query, |q, id| q.bind::<Integer, _>(id));
|
.fold(query, |q, id| q.bind::<Integer, _>(id));
|
||||||
|
|
||||||
query
|
query
|
||||||
.load::<FileWithTagCount>(&mut self.connection)
|
.load::<FileWithTagCount>(conn.deref_mut())
|
||||||
.with_context(|| "Unable to get tagged photos")
|
.with_context(|| "Unable to get tagged photos")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -633,9 +659,13 @@ impl TagDao for SqliteTagDao {
|
|||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
use crate::database::schema::tagged_photo::dsl::*;
|
use crate::database::schema::tagged_photo::dsl::*;
|
||||||
|
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
diesel::update(tagged_photo.filter(photo_name.eq(old_name)))
|
diesel::update(tagged_photo.filter(photo_name.eq(old_name)))
|
||||||
.set(photo_name.eq(new_name))
|
.set(photo_name.eq(new_name))
|
||||||
.execute(&mut self.connection)?;
|
.execute(conn.deref_mut())?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -645,10 +675,14 @@ impl TagDao for SqliteTagDao {
|
|||||||
) -> anyhow::Result<Vec<String>> {
|
) -> anyhow::Result<Vec<String>> {
|
||||||
use crate::database::schema::tagged_photo::dsl::*;
|
use crate::database::schema::tagged_photo::dsl::*;
|
||||||
|
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
tagged_photo
|
tagged_photo
|
||||||
.select(photo_name)
|
.select(photo_name)
|
||||||
.distinct()
|
.distinct()
|
||||||
.load(&mut self.connection)
|
.load(conn.deref_mut())
|
||||||
.with_context(|| "Unable to get photo names")
|
.with_context(|| "Unable to get photo names")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -659,6 +693,10 @@ impl TagDao for SqliteTagDao {
|
|||||||
) -> anyhow::Result<std::collections::HashMap<String, i64>> {
|
) -> anyhow::Result<std::collections::HashMap<String, i64>> {
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
let mut conn = self
|
||||||
|
.connection
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock SqliteTagDao connection");
|
||||||
trace_db_call(context, "query", "get_tag_counts_batch", |span| {
|
trace_db_call(context, "query", "get_tag_counts_batch", |span| {
|
||||||
span.set_attribute(KeyValue::new("file_count", file_paths.len() as i64));
|
span.set_attribute(KeyValue::new("file_count", file_paths.len() as i64));
|
||||||
|
|
||||||
@@ -701,7 +739,7 @@ impl TagDao for SqliteTagDao {
|
|||||||
|
|
||||||
// Execute query and convert to HashMap
|
// Execute query and convert to HashMap
|
||||||
query
|
query
|
||||||
.load::<TagCountRow>(&mut self.connection)
|
.load::<TagCountRow>(conn.deref_mut())
|
||||||
.with_context(|| "Unable to get batch tag counts")
|
.with_context(|| "Unable to get batch tag counts")
|
||||||
.map(|rows| {
|
.map(|rows| {
|
||||||
rows.into_iter()
|
rows.into_iter()
|
||||||
@@ -739,7 +777,10 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SAFETY: TestTagDao is only used in single-threaded tests
|
// SAFETY: TestTagDao uses RefCell<T> fields which are !Send because they allow
|
||||||
|
// multiple mutable borrows without coordination. This impl is sound because
|
||||||
|
// TestTagDao is test-only, used within a single test function, and never moved
|
||||||
|
// into spawned tasks or shared across threads.
|
||||||
unsafe impl Send for TestTagDao {}
|
unsafe impl Send for TestTagDao {}
|
||||||
unsafe impl Sync for TestTagDao {}
|
unsafe impl Sync for TestTagDao {}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user