Compare commits
7 Commits
master
...
feature/in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8196ef94a0 | ||
|
|
e58b8fe743 | ||
|
|
c0d27d0b9e | ||
|
|
8ecd3c6cf8 | ||
|
|
387ce23afd | ||
|
|
b31b4b903c | ||
|
|
dd0715c081 |
@@ -17,6 +17,7 @@ use crate::database::{
|
||||
};
|
||||
use crate::memories::extract_date_from_filename;
|
||||
use crate::otel::global_tracer;
|
||||
use crate::tags::TagDao;
|
||||
use crate::utils::normalize_path;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -45,6 +46,7 @@ pub struct InsightGenerator {
|
||||
calendar_dao: Arc<Mutex<Box<dyn CalendarEventDao>>>,
|
||||
location_dao: Arc<Mutex<Box<dyn LocationHistoryDao>>>,
|
||||
search_dao: Arc<Mutex<Box<dyn SearchHistoryDao>>>,
|
||||
tag_dao: Arc<Mutex<Box<dyn TagDao>>>,
|
||||
|
||||
base_path: String,
|
||||
}
|
||||
@@ -59,6 +61,7 @@ impl InsightGenerator {
|
||||
calendar_dao: Arc<Mutex<Box<dyn CalendarEventDao>>>,
|
||||
location_dao: Arc<Mutex<Box<dyn LocationHistoryDao>>>,
|
||||
search_dao: Arc<Mutex<Box<dyn SearchHistoryDao>>>,
|
||||
tag_dao: Arc<Mutex<Box<dyn TagDao>>>,
|
||||
base_path: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -70,6 +73,7 @@ impl InsightGenerator {
|
||||
calendar_dao,
|
||||
location_dao,
|
||||
search_dao,
|
||||
tag_dao,
|
||||
base_path,
|
||||
}
|
||||
}
|
||||
@@ -149,6 +153,7 @@ impl InsightGenerator {
|
||||
contact: Option<&str>,
|
||||
topics: Option<&[String]>,
|
||||
limit: usize,
|
||||
extra_context: Option<&str>,
|
||||
) -> Result<Vec<String>> {
|
||||
let tracer = global_tracer();
|
||||
let span = tracer.start_with_context("ai.rag.filter_historical", parent_cx);
|
||||
@@ -170,7 +175,7 @@ impl InsightGenerator {
|
||||
}
|
||||
|
||||
let query_results = self
|
||||
.find_relevant_messages_rag(date, location, contact, topics, limit * 2)
|
||||
.find_relevant_messages_rag(date, location, contact, topics, limit * 2, extra_context)
|
||||
.await?;
|
||||
|
||||
filter_cx.span().set_attribute(KeyValue::new(
|
||||
@@ -232,6 +237,7 @@ impl InsightGenerator {
|
||||
contact: Option<&str>,
|
||||
topics: Option<&[String]>,
|
||||
limit: usize,
|
||||
extra_context: Option<&str>,
|
||||
) -> Result<Vec<String>> {
|
||||
let tracer = global_tracer();
|
||||
let current_cx = opentelemetry::Context::current();
|
||||
@@ -246,7 +252,7 @@ impl InsightGenerator {
|
||||
}
|
||||
|
||||
// Build query string - prioritize topics if available (semantically meaningful)
|
||||
let query = if let Some(topics) = topics {
|
||||
let base_query = if let Some(topics) = topics {
|
||||
if !topics.is_empty() {
|
||||
// Use topics for semantic search - these are actual content keywords
|
||||
let topic_str = topics.join(", ");
|
||||
@@ -264,6 +270,12 @@ impl InsightGenerator {
|
||||
Self::build_metadata_query(date, location, contact)
|
||||
};
|
||||
|
||||
let query = if let Some(extra) = extra_context {
|
||||
format!("{}. {}", base_query, extra)
|
||||
} else {
|
||||
base_query
|
||||
};
|
||||
|
||||
span.set_attribute(KeyValue::new("query", query.clone()));
|
||||
|
||||
// Create context with this span for child operations
|
||||
@@ -508,22 +520,29 @@ impl InsightGenerator {
|
||||
timestamp: i64,
|
||||
location: Option<&str>,
|
||||
contact: Option<&str>,
|
||||
enrichment: Option<&str>,
|
||||
) -> Result<Option<String>> {
|
||||
let tracer = global_tracer();
|
||||
let span = tracer.start_with_context("ai.context.search", parent_cx);
|
||||
let search_cx = parent_cx.with_span(span);
|
||||
|
||||
// Build semantic query from metadata
|
||||
let query_text = format!(
|
||||
"searches about {} {} {}",
|
||||
DateTime::from_timestamp(timestamp, 0)
|
||||
.map(|dt| dt.format("%B %Y").to_string())
|
||||
.unwrap_or_default(),
|
||||
location.unwrap_or(""),
|
||||
contact
|
||||
.map(|c| format!("involving {}", c))
|
||||
.unwrap_or_default()
|
||||
);
|
||||
// Use enrichment (topics + photo description + tags) if available;
|
||||
// fall back to generic temporal query.
|
||||
let query_text = if let Some(enriched) = enrichment {
|
||||
enriched.to_string()
|
||||
} else {
|
||||
// Fallback: generic temporal query
|
||||
format!(
|
||||
"searches about {} {} {}",
|
||||
DateTime::from_timestamp(timestamp, 0)
|
||||
.map(|dt| dt.format("%B %Y").to_string())
|
||||
.unwrap_or_default(),
|
||||
location.unwrap_or(""),
|
||||
contact
|
||||
.map(|c| format!("involving {}", c))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
};
|
||||
|
||||
let query_embedding = match self.ollama.generate_embedding(&query_text).await {
|
||||
Ok(emb) => emb,
|
||||
@@ -585,6 +604,7 @@ impl InsightGenerator {
|
||||
calendar: Option<String>,
|
||||
location: Option<String>,
|
||||
search: Option<String>,
|
||||
tags: Option<String>,
|
||||
) -> String {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
@@ -600,6 +620,9 @@ impl InsightGenerator {
|
||||
if let Some(s) = search {
|
||||
parts.push(format!("## Searches\n{}", s));
|
||||
}
|
||||
if let Some(t) = tags {
|
||||
parts.push(format!("## Tags\n{}", t));
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
"No additional context available".to_string()
|
||||
@@ -703,6 +726,20 @@ impl InsightGenerator {
|
||||
.set_attribute(KeyValue::new("contact", c.clone()));
|
||||
}
|
||||
|
||||
// Fetch file tags (used to enrich RAG and final context)
|
||||
let tag_names: Vec<String> = {
|
||||
let mut dao = self.tag_dao.lock().expect("Unable to lock TagDao");
|
||||
dao.get_tags_for_path(&insight_cx, &file_path)
|
||||
.unwrap_or_else(|e| {
|
||||
log::warn!("Failed to fetch tags for insight {}: {}", file_path, e);
|
||||
Vec::new()
|
||||
})
|
||||
.into_iter()
|
||||
.map(|t| t.name)
|
||||
.collect()
|
||||
};
|
||||
log::info!("Fetched {} tags for photo: {:?}", tag_names.len(), tag_names);
|
||||
|
||||
// 4. Get location name from GPS coordinates (needed for RAG query)
|
||||
let location = match exif {
|
||||
Some(ref exif) => {
|
||||
@@ -729,6 +766,90 @@ impl InsightGenerator {
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Check if the model has vision capabilities
|
||||
let model_to_check = ollama_client.primary_model.clone();
|
||||
let has_vision = match OllamaClient::check_model_capabilities(
|
||||
&ollama_client.primary_url,
|
||||
&model_to_check,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(capabilities) => {
|
||||
log::info!(
|
||||
"Model '{}' vision capability: {}",
|
||||
model_to_check,
|
||||
capabilities.has_vision
|
||||
);
|
||||
capabilities.has_vision
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"Failed to check vision capabilities for model '{}', assuming no vision support: {}",
|
||||
model_to_check,
|
||||
e
|
||||
);
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
insight_cx
|
||||
.span()
|
||||
.set_attribute(KeyValue::new("model_has_vision", has_vision));
|
||||
|
||||
// Load image and encode as base64 only if model supports vision
|
||||
let image_base64 = if has_vision {
|
||||
match self.load_image_as_base64(&file_path) {
|
||||
Ok(b64) => {
|
||||
log::info!(
|
||||
"Successfully loaded image for vision-capable model '{}'",
|
||||
model_to_check
|
||||
);
|
||||
Some(b64)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to load image for vision model: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"Model '{}' does not support vision, skipping image processing",
|
||||
model_to_check
|
||||
);
|
||||
None
|
||||
};
|
||||
|
||||
// Generate brief photo description for RAG enrichment (vision models only)
|
||||
let photo_description: Option<String> = if let Some(ref img_b64) = image_base64 {
|
||||
match ollama_client.generate_photo_description(img_b64).await {
|
||||
Ok(desc) => {
|
||||
log::info!("Photo description for RAG enrichment: {}", desc);
|
||||
Some(desc)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to generate photo description for RAG enrichment: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build enriched context string for RAG: photo description + tags
|
||||
// (SMS topics are passed separately to RAG functions)
|
||||
let enriched_query: Option<String> = {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
if let Some(ref desc) = photo_description {
|
||||
parts.push(desc.clone());
|
||||
}
|
||||
if !tag_names.is_empty() {
|
||||
parts.push(format!("tags: {}", tag_names.join(", ")));
|
||||
}
|
||||
if parts.is_empty() { None } else { Some(parts.join(". ")) }
|
||||
};
|
||||
|
||||
let mut search_enrichment: Option<String> = enriched_query.clone();
|
||||
|
||||
// 5. Intelligent retrieval: Hybrid approach for better context
|
||||
let mut sms_summary = None;
|
||||
let mut used_rag = false;
|
||||
@@ -767,6 +888,21 @@ impl InsightGenerator {
|
||||
|
||||
log::info!("Extracted topics for query enrichment: {:?}", topics);
|
||||
|
||||
// Build full search enrichment: SMS topics + photo description + tag names
|
||||
search_enrichment = {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
if !topics.is_empty() {
|
||||
parts.push(topics.join(", "));
|
||||
}
|
||||
if let Some(ref desc) = photo_description {
|
||||
parts.push(desc.clone());
|
||||
}
|
||||
if !tag_names.is_empty() {
|
||||
parts.push(format!("tags: {}", tag_names.join(", ")));
|
||||
}
|
||||
if parts.is_empty() { None } else { Some(parts.join(". ")) }
|
||||
};
|
||||
|
||||
// Step 3: Try historical RAG (>30 days ago) using extracted topics
|
||||
let topics_slice = if topics.is_empty() {
|
||||
None
|
||||
@@ -781,6 +917,7 @@ impl InsightGenerator {
|
||||
contact.as_deref(),
|
||||
topics_slice,
|
||||
10, // Top 10 historical matches
|
||||
enriched_query.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -843,7 +980,7 @@ impl InsightGenerator {
|
||||
log::info!("No immediate messages found, trying basic RAG as fallback");
|
||||
// Fallback to basic RAG even without strong query
|
||||
match self
|
||||
.find_relevant_messages_rag(date_taken, None, contact.as_deref(), None, 20)
|
||||
.find_relevant_messages_rag(date_taken, None, contact.as_deref(), None, 20, enriched_query.as_deref())
|
||||
.await
|
||||
{
|
||||
Ok(rag_messages) if !rag_messages.is_empty() => {
|
||||
@@ -940,17 +1077,25 @@ impl InsightGenerator {
|
||||
timestamp,
|
||||
location.as_deref(),
|
||||
contact.as_deref(),
|
||||
search_enrichment.as_deref(),
|
||||
)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
// 7. Combine all context sources with equal weight
|
||||
let tags_context = if tag_names.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tag_names.join(", "))
|
||||
};
|
||||
|
||||
let combined_context = Self::combine_contexts(
|
||||
sms_summary,
|
||||
calendar_context,
|
||||
location_context,
|
||||
search_context,
|
||||
tags_context,
|
||||
);
|
||||
|
||||
log::info!(
|
||||
@@ -958,59 +1103,6 @@ impl InsightGenerator {
|
||||
combined_context.len()
|
||||
);
|
||||
|
||||
// 8. Check if the model has vision capabilities
|
||||
let model_to_check = ollama_client.primary_model.clone();
|
||||
let has_vision = match OllamaClient::check_model_capabilities(
|
||||
&ollama_client.primary_url,
|
||||
&model_to_check,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(capabilities) => {
|
||||
log::info!(
|
||||
"Model '{}' vision capability: {}",
|
||||
model_to_check,
|
||||
capabilities.has_vision
|
||||
);
|
||||
capabilities.has_vision
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"Failed to check vision capabilities for model '{}', assuming no vision support: {}",
|
||||
model_to_check,
|
||||
e
|
||||
);
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
insight_cx
|
||||
.span()
|
||||
.set_attribute(KeyValue::new("model_has_vision", has_vision));
|
||||
|
||||
// 9. Load image and encode as base64 only if model supports vision
|
||||
let image_base64 = if has_vision {
|
||||
match self.load_image_as_base64(&file_path) {
|
||||
Ok(b64) => {
|
||||
log::info!(
|
||||
"Successfully loaded image for vision-capable model '{}'",
|
||||
model_to_check
|
||||
);
|
||||
Some(b64)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to load image for vision model: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"Model '{}' does not support vision, skipping image processing",
|
||||
model_to_check
|
||||
);
|
||||
None
|
||||
};
|
||||
|
||||
// 10. Generate summary first, then derive title from the summary
|
||||
let summary = ollama_client
|
||||
.generate_photo_summary(
|
||||
@@ -1019,7 +1111,7 @@ impl InsightGenerator {
|
||||
contact.as_deref(),
|
||||
Some(&combined_context),
|
||||
custom_system_prompt.as_deref(),
|
||||
image_base64,
|
||||
image_base64.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1297,3 +1389,38 @@ Return ONLY the summary, nothing else."#,
|
||||
data.display_name
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn combine_contexts_includes_tags_section_when_tags_present() {
|
||||
let result = InsightGenerator::combine_contexts(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some("vacation, hiking, mountains".to_string()),
|
||||
);
|
||||
assert!(result.contains("## Tags"), "Should include Tags section");
|
||||
assert!(result.contains("vacation, hiking, mountains"), "Should include tag names");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combine_contexts_omits_tags_section_when_no_tags() {
|
||||
let result = InsightGenerator::combine_contexts(
|
||||
Some("some messages".to_string()),
|
||||
None, None, None,
|
||||
None, // no tags
|
||||
);
|
||||
assert!(!result.contains("## Tags"), "Should not include Tags section when None");
|
||||
assert!(result.contains("## Messages"), "Should still include Messages");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combine_contexts_returns_no_context_message_when_all_none() {
|
||||
let result = InsightGenerator::combine_contexts(None, None, None, None, None);
|
||||
assert_eq!(result, "No additional context available");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,6 +480,22 @@ Analyze the image and use specific details from both the visual content and the
|
||||
.await
|
||||
}
|
||||
|
||||
/// Generate a brief visual description of a photo for use in RAG query enrichment.
|
||||
/// Returns 1-2 sentences describing people, location, and activity visible in the image.
|
||||
/// Only called when the model has vision capabilities.
|
||||
pub async fn generate_photo_description(&self, image_base64: &str) -> Result<String> {
|
||||
let prompt = "Briefly describe what you see in this image in 1-2 sentences. \
|
||||
Focus on the people, location, and activity.";
|
||||
let system = "You are a scene description assistant. Be concise and factual.";
|
||||
let images = vec![image_base64.to_string()];
|
||||
|
||||
let description = self
|
||||
.generate_with_images(prompt, Some(system), Some(images))
|
||||
.await?;
|
||||
|
||||
Ok(description.trim().to_string())
|
||||
}
|
||||
|
||||
/// Generate an embedding vector for text using nomic-embed-text:v1.5
|
||||
/// Returns a 768-dimensional vector as Vec<f32>
|
||||
pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
|
||||
@@ -664,3 +680,18 @@ struct OllamaBatchEmbedRequest {
|
||||
struct OllamaEmbedResponse {
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn generate_photo_description_prompt_is_concise() {
|
||||
// Verify the method exists and its prompt is sane by checking the
|
||||
// constant we'll use. This is a compile + smoke check; actual LLM
|
||||
// calls are integration-tested manually.
|
||||
let prompt = "Briefly describe what you see in this image in 1-2 sentences. \
|
||||
Focus on the people, location, and activity.";
|
||||
assert!(prompt.len() < 200, "Prompt should be concise");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1490,7 +1490,7 @@ mod tests {
|
||||
let request: Query<FilesRequest> =
|
||||
Query::from_query("path=&tag_ids=1,3&recursive=true").unwrap();
|
||||
|
||||
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection());
|
||||
let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
|
||||
|
||||
let tag1 = tag_dao
|
||||
.create_tag(&opentelemetry::Context::current(), "tag1")
|
||||
@@ -1536,7 +1536,7 @@ mod tests {
|
||||
exp: 12345,
|
||||
};
|
||||
|
||||
let mut tag_dao = SqliteTagDao::new(in_memory_db_connection());
|
||||
let mut tag_dao = SqliteTagDao::new(std::sync::Arc::new(Mutex::new(in_memory_db_connection())));
|
||||
|
||||
let tag1 = tag_dao
|
||||
.create_tag(&opentelemetry::Context::current(), "tag1")
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::database::{
|
||||
SqliteLocationHistoryDao, SqliteSearchHistoryDao,
|
||||
};
|
||||
use crate::database::{PreviewDao, SqlitePreviewDao};
|
||||
use crate::tags::{SqliteTagDao, TagDao};
|
||||
use crate::video::actors::{
|
||||
PlaylistGenerator, PreviewClipGenerator, StreamActor, VideoPlaylistManager,
|
||||
};
|
||||
@@ -119,6 +120,8 @@ impl Default for AppState {
|
||||
Arc::new(Mutex::new(Box::new(SqliteLocationHistoryDao::new())));
|
||||
let search_dao: Arc<Mutex<Box<dyn SearchHistoryDao>>> =
|
||||
Arc::new(Mutex::new(Box::new(SqliteSearchHistoryDao::new())));
|
||||
let tag_dao: Arc<Mutex<Box<dyn TagDao>>> =
|
||||
Arc::new(Mutex::new(Box::new(SqliteTagDao::default())));
|
||||
|
||||
// Load base path
|
||||
let base_path = env::var("BASE_PATH").expect("BASE_PATH was not set in the env");
|
||||
@@ -133,6 +136,7 @@ impl Default for AppState {
|
||||
calendar_dao.clone(),
|
||||
location_dao.clone(),
|
||||
search_dao.clone(),
|
||||
tag_dao.clone(),
|
||||
base_path.clone(),
|
||||
);
|
||||
|
||||
@@ -196,6 +200,8 @@ impl AppState {
|
||||
Arc::new(Mutex::new(Box::new(SqliteLocationHistoryDao::new())));
|
||||
let search_dao: Arc<Mutex<Box<dyn SearchHistoryDao>>> =
|
||||
Arc::new(Mutex::new(Box::new(SqliteSearchHistoryDao::new())));
|
||||
let tag_dao: Arc<Mutex<Box<dyn TagDao>>> =
|
||||
Arc::new(Mutex::new(Box::new(SqliteTagDao::default())));
|
||||
|
||||
// Initialize test InsightGenerator with all data sources
|
||||
let base_path_str = base_path.to_string_lossy().to_string();
|
||||
@@ -208,6 +214,7 @@ impl AppState {
|
||||
calendar_dao.clone(),
|
||||
location_dao.clone(),
|
||||
search_dao.clone(),
|
||||
tag_dao.clone(),
|
||||
base_path_str.clone(),
|
||||
);
|
||||
|
||||
|
||||
91
src/tags.rs
91
src/tags.rs
@@ -14,8 +14,8 @@ use opentelemetry::KeyValue;
|
||||
use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer};
|
||||
use schema::{tagged_photo, tags};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::BorrowMut;
|
||||
use std::sync::Mutex;
|
||||
use std::ops::DerefMut;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub fn add_tag_services<T, TagD: TagDao + 'static>(app: App<T>) -> App<T>
|
||||
where
|
||||
@@ -276,7 +276,7 @@ pub struct AddTagsRequest {
|
||||
pub tag_ids: Vec<i32>,
|
||||
}
|
||||
|
||||
pub trait TagDao {
|
||||
pub trait TagDao: Send + Sync {
|
||||
fn get_all_tags(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
@@ -330,18 +330,20 @@ pub trait TagDao {
|
||||
}
|
||||
|
||||
pub struct SqliteTagDao {
|
||||
connection: SqliteConnection,
|
||||
connection: Arc<Mutex<SqliteConnection>>,
|
||||
}
|
||||
|
||||
impl SqliteTagDao {
|
||||
pub(crate) fn new(connection: SqliteConnection) -> Self {
|
||||
pub(crate) fn new(connection: Arc<Mutex<SqliteConnection>>) -> Self {
|
||||
SqliteTagDao { connection }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SqliteTagDao {
|
||||
fn default() -> Self {
|
||||
SqliteTagDao::new(connect())
|
||||
SqliteTagDao {
|
||||
connection: Arc::new(Mutex::new(connect())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -353,6 +355,10 @@ impl TagDao for SqliteTagDao {
|
||||
) -> anyhow::Result<Vec<(i64, Tag)>> {
|
||||
// select name, count(*) from tags join tagged_photo ON tags.id = tagged_photo.tag_id GROUP BY tags.name ORDER BY COUNT(*);
|
||||
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_all_tags", |span| {
|
||||
span.set_attribute(KeyValue::new("path", path.clone().unwrap_or_default()));
|
||||
|
||||
@@ -363,7 +369,7 @@ impl TagDao for SqliteTagDao {
|
||||
.group_by(tags::id)
|
||||
.select((count_star(), id, name, created_time))
|
||||
.filter(tagged_photo::photo_name.like(path))
|
||||
.get_results(&mut self.connection)
|
||||
.get_results(conn.deref_mut())
|
||||
.map::<Vec<(i64, Tag)>, _>(|tags_with_count: Vec<(i64, i32, String, i64)>| {
|
||||
tags_with_count
|
||||
.iter()
|
||||
@@ -388,6 +394,10 @@ impl TagDao for SqliteTagDao {
|
||||
context: &opentelemetry::Context,
|
||||
path: &str,
|
||||
) -> anyhow::Result<Vec<Tag>> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_tags_for_path", |span| {
|
||||
span.set_attribute(KeyValue::new("path", path.to_string()));
|
||||
|
||||
@@ -396,12 +406,16 @@ impl TagDao for SqliteTagDao {
|
||||
.left_join(tagged_photo::table)
|
||||
.filter(tagged_photo::photo_name.eq(&path))
|
||||
.select((tags::id, tags::name, tags::created_time))
|
||||
.get_results::<Tag>(self.connection.borrow_mut())
|
||||
.get_results::<Tag>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get tags from Sqlite")
|
||||
})
|
||||
}
|
||||
|
||||
fn create_tag(&mut self, context: &opentelemetry::Context, name: &str) -> anyhow::Result<Tag> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "insert", "create_tag", |span| {
|
||||
span.set_attribute(KeyValue::new("name", name.to_string()));
|
||||
|
||||
@@ -410,7 +424,7 @@ impl TagDao for SqliteTagDao {
|
||||
name: name.to_string(),
|
||||
created_time: Utc::now().timestamp(),
|
||||
})
|
||||
.execute(&mut self.connection)
|
||||
.execute(conn.deref_mut())
|
||||
.with_context(|| format!("Unable to insert tag {:?} in Sqlite", name))
|
||||
.and_then(|_| {
|
||||
info!("Inserted tag: {:?}", name);
|
||||
@@ -418,7 +432,7 @@ impl TagDao for SqliteTagDao {
|
||||
fn last_insert_rowid() -> Integer;
|
||||
}
|
||||
diesel::select(last_insert_rowid())
|
||||
.get_result::<i32>(&mut self.connection)
|
||||
.get_result::<i32>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get last inserted tag from Sqlite")
|
||||
})
|
||||
.and_then(|id| {
|
||||
@@ -426,7 +440,7 @@ impl TagDao for SqliteTagDao {
|
||||
tags::table
|
||||
.filter(tags::id.eq(id))
|
||||
.select((tags::id, tags::name, tags::created_time))
|
||||
.get_result::<Tag>(self.connection.borrow_mut())
|
||||
.get_result::<Tag>(conn.deref_mut())
|
||||
.with_context(|| {
|
||||
format!("Unable to get tagged photo with id: {:?} from Sqlite", id)
|
||||
})
|
||||
@@ -440,6 +454,10 @@ impl TagDao for SqliteTagDao {
|
||||
tag_name: &str,
|
||||
path: &str,
|
||||
) -> anyhow::Result<Option<()>> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "delete", "remove_tag", |span| {
|
||||
span.set_attributes(vec![
|
||||
KeyValue::new("tag_name", tag_name.to_string()),
|
||||
@@ -448,7 +466,7 @@ impl TagDao for SqliteTagDao {
|
||||
|
||||
tags::table
|
||||
.filter(tags::name.eq(tag_name))
|
||||
.get_result::<Tag>(self.connection.borrow_mut())
|
||||
.get_result::<Tag>(conn.deref_mut())
|
||||
.optional()
|
||||
.with_context(|| format!("Unable to get tag '{}'", tag_name))
|
||||
.and_then(|tag| {
|
||||
@@ -458,7 +476,7 @@ impl TagDao for SqliteTagDao {
|
||||
.filter(tagged_photo::tag_id.eq(tag.id))
|
||||
.filter(tagged_photo::photo_name.eq(path)),
|
||||
)
|
||||
.execute(&mut self.connection)
|
||||
.execute(conn.deref_mut())
|
||||
.with_context(|| format!("Unable to delete tag: '{}'", &tag.name))
|
||||
.map(|_| Some(()))
|
||||
} else {
|
||||
@@ -475,6 +493,10 @@ impl TagDao for SqliteTagDao {
|
||||
path: &str,
|
||||
tag_id: i32,
|
||||
) -> anyhow::Result<TaggedPhoto> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "insert", "tag_file", |span| {
|
||||
span.set_attributes(vec![
|
||||
KeyValue::new("path", path.to_string()),
|
||||
@@ -487,7 +509,7 @@ impl TagDao for SqliteTagDao {
|
||||
photo_name: path.to_string(),
|
||||
created_time: Utc::now().timestamp(),
|
||||
})
|
||||
.execute(self.connection.borrow_mut())
|
||||
.execute(conn.deref_mut())
|
||||
.with_context(|| format!("Unable to tag file {:?} in sqlite", path))
|
||||
.and_then(|_| {
|
||||
info!("Inserted tagged photo: {:#} -> {:?}", tag_id, path);
|
||||
@@ -495,13 +517,13 @@ impl TagDao for SqliteTagDao {
|
||||
fn last_insert_rowid() -> diesel::sql_types::Integer;
|
||||
}
|
||||
diesel::select(last_insert_rowid())
|
||||
.get_result::<i32>(&mut self.connection)
|
||||
.get_result::<i32>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get last inserted tag from Sqlite")
|
||||
})
|
||||
.and_then(|tagged_id| {
|
||||
tagged_photo::table
|
||||
.find(tagged_id)
|
||||
.first(self.connection.borrow_mut())
|
||||
.first(conn.deref_mut())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Error getting inserted tagged photo with id: {:?}",
|
||||
@@ -518,6 +540,10 @@ impl TagDao for SqliteTagDao {
|
||||
exclude_tag_ids: Vec<i32>,
|
||||
context: &opentelemetry::Context,
|
||||
) -> anyhow::Result<Vec<FileWithTagCount>> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_files_with_all_tags", |_| {
|
||||
use diesel::dsl::*;
|
||||
|
||||
@@ -564,7 +590,7 @@ impl TagDao for SqliteTagDao {
|
||||
.fold(query, |q, id| q.bind::<Integer, _>(id));
|
||||
|
||||
query
|
||||
.load::<FileWithTagCount>(&mut self.connection)
|
||||
.load::<FileWithTagCount>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get tagged photos with all specified tags")
|
||||
})
|
||||
}
|
||||
@@ -575,6 +601,10 @@ impl TagDao for SqliteTagDao {
|
||||
exclude_tag_ids: Vec<i32>,
|
||||
context: &opentelemetry::Context,
|
||||
) -> anyhow::Result<Vec<FileWithTagCount>> {
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_files_with_any_tags", |_| {
|
||||
use diesel::dsl::*;
|
||||
// Create the placeholders for the IN clauses
|
||||
@@ -616,7 +646,7 @@ impl TagDao for SqliteTagDao {
|
||||
.fold(query, |q, id| q.bind::<Integer, _>(id));
|
||||
|
||||
query
|
||||
.load::<FileWithTagCount>(&mut self.connection)
|
||||
.load::<FileWithTagCount>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get tagged photos")
|
||||
})
|
||||
}
|
||||
@@ -629,9 +659,13 @@ impl TagDao for SqliteTagDao {
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::database::schema::tagged_photo::dsl::*;
|
||||
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
diesel::update(tagged_photo.filter(photo_name.eq(old_name)))
|
||||
.set(photo_name.eq(new_name))
|
||||
.execute(&mut self.connection)?;
|
||||
.execute(conn.deref_mut())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -641,10 +675,14 @@ impl TagDao for SqliteTagDao {
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
use crate::database::schema::tagged_photo::dsl::*;
|
||||
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
tagged_photo
|
||||
.select(photo_name)
|
||||
.distinct()
|
||||
.load(&mut self.connection)
|
||||
.load(conn.deref_mut())
|
||||
.with_context(|| "Unable to get photo names")
|
||||
}
|
||||
|
||||
@@ -655,6 +693,10 @@ impl TagDao for SqliteTagDao {
|
||||
) -> anyhow::Result<std::collections::HashMap<String, i64>> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut conn = self
|
||||
.connection
|
||||
.lock()
|
||||
.expect("Unable to lock SqliteTagDao connection");
|
||||
trace_db_call(context, "query", "get_tag_counts_batch", |span| {
|
||||
span.set_attribute(KeyValue::new("file_count", file_paths.len() as i64));
|
||||
|
||||
@@ -697,7 +739,7 @@ impl TagDao for SqliteTagDao {
|
||||
|
||||
// Execute query and convert to HashMap
|
||||
query
|
||||
.load::<TagCountRow>(&mut self.connection)
|
||||
.load::<TagCountRow>(conn.deref_mut())
|
||||
.with_context(|| "Unable to get batch tag counts")
|
||||
.map(|rows| {
|
||||
rows.into_iter()
|
||||
@@ -735,6 +777,13 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: TestTagDao uses RefCell<T> fields which are !Send because they allow
|
||||
// multiple mutable borrows without coordination. This impl is sound because
|
||||
// TestTagDao is test-only, used within a single test function, and never moved
|
||||
// into spawned tasks or shared across threads.
|
||||
unsafe impl Send for TestTagDao {}
|
||||
unsafe impl Sync for TestTagDao {}
|
||||
|
||||
impl TagDao for TestTagDao {
|
||||
fn get_all_tags(
|
||||
&mut self,
|
||||
|
||||
Reference in New Issue
Block a user