diff --git a/Cargo.lock b/Cargo.lock index 5d3e4ce..d35048c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2104,6 +2104,7 @@ dependencies = [ "tokio", "tokio-util", "urlencoding", + "uuid", "walkdir", "zerocopy", ] @@ -4391,7 +4392,9 @@ version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" dependencies = [ + "getrandom 0.4.2", "js-sys", + "serde_core", "wasm-bindgen", ] diff --git a/Cargo.toml b/Cargo.toml index 7324001..6807778 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ image_hasher = "3.0" bk-tree = "0.5" async-trait = "0.1" indicatif = "0.17" +uuid = { version = "1.10", features = ["v4", "serde"] } # Windows lacks system sqlite3, so re-enable the bundled C build there. # Linux/macOS use the system library (faster builds, smaller binary). diff --git a/migrations/2026-05-27-000000_add_insight_generation_jobs/down.sql b/migrations/2026-05-27-000000_add_insight_generation_jobs/down.sql new file mode 100644 index 0000000..2c9a2a7 --- /dev/null +++ b/migrations/2026-05-27-000000_add_insight_generation_jobs/down.sql @@ -0,0 +1,3 @@ +DROP INDEX IF EXISTS idx_insight_gen_jobs_status_cleanup; +DROP INDEX IF EXISTS idx_insight_gen_jobs_file; +DROP TABLE IF EXISTS insight_generation_jobs; diff --git a/migrations/2026-05-27-000000_add_insight_generation_jobs/up.sql b/migrations/2026-05-27-000000_add_insight_generation_jobs/up.sql new file mode 100644 index 0000000..1ad6aab --- /dev/null +++ b/migrations/2026-05-27-000000_add_insight_generation_jobs/up.sql @@ -0,0 +1,23 @@ +-- Track async insight generation jobs so the client can poll for +-- completion after the server returns 202 Accepted. Each generation +-- creates a new row; the application layer cancels prior running +-- jobs before inserting. +CREATE TABLE insight_generation_jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + library_id INTEGER NOT NULL DEFAULT 1, + file_path TEXT NOT NULL, + generation_type TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'running', + started_at INTEGER NOT NULL, + completed_at INTEGER, + result_insight_id INTEGER, + error_message TEXT +); + +-- For the status endpoint: fast lookup by (library_id, file_path) +CREATE INDEX idx_insight_gen_jobs_file + ON insight_generation_jobs(library_id, file_path); + +-- For startup cleanup (future): prune old completed/failed jobs +CREATE INDEX idx_insight_gen_jobs_status_cleanup + ON insight_generation_jobs(status, started_at); diff --git a/migrations/2026-05-27-000001_remove_insight_jobs_unique/down.sql b/migrations/2026-05-27-000001_remove_insight_jobs_unique/down.sql new file mode 100644 index 0000000..7a5cf40 --- /dev/null +++ b/migrations/2026-05-27-000001_remove_insight_jobs_unique/down.sql @@ -0,0 +1,28 @@ +-- Restore UNIQUE constraint + +CREATE TABLE insight_generation_jobs_new ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + library_id INTEGER NOT NULL DEFAULT 1, + file_path TEXT NOT NULL, + generation_type TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'running', + started_at INTEGER NOT NULL, + completed_at INTEGER, + result_insight_id INTEGER, + error_message TEXT, + UNIQUE(library_id, file_path, generation_type) +); + +INSERT INTO insight_generation_jobs_new + SELECT id, library_id, file_path, generation_type, status, started_at, completed_at, result_insight_id, error_message + FROM insight_generation_jobs; + +DROP TABLE insight_generation_jobs; + +ALTER TABLE insight_generation_jobs_new RENAME TO insight_generation_jobs; + +CREATE INDEX idx_insight_gen_jobs_file + ON insight_generation_jobs(library_id, file_path); + +CREATE INDEX idx_insight_gen_jobs_status_cleanup + ON insight_generation_jobs(status, started_at); diff --git a/migrations/2026-05-27-000001_remove_insight_jobs_unique/up.sql b/migrations/2026-05-27-000001_remove_insight_jobs_unique/up.sql new file mode 100644 index 0000000..939c592 --- /dev/null +++ b/migrations/2026-05-27-000001_remove_insight_jobs_unique/up.sql @@ -0,0 +1,30 @@ +-- Remove UNIQUE(library_id, file_path, generation_type) constraint to allow +-- multiple job rows per file. This enables proper cancel/regenerate semantics: +-- a new job is always inserted on regenerate, and the old job is cancelled +-- independently. The application layer prevents concurrent running jobs. + +CREATE TABLE insight_generation_jobs_new ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + library_id INTEGER NOT NULL DEFAULT 1, + file_path TEXT NOT NULL, + generation_type TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'running', + started_at INTEGER NOT NULL, + completed_at INTEGER, + result_insight_id INTEGER, + error_message TEXT +); + +INSERT INTO insight_generation_jobs_new + SELECT id, library_id, file_path, generation_type, status, started_at, completed_at, result_insight_id, error_message + FROM insight_generation_jobs; + +DROP TABLE insight_generation_jobs; + +ALTER TABLE insight_generation_jobs_new RENAME TO insight_generation_jobs; + +CREATE INDEX idx_insight_gen_jobs_file + ON insight_generation_jobs(library_id, file_path); + +CREATE INDEX idx_insight_gen_jobs_status_cleanup + ON insight_generation_jobs(status, started_at); diff --git a/migrations/2026-05-27-000002_add_insight_generation_params/down.sql b/migrations/2026-05-27-000002_add_insight_generation_params/down.sql new file mode 100644 index 0000000..217d330 --- /dev/null +++ b/migrations/2026-05-27-000002_add_insight_generation_params/down.sql @@ -0,0 +1,11 @@ +-- SQLite doesn't support DROP COLUMN before 3.35.0; recreate the table +-- without the new columns. This is only needed for rollback. +CREATE TABLE photo_insights_old AS + SELECT id, library_id, rel_path, title, summary, generated_at, + model_version, is_current, training_messages, approved, + backend, fewshot_source_ids, content_hash + FROM photo_insights; + +DROP TABLE photo_insights; + +ALTER TABLE photo_insights_old RENAME TO photo_insights; diff --git a/migrations/2026-05-27-000002_add_insight_generation_params/up.sql b/migrations/2026-05-27-000002_add_insight_generation_params/up.sql new file mode 100644 index 0000000..1313fde --- /dev/null +++ b/migrations/2026-05-27-000002_add_insight_generation_params/up.sql @@ -0,0 +1,8 @@ +-- Persist generation parameters on each insight row for auditing. +ALTER TABLE photo_insights ADD COLUMN num_ctx INTEGER; +ALTER TABLE photo_insights ADD COLUMN temperature REAL; +ALTER TABLE photo_insights ADD COLUMN top_p REAL; +ALTER TABLE photo_insights ADD COLUMN top_k INTEGER; +ALTER TABLE photo_insights ADD COLUMN min_p REAL; +ALTER TABLE photo_insights ADD COLUMN system_prompt TEXT; +ALTER TABLE photo_insights ADD COLUMN persona_id TEXT; diff --git a/migrations/2026-05-27-000003_add_insight_token_counts/down.sql b/migrations/2026-05-27-000003_add_insight_token_counts/down.sql new file mode 100644 index 0000000..f680ad2 --- /dev/null +++ b/migrations/2026-05-27-000003_add_insight_token_counts/down.sql @@ -0,0 +1,13 @@ +-- SQLite doesn't support DROP COLUMN before 3.35.0; recreate the table +-- without the token-count columns. This is only needed for rollback. +CREATE TABLE photo_insights_old AS + SELECT id, library_id, rel_path, title, summary, generated_at, + model_version, is_current, training_messages, approved, + backend, fewshot_source_ids, content_hash, + num_ctx, temperature, top_p, top_k, min_p, + system_prompt, persona_id + FROM photo_insights; + +DROP TABLE photo_insights; + +ALTER TABLE photo_insights_old RENAME TO photo_insights; diff --git a/migrations/2026-05-27-000003_add_insight_token_counts/up.sql b/migrations/2026-05-27-000003_add_insight_token_counts/up.sql new file mode 100644 index 0000000..ce8890e --- /dev/null +++ b/migrations/2026-05-27-000003_add_insight_token_counts/up.sql @@ -0,0 +1,6 @@ +-- Persist token usage on each insight row. Split from +-- 2026-05-27-000002_add_insight_generation_params because that +-- migration was already applied on some environments before these +-- columns were added. +ALTER TABLE photo_insights ADD COLUMN prompt_eval_count INTEGER; +ALTER TABLE photo_insights ADD COLUMN eval_count INTEGER; diff --git a/src/ai/handlers.rs b/src/ai/handlers.rs index a7a3720..9fbe6b7 100644 --- a/src/ai/handlers.rs +++ b/src/ai/handlers.rs @@ -1,12 +1,14 @@ use actix_web::{HttpRequest, HttpResponse, Responder, delete, get, post, web}; +use futures::StreamExt; use opentelemetry::KeyValue; use opentelemetry::trace::{Span, Status, Tracer}; use serde::{Deserialize, Serialize}; use crate::ai::insight_chat::{ChatStreamEvent, ChatTurnRequest}; use crate::ai::ollama::ChatMessage; -use crate::ai::{InsightGenerator, ModelCapabilities, OllamaClient}; +use crate::ai::{ModelCapabilities, OllamaClient}; use crate::data::Claims; +use crate::database::models::{InsightGenerationType, InsightJobStatus}; use crate::database::{ExifDao, InsightDao}; use crate::libraries; use crate::otel::{extract_context_from_request, global_tracer}; @@ -64,6 +66,209 @@ pub struct GetPhotoInsightQuery { pub library: Option, } +#[derive(Debug, Deserialize)] +pub struct GenerationStatusQuery { + /// If provided, look up the job by id. + #[serde(default)] + pub job_id: Option, + /// If provided with `library`, look up the latest running job for this + /// file. Used when the client doesn't have a persisted job_id. + #[serde(default)] + pub path: Option, + #[serde(default)] + pub library: Option, +} + +/// GET /insights/generation/status - Check status of a generation job. +/// Accepts either `?job_id=` or `?path=&library=`. +#[get("/insights/generation/status")] +pub async fn generation_status_handler( + _claims: Claims, + query: web::Query, + app_state: web::Data, +) -> impl Responder { + let ctx = opentelemetry::Context::new(); + + if let Some(jid) = query.job_id { + let mut dao = app_state + .insight_job_dao + .lock() + .expect("Unable to lock InsightJobDao"); + match dao.get_job_by_id(&ctx, jid) { + Ok(Some(job)) => { + return HttpResponse::Ok().json(GenerationStatusResponse { + job_id: job.id, + status: InsightJobStatus::parse(&job.status), + started_at: job.started_at, + completed_at: job.completed_at, + result_insight_id: job.result_insight_id, + error_message: job.error_message, + }); + } + Ok(None) => { + return HttpResponse::NotFound().json(serde_json::json!({ + "error": format!("Job {} not found", jid) + })); + } + Err(e) => { + log::error!("Failed to look up job {}: {:?}", jid, e); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Failed to look up job" + })); + } + } + } + + if let Some(ref fp) = query.path { + let library = libraries::resolve_library_param(&app_state, query.library.as_deref()) + .ok() + .flatten() + .unwrap_or_else(|| app_state.primary_library()); + let normalized = normalize_path(fp); + + let mut dao = app_state + .insight_job_dao + .lock() + .expect("Unable to lock InsightJobDao"); + match dao.get_active_job(&ctx, library.id, &normalized) { + Ok(Some(job)) => { + return HttpResponse::Ok().json(GenerationStatusResponse { + job_id: job.id, + status: InsightJobStatus::parse(&job.status), + started_at: job.started_at, + completed_at: job.completed_at, + result_insight_id: job.result_insight_id, + error_message: job.error_message, + }); + } + Ok(None) => { + return HttpResponse::Ok().json(serde_json::json!({ + "status": "idle", + "message": "No running generation job for this file" + })); + } + Err(e) => { + log::error!("Failed to look up active job for {}: {:?}", normalized, e); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Failed to look up active job" + })); + } + } + } + + HttpResponse::BadRequest().json(serde_json::json!({ + "error": "Provide either job_id or path query parameter" + })) +} + +#[derive(Debug, Deserialize)] +pub struct CancelGenerationRequest { + /// If provided, cancel the specific job by id. + #[serde(default)] + pub job_id: Option, + /// If provided with `library`, cancel all running jobs for this file. + #[serde(default)] + pub file_path: Option, + #[serde(default)] + pub library: Option, +} + +/// POST /insights/generation/cancel - Cancel a running generation job. +/// Accepts either `job_id` or `file_path` + optional `library` in the body. +#[post("/insights/generation/cancel")] +pub async fn cancel_generation_handler( + _claims: Claims, + request: web::Json, + app_state: web::Data, +) -> impl Responder { + let ctx = opentelemetry::Context::new(); + + if let Some(jid) = request.job_id { + let mut dao = app_state + .insight_job_dao + .lock() + .expect("Unable to lock InsightJobDao"); + match dao.cancel_job(&ctx, jid) { + Ok(true) => { + let mut handles = app_state + .insight_job_handles + .lock() + .expect("Unable to lock InsightJobHandles"); + if let Some(handle) = handles.remove(&jid) { + handle.abort(); + } + return HttpResponse::Ok().json(serde_json::json!({ + "success": true, + "message": format!("Job {} cancelled", jid) + })); + } + Ok(false) => { + return HttpResponse::Ok().json(serde_json::json!({ + "success": true, + "message": format!("Job {} was not running", jid) + })); + } + Err(e) => { + log::error!("Failed to cancel job {}: {:?}", jid, e); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Failed to cancel job" + })); + } + } + } + + if let Some(ref fp) = request.file_path { + let library = libraries::resolve_library_param(&app_state, request.library.as_deref()) + .ok() + .flatten() + .unwrap_or_else(|| app_state.primary_library()); + let normalized = normalize_path(fp); + + // Get active job ids first, then cancel in DB, then abort tasks + let active_ids: Vec = { + let mut dao = app_state + .insight_job_dao + .lock() + .expect("Unable to lock InsightJobDao"); + let ids = dao + .get_active_job(&ctx, library.id, &normalized) + .ok() + .flatten() + .map(|j| vec![j.id]) + .unwrap_or_default(); + let _ = dao.cancel_active_jobs(&ctx, library.id, &normalized); + ids + }; + + if active_ids.is_empty() { + return HttpResponse::Ok().json(serde_json::json!({ + "success": true, + "message": "No running generation job for this file" + })); + } + + for jid in &active_ids { + if let Some(handle) = app_state + .insight_job_handles + .lock() + .expect("Unable to lock InsightJobHandles") + .remove(jid) + { + handle.abort(); + } + } + + return HttpResponse::Ok().json(serde_json::json!({ + "success": true, + "message": format!("Cancelled {} running job(s) for {}", active_ids.len(), normalized) + })); + } + + HttpResponse::BadRequest().json(serde_json::json!({ + "error": "Provide either job_id or file_path in the request body" + })) +} + #[derive(Debug, Deserialize)] pub struct RateInsightRequest { pub file_path: String, @@ -76,6 +281,24 @@ pub struct ExportTrainingDataQuery { pub approved_only: Option, } +#[derive(Debug, Serialize)] +pub struct JobIdResponse { + pub job_id: i32, +} + +#[derive(Debug, Serialize)] +pub struct GenerationStatusResponse { + pub job_id: i32, + pub status: InsightJobStatus, + pub started_at: i64, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub result_insight_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_message: Option, +} + #[derive(Debug, Serialize)] pub struct PhotoInsightResponse { pub id: i32, @@ -94,6 +317,20 @@ pub struct PhotoInsightResponse { /// True when the insight was generated agentically and a chat /// continuation can be started against it. Drives the mobile chat button. pub has_training_messages: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_ctx: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub persona_id: Option, } #[derive(Debug, Serialize)] @@ -110,70 +347,182 @@ pub struct ServerModels { pub default_model: String, } -/// POST /insights/generate - Generate insight for a specific photo +/// POST /insights/generate - Generate insight for a specific photo (async) #[post("/insights/generate")] pub async fn generate_insight_handler( http_request: HttpRequest, _claims: Claims, request: web::Json, - insight_generator: web::Data, + app_state: web::Data, ) -> impl Responder { let parent_context = extract_context_from_request(&http_request); let tracer = global_tracer(); let mut span = tracer.start_with_context("http.insights.generate", &parent_context); let normalized_path = normalize_path(&request.file_path); + let library = app_state.primary_library(); + let gen_type = InsightGenerationType::Standard; span.set_attribute(KeyValue::new("file_path", normalized_path.clone())); if let Some(ref model) = request.model { span.set_attribute(KeyValue::new("model", model.clone())); } - if let Some(ref prompt) = request.system_prompt { - span.set_attribute(KeyValue::new("has_custom_prompt", true)); - span.set_attribute(KeyValue::new("prompt_length", prompt.len() as i64)); - } - if let Some(ctx) = request.num_ctx { - span.set_attribute(KeyValue::new("num_ctx", ctx as i64)); - } log::info!( - "Manual insight generation triggered for photo: {} with model: {:?}, custom_prompt: {}, num_ctx: {:?}", + "Manual insight generation triggered for photo: {} with model: {:?}", normalized_path, - request.model, - request.system_prompt.is_some(), - request.num_ctx + request.model ); - // Generate insight with optional custom model, system prompt, and context size - let result = insight_generator - .generate_insight_for_photo_with_config( - &normalized_path, - request.model.clone(), - request.system_prompt.clone(), - request.num_ctx, - request.temperature, - request.top_p, - request.top_k, - request.min_p, - ) - .await; - - match result { - Ok(()) => { - span.set_status(Status::Ok); - HttpResponse::Ok().json(serde_json::json!({ - "success": true, - "message": "Insight generated successfully" - })) - } - Err(e) => { - log::error!("Failed to generate insight: {:?}", e); - span.set_status(Status::error(e.to_string())); - HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to generate insight: {:?}", e) - })) + // Look up and abort any running job for this file, then cancel in DB + let old_job_ids: Vec = { + let mut dao = app_state + .insight_job_dao + .lock() + .expect("Unable to lock InsightJobDao"); + let ctx = opentelemetry::Context::new(); + let ids = dao + .get_active_job(&ctx, library.id, &normalized_path) + .ok() + .flatten() + .map(|j| vec![j.id]) + .unwrap_or_default(); + let _ = dao.cancel_active_jobs(&ctx, library.id, &normalized_path); + ids + }; + for jid in &old_job_ids { + if let Some(handle) = app_state + .insight_job_handles + .lock() + .expect("Unable to lock InsightJobHandles") + .remove(jid) + { + handle.abort(); } } + + let job_id = { + let mut dao = app_state + .insight_job_dao + .lock() + .expect("Unable to lock InsightJobDao"); + match dao.create_job( + &opentelemetry::Context::new(), + library.id, + &normalized_path, + gen_type, + ) { + Ok(id) => id, + Err(e) => { + log::error!("Failed to create generation job: {:?}", e); + span.set_status(Status::error("Failed to create generation job")); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Failed to create generation job" + })); + } + } + }; + + // Spawn background task with timeout + let generator = app_state.insight_generator.clone(); + let job_dao = app_state.insight_job_dao.clone(); + let job_handles = app_state.insight_job_handles.clone(); + let path = normalized_path.clone(); + + let handle = tokio::spawn(async move { + let timeout_secs: u64 = std::env::var("INSIGHT_GENERATION_TIMEOUT_SECS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(120); + + let path_for_task = path.clone(); + let generator_for_task = generator.clone(); + let result = tokio::task::spawn(async move { + tokio::time::timeout( + std::time::Duration::from_secs(timeout_secs), + generator_for_task.generate_insight_for_photo_with_config( + &path_for_task, + request.model.clone(), + request.system_prompt.clone(), + request.num_ctx, + request.temperature, + request.top_p, + request.top_k, + request.min_p, + ), + ) + .await + }) + .await; + + let ctx = opentelemetry::Context::new(); + let mut dao = job_dao.lock().expect("Unable to lock InsightJobDao"); + + match result { + Ok(Ok(Ok(()))) => { + let mut insight_dao = generator + .insight_dao() + .lock() + .expect("Unable to lock InsightDao"); + let insight_id = insight_dao + .get_insight(&ctx, &path) + .ok() + .flatten() + .map(|i| i.id); + if let Some(id) = insight_id { + if let Err(e) = dao.complete_job(&ctx, job_id, id) { + log::error!("Failed to mark job {} as completed: {:?}", job_id, e); + } + } else if let Err(e) = dao.fail_job(&ctx, job_id, "generation returned no insight") + { + log::error!("Failed to mark job {} as failed: {:?}", job_id, e); + } + } + Ok(Ok(Err(e))) => { + log::error!("Insight generation failed for {}: {:?}", path, e); + if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:?}", e)) { + log::error!("Failed to mark job {} as failed: {:?}", job_id, err); + } + } + Ok(Err(_)) => { + log::error!( + "Insight generation timed out for {} after {}s", + path, + timeout_secs + ); + if let Err(err) = + dao.fail_job(&ctx, job_id, &format!("timeout after {}s", timeout_secs)) + { + log::error!("Failed to mark job {} as failed: {:?}", job_id, err); + } + } + Err(_) => { + log::error!("Insight generation task panicked for {}", path); + if let Err(err) = dao.fail_job(&ctx, job_id, "generation task panicked") { + log::error!("Failed to mark job {} as failed: {:?}", job_id, err); + } + } + } + + // Remove handle from map on completion + let mut handles = job_handles + .lock() + .expect("Unable to lock InsightJobHandles"); + handles.remove(&job_id); + }); + + // Store abort handle + { + let mut handles = app_state + .insight_job_handles + .lock() + .expect("Unable to lock InsightJobHandles"); + handles.insert(job_id, handle.abort_handle()); + } + + span.set_attribute(KeyValue::new("job_id", job_id as i64)); + span.set_status(Status::Ok); + HttpResponse::Accepted().json(JobIdResponse { job_id }) } /// GET /insights?path=/path/to/photo.jpg - Fetch insight for specific photo @@ -213,11 +562,18 @@ pub async fn get_insight_handler( summary: insight.summary, generated_at: insight.generated_at, model_version: insight.model_version, - prompt_eval_count: None, - eval_count: None, + prompt_eval_count: insight.prompt_eval_count, + eval_count: insight.eval_count, approved: insight.approved, has_training_messages: insight.training_messages.is_some(), backend: insight.backend, + num_ctx: insight.num_ctx, + temperature: insight.temperature, + top_p: insight.top_p, + top_k: insight.top_k, + min_p: insight.min_p, + system_prompt: insight.system_prompt, + persona_id: insight.persona_id, }; HttpResponse::Ok().json(response) } @@ -282,11 +638,18 @@ pub async fn get_all_insights_handler( summary: insight.summary, generated_at: insight.generated_at, model_version: insight.model_version, - prompt_eval_count: None, - eval_count: None, + prompt_eval_count: insight.prompt_eval_count, + eval_count: insight.eval_count, approved: insight.approved, has_training_messages: insight.training_messages.is_some(), backend: insight.backend, + num_ctx: insight.num_ctx, + temperature: insight.temperature, + top_p: insight.top_p, + top_k: insight.top_k, + min_p: insight.min_p, + system_prompt: insight.system_prompt, + persona_id: insight.persona_id, }) .collect(); @@ -301,56 +664,86 @@ pub async fn get_all_insights_handler( } } -/// POST /insights/generate/agentic - Generate insight using agentic tool-calling loop +/// POST /insights/generate/agentic - Generate insight using agentic tool-calling loop (async) #[post("/insights/generate/agentic")] pub async fn generate_agentic_insight_handler( http_request: HttpRequest, claims: Claims, request: web::Json, - insight_generator: web::Data, - insight_dao: web::Data>>, + app_state: web::Data, ) -> impl Responder { - // Service tokens (sub: "service:apollo") fall through to user_id=1 - // — the operator convention. Mobile/web clients have a numeric sub. - let user_id = claims.sub.parse::().unwrap_or(1); let parent_context = extract_context_from_request(&http_request); let tracer = global_tracer(); let mut span = tracer.start_with_context("http.insights.generate_agentic", &parent_context); let normalized_path = normalize_path(&request.file_path); + let library = app_state.primary_library(); + let gen_type = InsightGenerationType::Agentic; span.set_attribute(KeyValue::new("file_path", normalized_path.clone())); if let Some(ref model) = request.model { span.set_attribute(KeyValue::new("model", model.clone())); } - if let Some(ref prompt) = request.system_prompt { - span.set_attribute(KeyValue::new("has_custom_prompt", true)); - span.set_attribute(KeyValue::new("prompt_length", prompt.len() as i64)); + if let Some(ref backend) = request.backend { + span.set_attribute(KeyValue::new("backend", backend.clone())); } - if let Some(ctx) = request.num_ctx { - span.set_attribute(KeyValue::new("num_ctx", ctx as i64)); - } - - let max_iterations: usize = std::env::var("AGENTIC_MAX_ITERATIONS") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(12); - - span.set_attribute(KeyValue::new("max_iterations", max_iterations as i64)); log::info!( - "Agentic insight generation triggered for photo: {} with model: {:?}, max_iterations: {}", + "Agentic insight generation triggered for photo: {} with model: {:?}", normalized_path, - request.model, - max_iterations + request.model ); - if let Some(ref b) = request.backend { - span.set_attribute(KeyValue::new("backend", b.clone())); + // Look up and abort any running job for this file, then cancel in DB + let old_job_ids: Vec = { + let mut dao = app_state + .insight_job_dao + .lock() + .expect("Unable to lock InsightJobDao"); + let ctx = opentelemetry::Context::new(); + let ids = dao + .get_active_job(&ctx, library.id, &normalized_path) + .ok() + .flatten() + .map(|j| vec![j.id]) + .unwrap_or_default(); + let _ = dao.cancel_active_jobs(&ctx, library.id, &normalized_path); + ids + }; + for jid in &old_job_ids { + if let Some(handle) = app_state + .insight_job_handles + .lock() + .expect("Unable to lock InsightJobHandles") + .remove(jid) + { + handle.abort(); + } } - // Resolve few-shot ids: request-provided ids take precedence when - // non-empty; otherwise fall back to the hardcoded defaults. + let job_id = { + let mut dao = app_state + .insight_job_dao + .lock() + .expect("Unable to lock InsightJobDao"); + match dao.create_job( + &opentelemetry::Context::new(), + library.id, + &normalized_path, + gen_type, + ) { + Ok(id) => id, + Err(e) => { + log::error!("Failed to create agentic generation job: {:?}", e); + span.set_status(Status::error("Failed to create generation job")); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Failed to create generation job" + })); + } + } + }; + + // Resolve few-shot ids for the background task let fewshot_ids: Vec = match request.fewshot_insight_ids.as_deref() { Some(ids) if !ids.is_empty() => ids.iter().take(2).copied().collect(), _ => DEFAULT_FEWSHOT_INSIGHT_IDS @@ -359,11 +752,14 @@ pub async fn generate_agentic_insight_handler( .copied() .collect(), }; - span.set_attribute(KeyValue::new("fewshot_count", fewshot_ids.len() as i64)); let fewshot_examples: Vec> = { let otel_context = opentelemetry::Context::new(); - let mut dao = insight_dao.lock().expect("Unable to lock InsightDao"); + let mut dao = app_state + .insight_chat + .insight_dao() + .lock() + .expect("Unable to lock InsightDao"); fewshot_ids .iter() .filter_map(|id| { @@ -384,90 +780,116 @@ pub async fn generate_agentic_insight_handler( .collect() }; + let user_id = claims.sub.parse::().unwrap_or(1); let persona_id = request .persona_id .clone() .filter(|s| !s.trim().is_empty()) .unwrap_or_else(|| "default".to_string()); - span.set_attribute(KeyValue::new("persona_id", persona_id.clone())); - let result = insight_generator - .generate_agentic_insight_for_photo( - &normalized_path, - request.model.clone(), - request.system_prompt.clone(), - request.num_ctx, - request.temperature, - request.top_p, - request.top_k, - request.min_p, - max_iterations, - request.backend.clone(), - fewshot_examples, - fewshot_ids, - user_id, - persona_id, - ) + let max_iterations: usize = std::env::var("AGENTIC_MAX_ITERATIONS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(12); + + // Spawn background task with timeout + let generator = app_state.insight_generator.clone(); + let job_dao = app_state.insight_job_dao.clone(); + let job_handles = app_state.insight_job_handles.clone(); + let path = normalized_path.clone(); + + let handle = tokio::spawn(async move { + let timeout_secs: u64 = std::env::var("INSIGHT_GENERATION_TIMEOUT_SECS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(180); + + let path_for_task = path.clone(); + let generator_for_task = generator.clone(); + let result = tokio::task::spawn(async move { + tokio::time::timeout( + std::time::Duration::from_secs(timeout_secs), + generator_for_task.generate_agentic_insight_for_photo( + &path_for_task, + request.model.clone(), + request.system_prompt.clone(), + request.num_ctx, + request.temperature, + request.top_p, + request.top_k, + request.min_p, + max_iterations, + request.backend.clone(), + fewshot_examples, + fewshot_ids, + user_id, + persona_id, + ), + ) + .await + }) .await; - match result { - Ok((prompt_eval_count, eval_count)) => { - span.set_status(Status::Ok); - // Fetch the stored insight to return it - let otel_context = opentelemetry::Context::new(); - let mut dao = insight_dao.lock().expect("Unable to lock InsightDao"); - match dao.get_insight(&otel_context, &normalized_path) { - Ok(Some(insight)) => { - let response = PhotoInsightResponse { - id: insight.id, - file_path: insight.file_path, - title: insight.title, - summary: insight.summary, - generated_at: insight.generated_at, - model_version: insight.model_version, - prompt_eval_count, - eval_count, - approved: insight.approved, - has_training_messages: insight.training_messages.is_some(), - backend: insight.backend, - }; - HttpResponse::Ok().json(response) - } - Ok(None) => HttpResponse::Ok().json(serde_json::json!({ - "success": true, - "message": "Agentic insight generated successfully" - })), - Err(e) => { - log::warn!("Insight stored but failed to retrieve: {:?}", e); - HttpResponse::Ok().json(serde_json::json!({ - "success": true, - "message": "Agentic insight generated successfully" - })) - } - } - } - Err(e) => { - let error_msg = format!("{:?}", e); - log::error!("Failed to generate agentic insight: {}", error_msg); - span.set_status(Status::error(error_msg.clone())); + let ctx = opentelemetry::Context::new(); + let mut dao = job_dao.lock().expect("Unable to lock InsightJobDao"); - if error_msg.contains("tool calling not supported") - || error_msg.contains("model not available") - { - HttpResponse::BadRequest().json(serde_json::json!({ - "error": format!("Failed to generate agentic insight: {}", error_msg) - })) - } else if error_msg.contains("error parsing tool call") { - HttpResponse::BadRequest().json(serde_json::json!({ - "error": "Model is not compatible with Ollama's tool calling protocol. Try a model known to support native tool calling (e.g. llama3.1, llama3.2, qwen2.5, mistral-nemo)." - })) - } else { - HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to generate agentic insight: {}", error_msg) - })) + match result { + Ok(Ok(Ok((Some(insight_id), _, _)))) => { + if let Err(e) = dao.complete_job(&ctx, job_id, insight_id) { + log::error!("Failed to mark job {} as completed: {:?}", job_id, e); + } + } + Ok(Ok(Ok((None, _, _)))) => { + if let Err(e) = dao.fail_job(&ctx, job_id, "agentic generation returned no insight") + { + log::error!("Failed to mark job {} as failed: {:?}", job_id, e); + } + } + Ok(Ok(Err(e))) => { + log::error!("Agentic insight generation failed for {}: {:?}", path, e); + if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:?}", e)) { + log::error!("Failed to mark job {} as failed: {:?}", job_id, err); + } + } + Ok(Err(_)) => { + log::error!( + "Agentic insight generation timed out for {} after {}s", + path, + timeout_secs + ); + if let Err(err) = + dao.fail_job(&ctx, job_id, &format!("timeout after {}s", timeout_secs)) + { + log::error!("Failed to mark job {} as failed: {:?}", job_id, err); + } + } + Err(_) => { + log::error!("Agentic insight generation task panicked for {}", path); + if let Err(err) = dao.fail_job(&ctx, job_id, "generation task panicked") { + log::error!("Failed to mark job {} as failed: {:?}", job_id, err); + } } } + + // Remove handle from map on completion + let mut handles = job_handles + .lock() + .expect("Unable to lock InsightJobHandles"); + handles.remove(&job_id); + }); + + // Store abort handle + { + let mut handles = app_state + .insight_job_handles + .lock() + .expect("Unable to lock InsightJobHandles"); + handles.insert(job_id, handle.abort_handle()); } + + span.set_attribute(KeyValue::new("job_id", job_id as i64)); + span.set_status(Status::Ok); + HttpResponse::Accepted().json(JobIdResponse { job_id }) } /// GET /insights/models - Local-backend models with capabilities. Returns @@ -1012,7 +1434,26 @@ pub async fn chat_stream_handler( } fn render_sse_frame(ev: &ChatStreamEvent) -> String { - let (event_name, payload) = match ev { + let (event_name, payload) = sse_event_payload(ev); + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); + format!("event: {}\ndata: {}\n\n", event_name, data) +} + +/// Like `render_sse_frame`, but stamps the event's absolute sequence number +/// (`seq`) into the payload so reconnecting replay clients can compute +/// `skip_before` precisely. `seq` is distinct from the tool-pairing `index` +/// already carried by `tool_call`/`tool_result`. +fn render_indexed_frame(ev: &ChatStreamEvent, seq: u32) -> String { + let (event_name, mut payload) = sse_event_payload(ev); + if let serde_json::Value::Object(map) = &mut payload { + map.insert("seq".to_string(), serde_json::json!(seq)); + } + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); + format!("event: {}\ndata: {}\n\n", event_name, data) +} + +fn sse_event_payload(ev: &ChatStreamEvent) -> (&'static str, serde_json::Value) { + match ev { ChatStreamEvent::IterationStart { n, max } => { ("iteration_start", serde_json::json!({ "n": n, "max": max })) } @@ -1050,6 +1491,7 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String { amended_insight_id, backend_used, model_used, + cancelled, } => ( "done", serde_json::json!({ @@ -1062,6 +1504,7 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String { "amended_insight_id": amended_insight_id, "backend": backend_used, "model": model_used, + "cancelled": cancelled, }), ), // Apollo's frontend SSE consumer (and its free-chat backend, which @@ -1070,7 +1513,491 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String { // "no insight found for path") was silently dropped, leaving an // empty assistant bubble with no clue why the turn died. ChatStreamEvent::Error(msg) => ("error_message", serde_json::json!({ "message": msg })), - }; - let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); - format!("event: {}\ndata: {}\n\n", event_name, data) + } +} + +/// POST /insights/chat/turn — async turn dispatch. Returns turn_id immediately, +/// client then polls GET /insights/chat/turn/{turn_id} for SSE replay. +#[post("/insights/chat/turn")] +pub async fn turn_async_handler( + http_request: HttpRequest, + claims: Claims, + request: web::Json, + app_state: web::Data, +) -> impl Responder { + let parent_context = extract_context_from_request(&http_request); + let tracer = global_tracer(); + let mut span = tracer.start_with_context("http.insights.chat_turn_async", &parent_context); + span.set_attribute(KeyValue::new("file_path", request.file_path.clone())); + + let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) { + Ok(Some(lib)) => lib, + Ok(None) => app_state.primary_library(), + Err(e) => { + return HttpResponse::BadRequest().json(serde_json::json!({ + "error": format!("invalid library: {}", e) + })); + } + }; + + let user_id = claims.sub.parse::().unwrap_or(1); + + let chat_req = ChatTurnRequest { + library_id: library.id, + user_id, + file_path: request.file_path.clone(), + user_message: request.user_message.clone(), + model: request.model.clone(), + backend: request.backend.clone(), + num_ctx: request.num_ctx, + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + min_p: request.min_p, + max_iterations: request.max_iterations, + system_prompt: request.system_prompt.clone(), + persona_id: request.persona_id.clone(), + amend: request.amend, + regenerate: request.regenerate, + }; + + let service = app_state.insight_chat.clone(); + let registry = app_state.turn_registry.clone(); + + let turn_id = service.chat_turn_async(registry, chat_req).await; + span.set_attribute(KeyValue::new("turn_id", turn_id.clone())); + span.set_status(Status::Ok); + HttpResponse::Accepted().json(serde_json::json!({ + "turn_id": turn_id, + "status": "running" + })) +} + +/// Query params for the SSE replay stream. +#[derive(Debug, Deserialize)] +pub struct ReplayQuery { + /// Replay events from this absolute sequence number (`seq`) onward. + /// Absent or 0 replays from the beginning. On reconnect the client sends + /// the `seq` of the last event it applied, plus one. + pub skip_before: Option, +} + +/// GET /insights/chat/turn/{turn_id} — SSE replay stream. +#[get("/insights/chat/turn/{turn_id}")] +pub async fn turn_replay_handler( + http_request: HttpRequest, + path: web::Path, + query: web::Query, + app_state: web::Data, +) -> HttpResponse { + use crate::ai::turn_registry::ReplayOutcome; + + let turn_id = path.into_inner(); + let skip_before = query.skip_before.unwrap_or(0); + + let parent_context = extract_context_from_request(&http_request); + let tracer = global_tracer(); + let mut span = tracer.start_with_context("ai.chat.turn.replay", &parent_context); + span.set_attribute(KeyValue::new("turn_id", turn_id.clone())); + span.set_attribute(KeyValue::new("skip_before", skip_before as i64)); + + let registry = app_state.turn_registry.clone(); + let entry = match registry.get(&turn_id).await { + Some(e) => e, + None => { + span.set_status(Status::error("turn not found")); + return HttpResponse::NotFound().json(serde_json::json!({ + "error": format!("turn {} not found", turn_id) + })); + } + }; + + let info = entry.info().await; + span.set_attribute(KeyValue::new("status", info.status.as_str())); + span.set_attribute(KeyValue::new( + "event_count", + info.total_events_pushed as i64, + )); + let turn_info_frame = render_turn_info_frame(&info); + + // Initial buffered batch: events produced before this connection attached. + // Stamp each frame with its absolute `seq` so the client can track + // `skip_before` precisely across reconnects. + let (initial_frames, start_skip) = match entry.replay_from(skip_before).await { + ReplayOutcome::Gone => { + span.set_status(Status::error("buffer evicted")); + return HttpResponse::Gone().json(serde_json::json!({ + "error": "turn history has expired (buffer evicted)" + })); + } + ReplayOutcome::CaughtUp { next_skip } => (Vec::new(), next_skip), + ReplayOutcome::Events { events, next_skip } => { + let frames: Vec = events + .into_iter() + .enumerate() + .map(|(i, ev)| { + actix_web::web::Bytes::from(render_indexed_frame(&ev, skip_before + i as u32)) + }) + .collect(); + (frames, next_skip) + } + }; + + span.set_status(Status::Ok); + let running = entry.is_running(); + + // Head: the `turn_info` event followed by any already-buffered events. + let head = futures::stream::once(async move { + Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(turn_info_frame)) + }) + .chain(futures::stream::iter( + initial_frames.into_iter().map(Ok::<_, actix_web::Error>), + )); + + if !running { + // Completed turn: every event — including the terminal Done/Error — is + // already in the buffered batch above. Emit it and close. + return HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("X-Accel-Buffering", "no")) + .streaming(head); + } + + // In-progress turn: after the head, wait for new events. `next_batch` + // drains every buffered event (including the terminal one) before it + // reports the turn finished, so the final Done/Error is never dropped; + // CaughtUp then closes the stream by returning None. + let tail = futures::stream::unfold( + ( + entry, + start_skip, + Vec::::new(), + false, + ), + |(entry, skip, pending, finished)| async move { + // Flush queued frames from a previous multi-event batch first. + if let Some((first, rest)) = pending.split_first() { + return Some((Ok(first.clone()), (entry, skip, rest.to_vec(), finished))); + } + if finished { + return None; + } + + match entry.next_batch(skip).await { + ReplayOutcome::Events { events, next_skip } => { + let frames: Vec = events + .into_iter() + .enumerate() + .map(|(i, ev)| { + actix_web::web::Bytes::from(render_indexed_frame(&ev, skip + i as u32)) + }) + .collect(); + // next_batch only returns Events for a non-empty batch. + let (first, rest) = frames.split_first().expect("non-empty batch"); + Some((Ok(first.clone()), (entry, next_skip, rest.to_vec(), false))) + } + // Terminal reached and fully drained — close the connection. + ReplayOutcome::CaughtUp { .. } => None, + ReplayOutcome::Gone => { + // Evicted mid-stream: emit one error frame, then close. + let gone = + actix_web::web::Bytes::from(render_sse_frame(&ChatStreamEvent::Error( + "turn history has expired (buffer evicted)".to_string(), + ))); + Some((Ok(gone), (entry, skip, Vec::new(), true))) + } + } + }, + ); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("X-Accel-Buffering", "no")) + .streaming(head.chain(tail)) +} + +fn render_turn_info_frame(info: &crate::ai::turn_registry::TurnInfo) -> String { + let payload = serde_json::json!({ + "turn_id": info.turn_id, + "file_path": info.file_path, + "library_id": info.library_id, + "status": info.status.as_str(), + "total_events_pushed": info.total_events_pushed, + "buffered_count": info.buffered_count, + }); + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); + format!("event: turn_info\ndata: {}\n\n", data) +} + +/// DELETE /insights/chat/turn/{turn_id} — cancel a running turn. +#[delete("/insights/chat/turn/{turn_id}")] +pub async fn cancel_turn_handler( + http_request: HttpRequest, + path: web::Path, + app_state: web::Data, +) -> impl Responder { + let turn_id = path.into_inner(); + + let parent_context = extract_context_from_request(&http_request); + let tracer = global_tracer(); + let mut span = tracer.start_with_context("ai.chat.turn.cancel", &parent_context); + span.set_attribute(KeyValue::new("turn_id", turn_id.clone())); + + let registry = app_state.turn_registry.clone(); + let entry = match registry.get(&turn_id).await { + Some(e) => e, + None => { + span.set_status(Status::error("turn not found")); + return HttpResponse::NotFound().json(serde_json::json!({ + "error": format!("turn {} not found", turn_id) + })); + } + }; + + // Abort the spawned task so it stops producing events promptly. The loop + // also checks `is_running()` at each iteration boundary as a graceful + // backstop in case the abort lands between await points. + let aborted = entry.abort(); + span.set_attribute(KeyValue::new("aborted", aborted)); + + // Push the terminal event BEFORE flipping status: a replay reader treats a + // terminal status with no buffered tail as "closed", so the Done must be + // buffered first for in-progress connections to receive it. + let _ = entry + .push_event(ChatStreamEvent::Done { + tool_calls_made: 0, + iterations_used: 0, + truncated: false, + prompt_tokens: None, + eval_tokens: None, + num_ctx: None, + amended_insight_id: None, + backend_used: "cancelled".to_string(), + model_used: "cancelled".to_string(), + cancelled: true, + }) + .await; + entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Cancelled); + span.set_status(Status::Ok); + + HttpResponse::Ok().json(serde_json::json!({ + "cancelled": true + })) +} + +#[cfg(test)] +mod turn_replay_tests { + use super::{cancel_turn_handler, render_indexed_frame, turn_replay_handler}; + use crate::ai::insight_chat::ChatStreamEvent; + use crate::ai::turn_registry::{TurnEntry, TurnStatus}; + use crate::state::AppState; + use actix_web::test as actix_test; + use actix_web::{App, web::Data}; + use std::sync::Arc; + + /// Serialize `AppState::test_state()` construction across the parallel + /// tests in this module: each build opens ~10 DAO connections to the one + /// shared `DATABASE_URL` file, and doing several at once races the WAL + /// `journal_mode` switch into a spurious "database is locked". The test + /// bodies themselves still run in parallel; only the open is gated. + static DB_INIT: std::sync::Mutex<()> = std::sync::Mutex::new(()); + + fn build_state() -> Data { + let _guard = DB_INIT.lock().unwrap_or_else(|p| p.into_inner()); + Data::new(AppState::test_state()) + } + + fn done(cancelled: bool) -> ChatStreamEvent { + ChatStreamEvent::Done { + tool_calls_made: 0, + iterations_used: 1, + truncated: false, + prompt_tokens: Some(10), + eval_tokens: Some(20), + num_ctx: None, + amended_insight_id: None, + backend_used: "local".into(), + model_used: "m".into(), + cancelled, + } + } + + /// Seed a completed turn (events + terminal Done) directly in the registry. + async fn seed_completed(state: &AppState, id: &str, text_events: usize) { + let entry = Arc::new(TurnEntry::new(id.into(), "/p.jpg".into(), 1)); + for i in 0..text_events { + entry + .push_event(ChatStreamEvent::TextDelta(format!("d{i}"))) + .await; + } + entry.push_event(done(false)).await; + entry.set_terminal_status(TurnStatus::Done); + state.turn_registry.insert(entry).await; + } + + #[test] + fn indexed_frame_stamps_seq_without_clobbering_tool_index() { + // tool_call carries its own pairing `index`; `seq` must be additive. + let frame = render_indexed_frame( + &ChatStreamEvent::ToolCall { + index: 3, + name: "geo".into(), + arguments: serde_json::json!({}), + }, + 42, + ); + assert!(frame.contains("event: tool_call")); + assert!(frame.contains("\"index\":3")); + assert!(frame.contains("\"seq\":42")); + } + + #[actix_rt::test] + async fn replay_unknown_turn_is_404() { + let state = build_state(); + let app = actix_test::init_service( + App::new() + .service(turn_replay_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/insights/chat/turn/nope") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 404); + } + + #[actix_rt::test] + async fn replay_completed_turn_emits_turn_info_and_done_with_seq() { + let state = build_state(); + seed_completed(&state, "t1", 2).await; + + let app = actix_test::init_service( + App::new() + .service(turn_replay_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/insights/chat/turn/t1") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 200); + + let body = String::from_utf8(actix_test::read_body(resp).await.to_vec()).unwrap(); + assert!(body.contains("event: turn_info")); + assert!(body.contains("event: text")); + assert!(body.contains("event: done")); + // Events are seq-stamped 0,1 (text) and 2 (done). + assert!(body.contains("\"seq\":0")); + assert!(body.contains("\"seq\":2")); + // Done payload carries the renamed token fields the client reads. + assert!(body.contains("\"prompt_tokens\":10")); + } + + #[actix_rt::test] + async fn replay_skip_before_query_skips_applied_events() { + let state = build_state(); + seed_completed(&state, "t2", 3).await; // seqs 0,1,2 text; 3 done + + let app = actix_test::init_service( + App::new() + .service(turn_replay_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/insights/chat/turn/t2?skip_before=2") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 200); + + let body = String::from_utf8(actix_test::read_body(resp).await.to_vec()).unwrap(); + // Only seq 2 (last text) and seq 3 (done) should be present. + assert!(body.contains("\"seq\":2")); + assert!(body.contains("\"seq\":3")); + assert!(!body.contains("\"seq\":0")); + assert!(!body.contains("\"seq\":1")); + } + + #[actix_rt::test] + async fn replay_evicted_index_is_410() { + let state = build_state(); + let entry = Arc::new(TurnEntry::new("t3".into(), "/p.jpg".into(), 1)); + // Push past the cap so the front is evicted and base advances. + for i in 0..600 { + entry + .push_event(ChatStreamEvent::TextDelta(format!("d{i}"))) + .await; + } + entry.set_terminal_status(TurnStatus::Done); + state.turn_registry.insert(entry).await; + + let app = actix_test::init_service( + App::new() + .service(turn_replay_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/insights/chat/turn/t3?skip_before=0") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 410); + } + + #[actix_rt::test] + async fn cancel_unknown_turn_is_404() { + let state = build_state(); + let app = actix_test::init_service( + App::new() + .service(cancel_turn_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::delete() + .uri("/insights/chat/turn/nope") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 404); + } + + #[actix_rt::test] + async fn cancel_running_turn_marks_cancelled_and_buffers_terminal() { + let state = build_state(); + let entry = Arc::new(TurnEntry::new("t4".into(), "/p.jpg".into(), 1)); + entry + .push_event(ChatStreamEvent::TextDelta("partial".into())) + .await; + state.turn_registry.insert(entry.clone()).await; + + let app = actix_test::init_service( + App::new() + .service(cancel_turn_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::delete() + .uri("/insights/chat/turn/t4") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 200); + + // Status flipped to Cancelled and a terminal Done(cancelled) buffered + // after the existing event, so a late replay reader still completes. + assert_eq!( + TurnStatus::from(entry.status.load(std::sync::atomic::Ordering::Relaxed)), + TurnStatus::Cancelled + ); + let info = entry.info().await; + assert_eq!(info.total_events_pushed, 2); + } } diff --git a/src/ai/insight_chat.rs b/src/ai/insight_chat.rs index 9837733..f4d478b 100644 --- a/src/ai/insight_chat.rs +++ b/src/ai/insight_chat.rs @@ -9,11 +9,14 @@ use tokio::sync::Mutex as TokioMutex; use crate::ai::backend::{BackendKind, ResolvedBackend, SamplingOverrides}; use crate::ai::insight_generator::InsightGenerator; use crate::ai::llm_client::{ChatMessage, LlmStreamEvent, Tool}; +use crate::ai::turn_registry::TurnEntry; +use crate::ai::turn_registry::TurnRegistry; use crate::database::InsightDao; use crate::database::models::InsertPhotoInsight; use crate::otel::global_tracer; use crate::utils::normalize_path; use futures::stream::{BoxStream, StreamExt}; +use uuid::Uuid; const DEFAULT_MAX_ITERATIONS: usize = 6; const DEFAULT_NUM_CTX: i32 = 8192; @@ -24,6 +27,12 @@ const RESPONSE_HEADROOM_TOKENS: usize = 2048; /// tokenization is model-specific; this avoids carrying tiktoken just for a /// soft bound. const BYTES_PER_TOKEN: usize = 4; +/// Flat token cost charged per inlined image in the truncation budget. A +/// 1024px-longest-edge JPEG (see `load_image_as_base64`) costs vision models on +/// the order of ~1.3K tokens. Crucially, the raw base64 (hundreds of KB of +/// characters) must NOT be counted as text bytes — doing so dwarfs the entire +/// text budget and forces spurious truncation on every turn. +const IMAGE_TOKENS_EACH: usize = 1300; pub type ChatLockMap = Arc>>>>; @@ -107,6 +116,11 @@ impl InsightChatService { } } + /// Accessor for the insight DAO (used by async job completion). + pub fn insight_dao(&self) -> &Arc>> { + &self.insight_dao + } + /// Load the rendered transcript for chat-UI display. Filters internal /// scaffolding (system message, tool turns, tool-dispatch-only assistant /// messages) and drops base64 images from user turns to keep payloads @@ -512,6 +526,15 @@ impl InsightChatService { backend: kind.as_str().to_string(), fewshot_source_ids: None, content_hash: None, + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + system_prompt: req.system_prompt.clone(), + persona_id: req.persona_id.clone(), + prompt_eval_count: None, + eval_count: None, }; let cx = opentelemetry::Context::new(); let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); @@ -522,8 +545,17 @@ impl InsightChatService { } else { let cx = opentelemetry::Context::new(); let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); - dao.update_training_messages(&cx, req.library_id, &normalized, &json) + let rows = dao + .update_training_messages(&cx, req.library_id, &normalized, &json) .map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?; + if rows == 0 { + log::warn!( + "update_training_messages updated 0 rows for {} (lib {}), \ + concurrent regenerate likely flipped is_current", + normalized, + req.library_id + ); + } } Ok(ChatTurnResult { @@ -590,8 +622,17 @@ impl InsightChatService { let cx = opentelemetry::Context::new(); let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); - dao.update_training_messages(&cx, library_id, &normalized, &json) + let rows = dao + .update_training_messages(&cx, library_id, &normalized, &json) .map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?; + if rows == 0 { + log::warn!( + "update_training_messages (rewind) updated 0 rows for {} (lib {}), \ + concurrent regenerate likely flipped is_current", + normalized, + library_id + ); + } Ok(()) } @@ -646,6 +687,626 @@ impl InsightChatService { Ok(rx) } + /// Async turn dispatch: creates a TurnEntry in the registry, spawns the + /// agentic loop on a Tokio task, and returns the turn_id immediately. + /// Events are buffered in the TurnEntry for SSE replay. + pub async fn chat_turn_async( + self: Arc, + registry: Arc, + req: ChatTurnRequest, + ) -> String { + let turn_id = Uuid::new_v4().to_string(); + let entry = Arc::new(TurnEntry::new( + turn_id.clone(), + req.file_path.clone(), + req.library_id, + )); + registry.insert(entry.clone()).await; + + let svc = self.clone(); + let entry_clone = entry.clone(); + let turn_id_for_span = turn_id.clone(); + let library_id = req.library_id; + let handle = tokio::spawn(async move { + // Span covering the whole spawned turn execution. Created here (not + // in the HTTP handler) because the dispatch span ends at the 202 + // response, long before this work runs. + let tracer = global_tracer(); + let mut span = tracer.start("ai.chat.turn.execute"); + span.set_attribute(KeyValue::new("turn_id", turn_id_for_span)); + span.set_attribute(KeyValue::new("library_id", library_id as i64)); + + let result = svc + .run_streaming_turn_with_entry(req, entry_clone.clone()) + .await; + if let Err(ref e) = result { + span.set_attribute(KeyValue::new("status", "error")); + span.set_status(Status::error(format!("{e}"))); + // Push the terminal event BEFORE flipping status: a replay + // reader treats a terminal status with no buffered tail as + // "closed", so the Error must be in the buffer first. + let _ = entry_clone + .push_event(ChatStreamEvent::Error(format!("{}", e))) + .await; + entry_clone.set_terminal_status(crate::ai::turn_registry::TurnStatus::Error); + } else { + span.set_attribute(KeyValue::new("status", "done")); + span.set_status(Status::Ok); + } + }); + + // Install the abort handle so DELETE can actually stop the task. + entry.set_abort_handle(handle.abort_handle()); + + turn_id + } + + /// Variant of `run_streaming_turn` that pushes events to a `TurnEntry` + /// buffer instead of an `mpsc::Sender`. + async fn run_streaming_turn_with_entry( + self: Arc, + req: ChatTurnRequest, + entry: Arc, + ) -> Result<()> { + if req.user_message.trim().is_empty() { + bail!("user_message must not be empty"); + } + if req.user_message.len() > 8192 { + bail!("user_message exceeds 8192 chars"); + } + let normalized = normalize_path(&req.file_path); + + let lock_key = (req.library_id, normalized.clone()); + let entry_lock = { + let mut locks = self.chat_locks.lock().await; + locks + .entry(lock_key.clone()) + .or_insert_with(|| Arc::new(TokioMutex::new(()))) + .clone() + }; + let _guard = entry_lock.lock().await; + + // Look up existing insight scoped to this turn's library_id. + let existing_insight = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.get_current_insight_for_library(&cx, req.library_id, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + }; + + if req.regenerate || existing_insight.is_none() { + return self + .run_bootstrap_streaming_with_entry(req, normalized, entry) + .await; + } + let insight = existing_insight.expect("just checked Some above"); + self.run_continuation_streaming_with_entry(req, normalized, insight, entry) + .await + } + + /// Continuation path with TurnEntry buffer. + async fn run_continuation_streaming_with_entry( + &self, + req: ChatTurnRequest, + normalized: String, + insight: crate::database::models::PhotoInsight, + entry: Arc, + ) -> Result<()> { + let active_persona = req + .persona_id + .clone() + .filter(|s| !s.trim().is_empty()) + .unwrap_or_else(|| "default".to_string()); + let raw_history = insight.training_messages.as_ref().ok_or_else(|| { + anyhow!("insight has no chat history; regenerate this insight in agentic mode") + })?; + let mut messages: Vec = serde_json::from_str(raw_history) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + let stored_backend = insight.backend.clone(); + let effective_backend = req + .backend + .as_deref() + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| stored_backend.clone()); + let kind = BackendKind::parse(&effective_backend)?; + validate_cross_replay(&stored_backend, kind.as_str())?; + + let max_iterations = req + .max_iterations + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .clamp(1, env_max_iterations()); + + let stored_model = insight.model_version.clone(); + let overrides = SamplingOverrides { + model: req + .model + .clone() + .or_else(|| Some(stored_model.clone())) + .filter(|m| !m.is_empty()), + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + }; + let backend = self.generator.resolve_backend(kind, &overrides).await?; + let model_used = backend.model().to_string(); + + let local_first_user_has_image = messages + .iter() + .find(|m| m.role == "user") + .and_then(|m| m.images.as_ref()) + .map(|imgs| !imgs.is_empty()) + .unwrap_or(false); + let offer_describe_tool = backend.images_inline && local_first_user_has_image; + let gate_opts = self.generator.current_gate_opts_for_persona( + offer_describe_tool, + Some((req.user_id, &active_persona)), + ); + let tools = InsightGenerator::build_tool_definitions(gate_opts); + + let image_base64: Option = if offer_describe_tool { + self.generator.load_image_as_base64(&normalized).ok() + } else { + None + }; + + let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize) + .saturating_sub(RESPONSE_HEADROOM_TOKENS); + let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN); + let truncated = apply_context_budget(&mut messages, budget_bytes); + if truncated { + let _ = entry.push_event(ChatStreamEvent::Truncated).await; + } + + messages.push(ChatMessage::user(req.user_message.clone())); + + let override_stash = + apply_system_prompt_override(&mut messages, req.system_prompt.as_deref()); + let original_system_content = annotate_system_with_budget(&mut messages, max_iterations); + + let outcome = self + .run_streaming_agentic_loop_with_entry( + &backend, + &mut messages, + tools, + &image_base64, + &normalized, + req.user_id, + &active_persona, + max_iterations, + &entry, + ) + .await?; + let AgenticLoopOutcome { + tool_calls_made, + iterations_used, + last_prompt_eval_count, + last_eval_count, + final_content, + cancelled, + } = outcome; + + // Turn was cancelled mid-flight: the DELETE handler already pushed the + // terminal event and flipped status. Don't persist a partial turn or + // push a second terminal event. + if cancelled { + return Ok(()); + } + + restore_system_content(&mut messages, original_system_content); + + if !req.amend { + restore_system_prompt_override(&mut messages, override_stash); + } + + let json = serde_json::to_string(&messages) + .map_err(|e| anyhow!("failed to serialize chat history: {}", e))?; + + let mut amended_insight_id: Option = None; + if req.amend { + let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content); + let final_content = body; + + let new_row = InsertPhotoInsight { + library_id: req.library_id, + file_path: normalized.clone(), + title, + summary: final_content.clone(), + generated_at: Utc::now().timestamp(), + model_version: model_used.clone(), + is_current: true, + training_messages: Some(json), + backend: kind.as_str().to_string(), + fewshot_source_ids: None, + content_hash: None, + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + system_prompt: req.system_prompt.clone(), + persona_id: req.persona_id.clone(), + prompt_eval_count: None, + eval_count: None, + }; + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let stored = dao + .store_insight(&cx, new_row) + .map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?; + amended_insight_id = Some(stored.id); + } else { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let rows = dao + .update_training_messages(&cx, req.library_id, &normalized, &json) + .map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?; + if rows == 0 { + log::warn!( + "update_training_messages (stream) updated 0 rows for {} (lib {}), \ + concurrent regenerate likely flipped is_current", + normalized, + req.library_id + ); + } + } + + let _ = entry + .push_event(ChatStreamEvent::Done { + tool_calls_made, + iterations_used, + truncated, + prompt_tokens: last_prompt_eval_count, + eval_tokens: last_eval_count, + num_ctx: req.num_ctx, + amended_insight_id, + backend_used: kind.as_str().to_string(), + model_used, + cancelled: false, + }) + .await; + + entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Done); + Ok(()) + } + + /// Bootstrap path with TurnEntry buffer. + async fn run_bootstrap_streaming_with_entry( + &self, + req: ChatTurnRequest, + normalized: String, + entry: Arc, + ) -> Result<()> { + let active_persona = req + .persona_id + .clone() + .filter(|s| !s.trim().is_empty()) + .unwrap_or_else(|| "default".to_string()); + let effective_backend = resolve_bootstrap_backend(req.backend.as_deref())?; + let kind = BackendKind::parse(&effective_backend)?; + + let max_iterations = req + .max_iterations + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .clamp(1, env_max_iterations()); + + let overrides = SamplingOverrides { + model: req.model.clone().filter(|m| !m.is_empty()), + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + }; + let backend = self.generator.resolve_backend(kind, &overrides).await?; + let model_used = backend.model().to_string(); + + let image_base64: Option = self.generator.load_image_as_base64(&normalized).ok(); + + let exif = self.generator.fetch_exif(&normalized); + let date_taken_str = resolve_date_taken_for_context(&exif, &normalized); + let gps = exif + .as_ref() + .and_then(|e| match (e.gps_latitude, e.gps_longitude) { + (Some(lat), Some(lon)) => Some((lat as f64, lon as f64)), + _ => None, + }); + + let visual_block = if !backend.images_inline { + match image_base64.as_deref() { + Some(b64) => match backend.local().describe_image(b64).await { + Ok(desc) => { + format!("Visual description (from local vision model):\n{}\n", desc) + } + Err(e) => { + log::warn!("{} bootstrap: describe_image failed: {}", kind.as_str(), e); + String::new() + } + }, + None => String::new(), + } + } else { + String::new() + }; + + let offer_describe_tool = backend.images_inline && image_base64.is_some(); + let gate_opts = self.generator.current_gate_opts_for_persona( + offer_describe_tool, + Some((req.user_id, &active_persona)), + ); + let tools = InsightGenerator::build_tool_definitions(gate_opts); + + let persona = resolve_bootstrap_system_prompt(req.system_prompt.as_deref()); + let system_content = build_bootstrap_system_message( + &persona, + &normalized, + date_taken_str.as_deref(), + gps, + &visual_block, + ); + let system_msg = ChatMessage::system(system_content); + let mut user_msg = ChatMessage::user(req.user_message.clone()); + if backend.images_inline + && let Some(ref img) = image_base64 + { + user_msg.images = Some(vec![img.clone()]); + } + let mut messages = vec![system_msg, user_msg]; + + let outcome = self + .run_streaming_agentic_loop_with_entry( + &backend, + &mut messages, + tools, + &image_base64, + &normalized, + req.user_id, + &active_persona, + max_iterations, + &entry, + ) + .await?; + let AgenticLoopOutcome { + tool_calls_made, + iterations_used, + last_prompt_eval_count, + last_eval_count, + final_content, + cancelled, + } = outcome; + + // Turn was cancelled mid-flight: the DELETE handler already pushed the + // terminal event and flipped status. Don't persist a partial turn or + // push a second terminal event. + if cancelled { + return Ok(()); + } + + let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content); + + let json = serde_json::to_string(&messages) + .map_err(|e| anyhow!("failed to serialize chat history: {}", e))?; + let new_row = InsertPhotoInsight { + library_id: req.library_id, + file_path: normalized.clone(), + title, + summary: body, + generated_at: Utc::now().timestamp(), + model_version: model_used.clone(), + is_current: true, + training_messages: Some(json), + backend: kind.as_str().to_string(), + fewshot_source_ids: None, + content_hash: None, + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + system_prompt: req.system_prompt.clone(), + persona_id: req.persona_id.clone(), + prompt_eval_count: None, + eval_count: None, + }; + let stored = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.store_insight(&cx, new_row) + .map_err(|e| anyhow!("failed to store bootstrap insight: {:?}", e))? + }; + + let _ = entry + .push_event(ChatStreamEvent::Done { + tool_calls_made, + iterations_used, + truncated: false, + prompt_tokens: last_prompt_eval_count, + eval_tokens: last_eval_count, + num_ctx: req.num_ctx, + amended_insight_id: Some(stored.id), + backend_used: kind.as_str().to_string(), + model_used, + cancelled: false, + }) + .await; + + entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Done); + Ok(()) + } + + /// Agentic loop variant that pushes events to a `TurnEntry` buffer. + async fn run_streaming_agentic_loop_with_entry( + &self, + backend: &ResolvedBackend, + messages: &mut Vec, + tools: Vec, + image_base64: &Option, + normalized: &str, + user_id: i32, + active_persona: &str, + max_iterations: usize, + entry: &Arc, + ) -> Result { + let mut tool_calls_made = 0usize; + let mut iterations_used = 0usize; + let mut last_prompt_eval_count: Option = None; + let mut last_eval_count: Option = None; + let mut final_content = String::new(); + + for iteration in 0..max_iterations { + // Cooperative cancellation: a DELETE flips status out of Running + // (and aborts this task). Check at the iteration boundary so an + // in-flight tool round finishes cleanly rather than mid-write. + if !entry.is_running() { + return Ok(AgenticLoopOutcome { + tool_calls_made, + iterations_used, + last_prompt_eval_count, + last_eval_count, + final_content, + cancelled: true, + }); + } + + iterations_used = iteration + 1; + let _ = entry + .push_event(ChatStreamEvent::IterationStart { + n: iterations_used, + max: max_iterations, + }) + .await; + + let mut stream = backend + .chat() + .chat_with_tools_stream(messages.clone(), tools.clone()) + .await?; + + let mut final_message: Option = None; + while let Some(ev) = stream.next().await { + let ev = ev?; + match ev { + LlmStreamEvent::TextDelta(delta) => { + let _ = entry.push_event(ChatStreamEvent::TextDelta(delta)).await; + } + LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + } => { + last_prompt_eval_count = prompt_eval_count; + last_eval_count = eval_count; + final_message = Some(message); + break; + } + } + } + let mut response = + final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?; + + if let Some(ref mut tcs) = response.tool_calls { + for tc in tcs.iter_mut() { + if !tc.function.arguments.is_object() { + tc.function.arguments = serde_json::Value::Object(Default::default()); + } + } + } + + messages.push(response.clone()); + + if let Some(ref tool_calls) = response.tool_calls + && !tool_calls.is_empty() + { + for tool_call in tool_calls { + tool_calls_made += 1; + let call_index = tool_calls_made - 1; + let _ = entry + .push_event(ChatStreamEvent::ToolCall { + index: call_index, + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + }) + .await; + let cx = opentelemetry::Context::new(); + let result = self + .generator + .execute_tool( + &tool_call.function.name, + &tool_call.function.arguments, + backend, + image_base64, + normalized, + user_id, + active_persona, + &cx, + ) + .await; + let (result_preview, result_truncated) = truncate_tool_result(&result); + let _ = entry + .push_event(ChatStreamEvent::ToolResult { + index: call_index, + name: tool_call.function.name.clone(), + result: result_preview, + result_truncated, + }) + .await; + messages.push(ChatMessage::tool_result(result)); + } + continue; + } + + final_content = response.content; + break; + } + + // No-tools fallback + if final_content.is_empty() { + let synthetic_idx = messages.len(); + messages.push(ChatMessage::user( + "Please write your final answer now without calling any more tools.", + )); + let mut stream = backend + .chat() + .chat_with_tools_stream(messages.clone(), vec![]) + .await?; + let mut final_message: Option = None; + while let Some(ev) = stream.next().await { + let ev = ev?; + match ev { + LlmStreamEvent::TextDelta(delta) => { + let _ = entry.push_event(ChatStreamEvent::TextDelta(delta)).await; + } + LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + } => { + last_prompt_eval_count = prompt_eval_count; + last_eval_count = eval_count; + final_message = Some(message); + break; + } + } + } + let final_response = + final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?; + final_content = final_response.content.clone(); + messages.push(final_response); + messages.remove(synthetic_idx); + } + + Ok(AgenticLoopOutcome { + tool_calls_made, + iterations_used, + last_prompt_eval_count, + last_eval_count, + final_content, + cancelled: false, + }) + } + async fn run_streaming_turn( self: Arc, req: ChatTurnRequest, @@ -804,6 +1465,8 @@ impl InsightChatService { last_prompt_eval_count, last_eval_count, final_content, + // The mpsc (legacy) path has no cancellation channel. + cancelled: _, } = outcome; // Drop the per-turn iteration-budget note before persisting so it @@ -841,6 +1504,15 @@ impl InsightChatService { backend: kind.as_str().to_string(), fewshot_source_ids: None, content_hash: None, + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + system_prompt: req.system_prompt.clone(), + persona_id: req.persona_id.clone(), + prompt_eval_count: None, + eval_count: None, }; let cx = opentelemetry::Context::new(); let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); @@ -851,8 +1523,17 @@ impl InsightChatService { } else { let cx = opentelemetry::Context::new(); let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); - dao.update_training_messages(&cx, req.library_id, &normalized, &json) + let rows = dao + .update_training_messages(&cx, req.library_id, &normalized, &json) .map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?; + if rows == 0 { + log::warn!( + "update_training_messages (stream) updated 0 rows for {} (lib {}), \ + concurrent regenerate likely flipped is_current", + normalized, + req.library_id + ); + } } let _ = tx @@ -866,6 +1547,7 @@ impl InsightChatService { amended_insight_id, backend_used: kind.as_str().to_string(), model_used, + cancelled: false, }) .await; @@ -976,10 +1658,10 @@ impl InsightChatService { ); let system_msg = ChatMessage::system(system_content); let mut user_msg = ChatMessage::user(req.user_message.clone()); - if backend.images_inline { - if let Some(ref img) = image_base64 { - user_msg.images = Some(vec![img.clone()]); - } + if backend.images_inline + && let Some(ref img) = image_base64 + { + user_msg.images = Some(vec![img.clone()]); } let mut messages = vec![system_msg, user_msg]; @@ -1002,6 +1684,8 @@ impl InsightChatService { last_prompt_eval_count, last_eval_count, final_content, + // The mpsc (legacy) path has no cancellation channel. + cancelled: _, } = outcome; let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content); @@ -1020,6 +1704,15 @@ impl InsightChatService { backend: kind.as_str().to_string(), fewshot_source_ids: None, content_hash: None, + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + system_prompt: req.system_prompt.clone(), + persona_id: req.persona_id.clone(), + prompt_eval_count: None, + eval_count: None, }; let stored = { let cx = opentelemetry::Context::new(); @@ -1042,6 +1735,7 @@ impl InsightChatService { amended_insight_id: Some(stored.id), backend_used: kind.as_str().to_string(), model_used, + cancelled: false, }) .await; @@ -1215,6 +1909,7 @@ impl InsightChatService { last_prompt_eval_count, last_eval_count, final_content, + cancelled: false, }) } } @@ -1343,6 +2038,10 @@ struct AgenticLoopOutcome { last_prompt_eval_count: Option, last_eval_count: Option, final_content: String, + /// True when the loop exited early because the turn was cancelled + /// (status flipped out of `Running`). Callers skip persistence and the + /// terminal `Done` push — the cancel handler owns the terminal event. + cancelled: bool, } /// Events emitted by `chat_turn_stream`. One stream per turn; ends after @@ -1397,6 +2096,10 @@ pub enum ChatStreamEvent { amended_insight_id: Option, backend_used: String, model_used: String, + /// True only for the synthetic terminal event emitted by the cancel + /// handler, so clients can distinguish a user-cancelled turn from a + /// natural completion. Always false on the normal success path. + cancelled: bool, }, /// Terminal failure event. No further events follow. Error(String), @@ -1662,10 +2365,32 @@ pub(crate) fn apply_context_budget(messages: &mut Vec, budget_bytes dropped_any } +/// Estimate the serialized byte size of `messages` for the truncation budget, +/// EXCLUDING inlined base64 image payloads. Images are charged a flat +/// `IMAGE_TOKENS_EACH` instead: their base64 is hundreds of KB of characters +/// that have no relation to the text token pressure we're budgeting against, +/// and counting them verbatim makes a single photo exceed the entire budget, +/// spuriously trimming all history on every turn. fn estimate_bytes(messages: &[ChatMessage]) -> usize { - serde_json::to_string(messages) + let mut image_count = 0usize; + // Clone with image payloads stripped so they don't inflate the byte count. + // We still account for the (small) non-image fields verbatim. + let stripped: Vec = messages + .iter() + .map(|m| { + if let Some(imgs) = m.images.as_ref() { + image_count += imgs.len(); + } + ChatMessage { + images: None, + ..m.clone() + } + }) + .collect(); + let text_bytes = serde_json::to_string(&stripped) .map(|s| s.len()) - .unwrap_or(0) + .unwrap_or(0); + text_bytes + image_count * IMAGE_TOKENS_EACH * BYTES_PER_TOKEN } #[cfg(test)] @@ -1725,6 +2450,35 @@ mod tests { assert_eq!(msgs.len(), 2); } + #[test] + fn image_payload_excluded_from_budget() { + // First user message carries a ~400KB base64 image but only a little + // text. Counting the base64 verbatim (old behavior) dwarfs the budget + // and forces all tool history to be dropped on every turn. The image + // must instead be charged a flat per-image cost so a short + // conversation comfortably fits. + let mut user = ChatMessage::user("describe this"); + user.images = Some(vec!["A".repeat(400_000)]); + let mut msgs = vec![ + ChatMessage::system("sys"), + user, + assistant_with_tool_call("get_x"), + ChatMessage::tool_result("small x result"), + assistant_text("here is the answer"), + ]; + + // Default budget: (8192 - 2048) * 4 bytes ≈ 24KB. The text easily fits; + // only the (excluded) image bytes could blow it. + let budget_bytes = (DEFAULT_NUM_CTX as usize - RESPONSE_HEADROOM_TOKENS) * BYTES_PER_TOKEN; + let original_len = msgs.len(); + let dropped = apply_context_budget(&mut msgs, budget_bytes); + + assert!(!dropped, "short conversation with one image must not truncate"); + assert_eq!(msgs.len(), original_len, "no messages should be dropped"); + // Sanity: the flat image charge is accounted for but stays well under budget. + assert!(estimate_bytes(&msgs) <= budget_bytes); + } + #[test] fn truncation_returns_false_with_no_droppable_pairs() { // Only system + user, no tool-call turns to drop. diff --git a/src/ai/insight_generator.rs b/src/ai/insight_generator.rs index 8e39c59..4f15ef4 100644 --- a/src/ai/insight_generator.rs +++ b/src/ai/insight_generator.rs @@ -196,6 +196,12 @@ impl InsightGenerator { } } + /// Accessor for the insight DAO (used by async job completion to + /// look up the stored insight id). + pub fn insight_dao(&self) -> &Arc>> { + &self.insight_dao + } + /// Whether the optional Apollo Places integration is wired up. Drives /// tool-definition gating (no point offering `get_personal_place_at` /// when Apollo is unreachable) — exposed publicly so `insight_chat` @@ -1426,6 +1432,15 @@ impl InsightGenerator { backend: "local".to_string(), fewshot_source_ids: None, content_hash: None, + num_ctx, + temperature, + top_p, + top_k, + min_p, + system_prompt: custom_system_prompt.clone(), + persona_id: None, + prompt_eval_count: None, + eval_count: None, }; let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); @@ -3795,7 +3810,7 @@ Return ONLY the summary, nothing else."#, fewshot_source_ids: Vec, user_id: i32, persona_id: String, - ) -> Result<(Option, Option)> { + ) -> Result<(Option, Option, Option)> { let tracer = global_tracer(); let current_cx = opentelemetry::Context::current(); let mut span = tracer.start_with_context("ai.insight.generate_agentic", ¤t_cx); @@ -4026,10 +4041,10 @@ Return ONLY the summary, nothing else."#, // user message; describe-then-inline → text was already injected. let system_msg = ChatMessage::system(system_content); let mut user_msg = ChatMessage::user(user_content); - if backend.images_inline { - if let Some(ref img) = image_base64 { - user_msg.images = Some(vec![img.clone()]); - } + if backend.images_inline + && let Some(ref img) = image_base64 + { + user_msg.images = Some(vec![img.clone()]); } let mut messages = vec![system_msg, user_msg]; @@ -4170,6 +4185,15 @@ Return ONLY the summary, nothing else."#, backend: kind.as_str().to_string(), fewshot_source_ids: fewshot_source_ids_json, content_hash: None, + num_ctx, + temperature, + top_p, + top_k, + min_p, + system_prompt: custom_system_prompt.clone(), + persona_id: Some(persona_id.clone()), + prompt_eval_count: last_prompt_eval_count, + eval_count: last_eval_count, }; let stored = { @@ -4207,7 +4231,11 @@ Return ONLY the summary, nothing else."#, } } - Ok((last_prompt_eval_count, last_eval_count)) + Ok(( + Some(stored_insight.id), + last_prompt_eval_count, + last_eval_count, + )) } /// Reverse geocode GPS coordinates to human-readable place names diff --git a/src/ai/mod.rs b/src/ai/mod.rs index c991c71..e9bec09 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -11,6 +11,7 @@ pub mod llm_client; pub mod ollama; pub mod openrouter; pub mod sms_client; +pub mod turn_registry; // strip_summary_boilerplate is used by binaries (test_daily_summary), not the library #[allow(unused_imports)] @@ -19,10 +20,11 @@ pub use daily_summary_job::{ generate_daily_summaries, strip_summary_boilerplate, }; pub use handlers::{ - chat_history_handler, chat_rewind_handler, chat_stream_handler, chat_turn_handler, - delete_insight_handler, export_training_data_handler, generate_agentic_insight_handler, - generate_insight_handler, get_all_insights_handler, get_available_models_handler, - get_insight_handler, get_openrouter_models_handler, rate_insight_handler, + cancel_generation_handler, cancel_turn_handler, chat_history_handler, chat_rewind_handler, + chat_stream_handler, chat_turn_handler, delete_insight_handler, export_training_data_handler, + generate_agentic_insight_handler, generate_insight_handler, generation_status_handler, + get_all_insights_handler, get_available_models_handler, get_insight_handler, + get_openrouter_models_handler, rate_insight_handler, turn_async_handler, turn_replay_handler, }; pub use insight_generator::InsightGenerator; pub use llamacpp::LlamaCppClient; @@ -77,8 +79,9 @@ pub async fn embed_one( .pop() .ok_or_else(|| anyhow::anyhow!("llama-swap returned no embeddings")); } - log::warn!( - "LLM_BACKEND=llamacpp but LlamaCppClient is unconfigured; falling back to Ollama embeddings" + anyhow::bail!( + "LLM_BACKEND=llamacpp but LlamaCppClient is unconfigured — \ + set LLAMA_SWAP_URL or switch to LLM_BACKEND=ollama" ); } ollama.generate_embedding(text).await diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs index c56e1e7..680668f 100644 --- a/src/ai/ollama.rs +++ b/src/ai/ollama.rs @@ -424,10 +424,7 @@ impl OllamaClient { self.generate_with_images(prompt, system, None).await } - /// Variant of `generate` that sets Ollama's top-level `think: false`. - /// Used by latency-sensitive callers like the rerank pass, where the - /// task has nothing to reason about and chain-of-thought tokens are - /// wasted wall time. Server-side no-op on non-reasoning models. + #[allow(dead_code)] pub async fn generate_no_think(&self, prompt: &str, system: Option<&str>) -> Result { self.generate_with_options(prompt, system, None, Some(false)) .await diff --git a/src/ai/turn_registry.rs b/src/ai/turn_registry.rs new file mode 100644 index 0000000..2a5d432 --- /dev/null +++ b/src/ai/turn_registry.rs @@ -0,0 +1,748 @@ +use crate::ai::insight_chat::ChatStreamEvent; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::time::Instant; +use tokio::sync::{Mutex, Notify}; +use tokio::task::AbortHandle; + +/// Maximum number of events buffered per turn. Agentic turns typically +/// produce ~120 events; 500 provides 4× headroom. When exceeded, oldest +/// events are evicted from the front. +const MAX_BUFFERED_EVENTS: usize = 500; + +/// Turn status codes used by `TurnEntry::status`. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TurnStatus { + Running = 0, + Done = 1, + Error = 2, + Cancelled = 3, +} + +impl From for TurnStatus { + fn from(v: u32) -> Self { + match v { + 0 => TurnStatus::Running, + 1 => TurnStatus::Done, + 2 => TurnStatus::Error, + 3 => TurnStatus::Cancelled, + _ => TurnStatus::Running, + } + } +} + +impl TurnStatus { + pub fn as_str(&self) -> &'static str { + match self { + TurnStatus::Running => "running", + TurnStatus::Done => "done", + TurnStatus::Error => "error", + TurnStatus::Cancelled => "cancelled", + } + } +} + +/// Shared metadata about a turn, read by the SSE replay handler to emit +/// the initial `turn_info` event and to decide whether to wait for new +/// events or close immediately. +#[derive(Debug, Clone)] +pub struct TurnInfo { + pub turn_id: String, + pub file_path: String, + pub library_id: i32, + pub status: TurnStatus, + pub total_events_pushed: u32, + pub buffered_count: u32, +} + +/// Result of reading events at or after an absolute `skip_before` index. +#[derive(Debug)] +pub enum ReplayOutcome { + /// New events are available. `next_skip` is the absolute index to pass + /// on the next read (i.e. one past the last event returned). + Events { + events: Vec, + next_skip: u32, + }, + /// The reader is caught up to the live edge — no events past `skip_before` + /// yet. `next_skip` is the current high-water mark. + CaughtUp { next_skip: u32 }, + /// `skip_before` points below the buffer's base index: the requested + /// events were evicted. Maps to HTTP 410 Gone. + Gone, +} + +/// Per-turn state shared between the agentic loop (writer) and all SSE +/// replay connections (readers). +pub struct TurnEntry { + pub turn_id: String, + pub file_path: String, + pub library_id: i32, + /// Shared event buffer — multiple SSE connections can read independently. + /// Each connection tracks its own `skip_before` offset. + events: Mutex>, + /// Monotonic counter: total events pushed (may exceed events.len() + /// due to eviction). Used for skip_before indexing. + total_events_pushed: AtomicU32, + /// The event index that this entry started with. Adjusts on eviction + /// so that `skip_before` stays absolute across connections. + base_index: AtomicU32, + pub status: AtomicU32, + /// Abort handle for the spawned agentic task, set once after spawn. + /// Behind a std `Mutex` because the entry is shared via `Arc` and the + /// handle is installed after the entry is already in the registry. + abort_handle: StdMutex>, + pub created_at: Instant, + notify: Arc, +} + +impl TurnEntry { + pub fn new(turn_id: String, file_path: String, library_id: i32) -> Self { + Self { + turn_id, + file_path, + library_id, + events: Mutex::new(Vec::new()), + total_events_pushed: AtomicU32::new(0), + base_index: AtomicU32::new(0), + status: AtomicU32::new(TurnStatus::Running as u32), + abort_handle: StdMutex::new(None), + created_at: Instant::now(), + notify: Arc::new(Notify::new()), + } + } + + /// Install the abort handle for the spawned agentic task. Called once, + /// right after the task is spawned. + pub fn set_abort_handle(&self, handle: AbortHandle) { + *self.abort_handle.lock().expect("abort_handle poisoned") = Some(handle); + } + + /// Abort the spawned agentic task, if a handle was installed. Returns + /// `true` if a task was aborted. + pub fn abort(&self) -> bool { + if let Some(handle) = self + .abort_handle + .lock() + .expect("abort_handle poisoned") + .take() + { + handle.abort(); + true + } else { + false + } + } + + /// Push an event into the buffer. Evicts oldest events if the buffer + /// exceeds `MAX_BUFFERED_EVENTS`. Notifies all waiting SSE connections. + pub async fn push_event(&self, event: ChatStreamEvent) { + { + let mut events = self.events.lock().await; + + // Evict oldest events if we've hit the cap. + if events.len() >= MAX_BUFFERED_EVENTS { + // Drop the oldest event to make room and advance the base + // index so skip_before stays absolute across connections. + events.remove(0); + self.base_index.fetch_add(1, Ordering::Relaxed); + } + + events.push(event); + // Increment while holding the buffer lock so the counter stays in + // lock-step with the buffer even if multiple writers ever exist. + self.total_events_pushed.fetch_add(1, Ordering::Relaxed); + } + + self.notify.notify_waiters(); + } + + /// Get a snapshot of turn metadata for the `turn_info` SSE event. + pub async fn info(&self) -> TurnInfo { + let events = self.events.lock().await; + let buffered = events.len() as u32; + let total = self.total_events_pushed.load(Ordering::Relaxed); + drop(events); + + TurnInfo { + turn_id: self.turn_id.clone(), + file_path: self.file_path.clone(), + library_id: self.library_id, + status: self.status.load(Ordering::Relaxed).into(), + total_events_pushed: total, + buffered_count: buffered, + } + } + + /// Set the terminal status and notify all waiters. + pub fn set_terminal_status(&self, status: TurnStatus) { + self.status.store(status as u32, Ordering::Relaxed); + self.notify.notify_waiters(); + } + + /// Read buffered events at or after absolute index `skip_before` without + /// waiting. Distinguishes "evicted" (Gone) from "caught up" (no new + /// events yet) — the previous boolean/`Option` API conflated the two. + pub async fn replay_from(&self, skip_before: u32) -> ReplayOutcome { + let events = self.events.lock().await; + let base = self.base_index.load(Ordering::Relaxed); + + // The buffer holds absolute indices [base, base + len). A request + // below `base` asked for events that have been evicted. + if skip_before < base { + return ReplayOutcome::Gone; + } + + let offset = (skip_before - base) as usize; + let next_skip = base + events.len() as u32; + if offset >= events.len() { + // Caught up to (or past) the live edge — nothing new yet. + return ReplayOutcome::CaughtUp { next_skip }; + } + + ReplayOutcome::Events { + events: events[offset..].to_vec(), + next_skip, + } + } + + /// Wait for the next batch of events past `skip_before`, the turn to + /// finish, or eviction. Returns: + /// - `Events` when new events are available (drained before any terminal + /// signal so the final `Done`/`Error` is never dropped), + /// - `CaughtUp` only when the turn has reached a terminal status and the + /// reader is fully drained (the caller should close the stream), + /// - `Gone` when `skip_before` points into evicted territory. + pub async fn next_batch(&self, skip_before: u32) -> ReplayOutcome { + loop { + // Register interest BEFORE inspecting state so a push/terminal that + // races between our read and our await can't be lost (Notify's + // `notify_waiters` does not store a permit). + let notified = self.notify.notified(); + tokio::pin!(notified); + notified.as_mut().enable(); + + match self.replay_from(skip_before).await { + ReplayOutcome::CaughtUp { next_skip } => { + // No new events. If the turn is finished, every event + // (including the terminal one) has already been drained + // above on a prior call, so signal the caller to close. + if !self.is_running() { + return ReplayOutcome::CaughtUp { next_skip }; + } + // Still running — wait for the next push or terminal. + } + other => return other, // Events or Gone + } + + notified.await; + } + } + + /// Check if this turn is still running. + pub fn is_running(&self) -> bool { + self.status.load(Ordering::Relaxed) == TurnStatus::Running as u32 + } +} + +/// In-memory registry of all active chat turns. Injected into `AppState` +/// and shared across all handlers. +pub struct TurnRegistry { + entries: Mutex>>, + timeout_secs: u64, +} + +impl TurnRegistry { + pub fn new(timeout_secs: u64) -> Self { + Self { + entries: Mutex::new(HashMap::new()), + timeout_secs, + } + } + + /// Returns the cleanup timeout in seconds. + pub fn timeout_secs(&self) -> u64 { + self.timeout_secs + } + + /// Insert a new turn entry. Returns the turn_id. + pub async fn insert(&self, entry: Arc) -> String { + let turn_id = entry.turn_id.clone(); + let mut entries = self.entries.lock().await; + entries.insert(turn_id.clone(), entry); + turn_id + } + + /// Look up a turn by id. Returns None if not found or expired. + pub async fn get(&self, turn_id: &str) -> Option> { + let entries = self.entries.lock().await; + entries.get(turn_id).cloned() + } + + /// Clean up stale entries older than the timeout. Returns the count of + /// entries removed. + pub async fn cleanup_stale(&self) -> usize { + let mut entries = self.entries.lock().await; + let _now = Instant::now(); + let stale: Vec = entries + .iter() + .filter(|(_, entry)| entry.created_at.elapsed().as_secs() > self.timeout_secs) + .map(|(id, _)| id.clone()) + .collect(); + + for id in &stale { + entries.remove(id); + } + + if !stale.is_empty() { + log::info!( + "TurnRegistry: cleaned up {} stale entries (timeout={}s)", + stale.len(), + self.timeout_secs + ); + } + + stale.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::insight_chat::ChatStreamEvent; + use std::time::Duration; + + /// Unwrap the events from a `ReplayOutcome::Events`, panicking otherwise. + fn events_of(outcome: ReplayOutcome) -> Vec { + match outcome { + ReplayOutcome::Events { events, .. } => events, + other => panic!("expected Events, got {other:?}"), + } + } + + // ── TurnStatus ────────────────────────────────────────────────── + + #[test] + fn turn_status_from_u32_valid_values() { + assert_eq!(TurnStatus::from(0), TurnStatus::Running); + assert_eq!(TurnStatus::from(1), TurnStatus::Done); + assert_eq!(TurnStatus::from(2), TurnStatus::Error); + assert_eq!(TurnStatus::from(3), TurnStatus::Cancelled); + } + + #[test] + fn turn_status_from_u32_unknown_defaults_to_running() { + assert_eq!(TurnStatus::from(4), TurnStatus::Running); + assert_eq!(TurnStatus::from(u32::MAX), TurnStatus::Running); + } + + #[test] + fn turn_status_as_str() { + assert_eq!(TurnStatus::Running.as_str(), "running"); + assert_eq!(TurnStatus::Done.as_str(), "done"); + assert_eq!(TurnStatus::Error.as_str(), "error"); + assert_eq!(TurnStatus::Cancelled.as_str(), "cancelled"); + } + + // ── TurnEntry ─────────────────────────────────────────────────── + + #[tokio::test] + async fn turn_entry_push_and_replay() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + + entry + .push_event(ChatStreamEvent::TextDelta("hello".to_string())) + .await; + entry + .push_event(ChatStreamEvent::TextDelta(" world".to_string())) + .await; + + let events = events_of(entry.replay_from(0).await); + assert_eq!(events.len(), 2); + } + + #[tokio::test] + async fn turn_entry_replay_with_skip() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + + for i in 0..5 { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + + // skip_before=0 → all 5 events + let all = events_of(entry.replay_from(0).await); + assert_eq!(all.len(), 5); + + // skip_before=2 → events 2,3,4 (3 events) + let skipped = events_of(entry.replay_from(2).await); + assert_eq!(skipped.len(), 3); + + // skip_before=5 → caught up to the live edge (not Gone). + assert!(matches!( + entry.replay_from(5).await, + ReplayOutcome::CaughtUp { next_skip: 5 } + )); + } + + #[tokio::test] + async fn turn_entry_replay_empty_by_default() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + // Empty buffer with skip_before=0 → caught up (nothing to replay yet). + assert!(matches!( + entry.replay_from(0).await, + ReplayOutcome::CaughtUp { next_skip: 0 } + )); + } + + #[tokio::test] + async fn turn_entry_is_running_initially() { + let entry = TurnEntry::new("t1".to_string(), "/photo.jpg".to_string(), 1); + assert!(entry.is_running()); + } + + #[tokio::test] + async fn turn_entry_set_terminal_status() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + assert!(entry.is_running()); + entry.set_terminal_status(TurnStatus::Done); + assert!(!entry.is_running()); + } + + #[tokio::test] + async fn turn_entry_info() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 42, + )); + + entry + .push_event(ChatStreamEvent::TextDelta("x".to_string())) + .await; + entry.set_terminal_status(TurnStatus::Done); + + let info = entry.info().await; + assert_eq!(info.turn_id, "t1"); + assert_eq!(info.file_path, "/photo.jpg"); + assert_eq!(info.library_id, 42); + assert_eq!(info.status, TurnStatus::Done); + assert_eq!(info.total_events_pushed, 1); + assert_eq!(info.buffered_count, 1); + } + + #[tokio::test] + async fn turn_entry_eviction_caps_buffer() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + + // Push MAX_BUFFERED_EVENTS + 10 events. + for i in 0..(MAX_BUFFERED_EVENTS + 10) { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + + // Asking from absolute 0 after eviction is Gone (0-9 were dropped). + assert!(matches!(entry.replay_from(0).await, ReplayOutcome::Gone)); + + // Reading from the new base (10) returns the full capped buffer. + let events = events_of(entry.replay_from(10).await); + assert_eq!(events.len(), MAX_BUFFERED_EVENTS); + + // First event should be at index 10 (0-9 were evicted). + if let ChatStreamEvent::TextDelta(s) = &events[0] { + assert_eq!(s, "e10"); + } else { + panic!("expected TextDelta"); + } + + // Last event should be at index MAX_BUFFERED_EVENTS + 9. + if let ChatStreamEvent::TextDelta(s) = &events[events.len() - 1] { + assert_eq!(s, &format!("e{}", MAX_BUFFERED_EVENTS + 9)); + } else { + panic!("expected TextDelta"); + } + } + + #[tokio::test] + async fn turn_entry_replay_evicted_index_is_gone() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + + // Push one past the cap so exactly one event (index 0) is evicted. + for i in 0..=MAX_BUFFERED_EVENTS { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + + // Base is now 1; asking from absolute 0 is evicted territory → Gone. + assert!(matches!(entry.replay_from(0).await, ReplayOutcome::Gone)); + + // skip_before = MAX_BUFFERED_EVENTS → last event only (index valid). + let last = events_of(entry.replay_from(MAX_BUFFERED_EVENTS as u32).await); + assert_eq!(last.len(), 1); + + // skip_before = MAX_BUFFERED_EVENTS + 1 → caught up to the live edge. + assert!(matches!( + entry.replay_from((MAX_BUFFERED_EVENTS + 1) as u32).await, + ReplayOutcome::CaughtUp { .. } + )); + } + + // ── TurnRegistry ──────────────────────────────────────────────── + + #[tokio::test] + async fn turn_registry_insert_and_get() { + let registry = TurnRegistry::new(300); + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + let id = registry.insert(entry).await; + assert_eq!(id, "t1"); + + let retrieved = registry.get("t1").await; + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().turn_id, "t1"); + } + + #[tokio::test] + async fn turn_registry_get_nonexistent_returns_none() { + let registry = TurnRegistry::new(300); + assert!(registry.get("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn turn_registry_cleanup_stale_removes_old_entries() { + let registry = TurnRegistry::new(0); + let mut entry = TurnEntry::new("t1".to_string(), "/photo.jpg".to_string(), 1); + entry.created_at = Instant::now() - Duration::from_secs(1); + registry.insert(Arc::new(entry)).await; + + let cleaned = registry.cleanup_stale().await; + assert_eq!(cleaned, 1); + assert!(registry.get("t1").await.is_none()); + } + + #[tokio::test] + async fn turn_registry_cleanup_stale_preserves_recent() { + let registry = TurnRegistry::new(3600); // 1 hour + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + registry.insert(entry).await; + + let cleaned = registry.cleanup_stale().await; + assert_eq!(cleaned, 0); + assert!(registry.get("t1").await.is_some()); + } + + #[tokio::test] + async fn turn_registry_cleanup_stale_multiple() { + let registry = TurnRegistry::new(0); + + for i in 0..5 { + let mut entry = TurnEntry::new(format!("t{i}"), "/photo.jpg".to_string(), 1); + entry.created_at = Instant::now() - Duration::from_secs(1); + registry.insert(Arc::new(entry)).await; + } + + let cleaned = registry.cleanup_stale().await; + assert_eq!(cleaned, 5); + } + + #[tokio::test] + async fn turn_registry_timeout_secs() { + let registry = TurnRegistry::new(600); + assert_eq!(registry.timeout_secs(), 600); + } + + // ── next_batch / live replay ──────────────────────────────────── + + /// Drain a turn the way the SSE replay handler does: pull batches via + /// `next_batch` until the turn is finished and fully drained. + async fn drain_to_end(entry: Arc) -> Vec { + let mut out = Vec::new(); + let mut skip = 0u32; + while let ReplayOutcome::Events { events, next_skip } = entry.next_batch(skip).await { + out.extend(events); + skip = next_skip; + } + out + } + + fn is_terminal(ev: &ChatStreamEvent) -> bool { + matches!(ev, ChatStreamEvent::Done { .. } | ChatStreamEvent::Error(_)) + } + + /// The core guarantee behind the replay rewrite: a reader waiting on + /// `next_batch` always receives the terminal event, even though the + /// writer flips status to terminal immediately after pushing it. + #[tokio::test] + async fn next_batch_always_delivers_terminal_event() { + for _ in 0..50 { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + + let writer = entry.clone(); + let w = tokio::spawn(async move { + writer + .push_event(ChatStreamEvent::IterationStart { n: 1, max: 6 }) + .await; + writer + .push_event(ChatStreamEvent::TextDelta("hi".into())) + .await; + // Push terminal then flip status with no await between — the + // race that previously dropped the Done on the reader side. + writer + .push_event(ChatStreamEvent::Done { + tool_calls_made: 0, + iterations_used: 1, + truncated: false, + prompt_tokens: None, + eval_tokens: None, + num_ctx: None, + amended_insight_id: None, + backend_used: "local".into(), + model_used: "m".into(), + cancelled: false, + }) + .await; + writer.set_terminal_status(TurnStatus::Done); + }); + + let events = drain_to_end(entry).await; + w.await.unwrap(); + + assert!( + events.last().is_some_and(is_terminal), + "terminal event missing; got {} events", + events.len() + ); + assert_eq!(events.len(), 3, "expected IterationStart, TextDelta, Done"); + } + } + + /// A reader that connects before any event is pushed blocks in + /// `next_batch` and then receives events as the writer produces them. + #[tokio::test] + async fn next_batch_waits_for_late_events() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + + let writer = entry.clone(); + tokio::spawn(async move { + tokio::task::yield_now().await; + writer + .push_event(ChatStreamEvent::TextDelta("late".into())) + .await; + writer.set_terminal_status(TurnStatus::Done); + }); + + // First call blocks until the writer pushes, rather than returning + // CaughtUp on the empty buffer of a running turn. + match entry.next_batch(0).await { + ReplayOutcome::Events { events, next_skip } => { + assert_eq!(events.len(), 1); + assert_eq!(next_skip, 1); + } + other => panic!("expected Events, got {other:?}"), + } + } + + #[tokio::test] + async fn next_batch_closes_on_terminal_when_caught_up() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + entry + .push_event(ChatStreamEvent::TextDelta("x".into())) + .await; + entry.set_terminal_status(TurnStatus::Done); + + // Caught up (skip past the one buffered event) on a finished turn → + // CaughtUp so the handler closes the stream rather than hanging. + assert!(matches!( + entry.next_batch(1).await, + ReplayOutcome::CaughtUp { .. } + )); + } + + #[tokio::test] + async fn next_batch_reports_gone_for_evicted_index() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + for i in 0..=MAX_BUFFERED_EVENTS { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + // Index 0 was evicted (base advanced to 1). + assert!(matches!(entry.next_batch(0).await, ReplayOutcome::Gone)); + } + + // ── abort handle (#1 cancellation) ────────────────────────────── + + #[tokio::test] + async fn abort_handle_aborts_task_once() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + + // No handle installed yet → abort is a no-op. + assert!(!entry.abort()); + + let handle = tokio::spawn(async { + // Long-lived task that only ends via abort. + futures::future::pending::<()>().await; + }); + entry.set_abort_handle(handle.abort_handle()); + + assert!(entry.abort(), "first abort should fire"); + assert!(!entry.abort(), "handle is taken; second abort is a no-op"); + + // The aborted task resolves to a cancellation JoinError. + let join = handle.await; + assert!(join.unwrap_err().is_cancelled()); + } + + #[tokio::test] + async fn base_index_tracks_eviction() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + for i in 0..(MAX_BUFFERED_EVENTS + 5) { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + let info = entry.info().await; + // 5 events evicted; total keeps climbing, buffer stays capped. + assert_eq!(info.total_events_pushed, (MAX_BUFFERED_EVENTS + 5) as u32); + assert_eq!(info.buffered_count, MAX_BUFFERED_EVENTS as u32); + // First live index is 5: reading from there yields the full buffer. + let from_base = events_of(entry.replay_from(5).await); + assert_eq!(from_base.len(), MAX_BUFFERED_EVENTS); + } +} diff --git a/src/bin/probe_clip_search.rs b/src/bin/probe_clip_search.rs index 80d5e7f..c9e652f 100644 --- a/src/bin/probe_clip_search.rs +++ b/src/bin/probe_clip_search.rs @@ -219,7 +219,7 @@ async fn main() -> anyhow::Result<()> { } let sim = dot(&vec, &query_vec); scores.push((sim, rel_path.clone())); - if encoded % 10 == 0 { + if encoded.is_multiple_of(10) { info!( "progress: {} encoded, {:.1}s elapsed", encoded, diff --git a/src/clip_search.rs b/src/clip_search.rs index d91e490..98ea96e 100644 --- a/src/clip_search.rs +++ b/src/clip_search.rs @@ -109,7 +109,7 @@ struct SearchError { /// `None` on malformed bytes — those rows get skipped rather than /// failing the whole query. fn decode_embedding(bytes: &[u8]) -> Option> { - if bytes.is_empty() || bytes.len() % 4 != 0 { + if bytes.is_empty() || !bytes.len().is_multiple_of(4) { return None; } let mut out = Vec::with_capacity(bytes.len() / 4); diff --git a/src/database/calendar_dao.rs b/src/database/calendar_dao.rs index b70a9f6..4ebd21c 100644 --- a/src/database/calendar_dao.rs +++ b/src/database/calendar_dao.rs @@ -274,7 +274,7 @@ impl CalendarEventDao for SqliteCalendarEventDao { source_file: event.source_file, }) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn store_events_batch( @@ -348,7 +348,7 @@ impl CalendarEventDao for SqliteCalendarEventDao { Ok(inserted) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn find_events_in_range( @@ -373,7 +373,7 @@ impl CalendarEventDao for SqliteCalendarEventDao { .map(|rows| rows.into_iter().map(|r| r.to_calendar_event()).collect()) .map_err(|e| anyhow::anyhow!("Query error: {:?}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_similar_events( @@ -429,7 +429,7 @@ impl CalendarEventDao for SqliteCalendarEventDao { Ok(scored_events.into_iter().take(limit).map(|(_, event)| event).collect()) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_relevant_events_hybrid( @@ -500,7 +500,7 @@ impl CalendarEventDao for SqliteCalendarEventDao { Ok(events_in_range.into_iter().take(limit).map(|r| r.to_calendar_event()).collect()) } }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn event_exists( @@ -528,7 +528,7 @@ impl CalendarEventDao for SqliteCalendarEventDao { Ok(result.count > 0) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_event_count(&mut self, context: &opentelemetry::Context) -> Result { @@ -551,6 +551,6 @@ impl CalendarEventDao for SqliteCalendarEventDao { Ok(result.count) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } } diff --git a/src/database/daily_summary_dao.rs b/src/database/daily_summary_dao.rs index ec2a161..521c1a5 100644 --- a/src/database/daily_summary_dao.rs +++ b/src/database/daily_summary_dao.rs @@ -190,7 +190,7 @@ impl DailySummaryDao for SqliteDailySummaryDao { model_version: summary.model_version, }) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn find_similar_summaries( @@ -286,7 +286,7 @@ impl DailySummaryDao for SqliteDailySummaryDao { Ok(top_results) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_similar_summaries_with_time_weight( @@ -408,7 +408,7 @@ impl DailySummaryDao for SqliteDailySummaryDao { Ok(top_results) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn summary_exists( @@ -435,7 +435,7 @@ impl DailySummaryDao for SqliteDailySummaryDao { Ok(count > 0) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_summary_count( @@ -457,7 +457,7 @@ impl DailySummaryDao for SqliteDailySummaryDao { .map(|r| r.count) .map_err(|e| anyhow::anyhow!("Count query error: {:?}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn has_any_summaries(&mut self, context: &opentelemetry::Context) -> Result { @@ -481,7 +481,7 @@ impl DailySummaryDao for SqliteDailySummaryDao { Ok(!rows.is_empty()) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } } diff --git a/src/database/insight_generation_job_dao.rs b/src/database/insight_generation_job_dao.rs new file mode 100644 index 0000000..1ff2eef --- /dev/null +++ b/src/database/insight_generation_job_dao.rs @@ -0,0 +1,681 @@ +use diesel::prelude::*; +use diesel::sqlite::SqliteConnection; +use std::ops::DerefMut; +use std::sync::{Arc, Mutex}; + +use crate::database::models::{ + InsertInsightGenerationJob, InsightGenerationJob, InsightGenerationType, InsightJobStatus, +}; +use crate::database::schema; +use crate::database::{DbError, DbErrorKind, connect}; +use crate::otel::trace_db_call; + +/// Tracks async insight generation jobs. Each call to `create_job` inserts +/// a new row; the application layer prevents concurrent running jobs by +/// cancelling the old one before creating a new one. +pub trait InsightGenerationJobDao: Sync + Send { + /// Insert a new running job. Always creates a new row (no upsert). + /// Cleans up terminal-state rows for the same key first. + fn create_job( + &mut self, + context: &opentelemetry::Context, + library_id: i32, + file_path: &str, + generation_type: InsightGenerationType, + ) -> Result; + + /// Mark a job as completed with the resulting insight id. Only updates + /// if the job is still in "running" status (prevents overwriting a + /// cancelled job with a late-completing task). + fn complete_job( + &mut self, + context: &opentelemetry::Context, + job_id: i32, + insight_id: i32, + ) -> Result<(), DbError>; + + /// Mark a job as failed with an error message. Only updates if the job + /// is still in "running" status. + fn fail_job( + &mut self, + context: &opentelemetry::Context, + job_id: i32, + error_message: &str, + ) -> Result<(), DbError>; + + /// Cancel a specific job by id. Only updates if the job is still + /// in "running" status. Returns true if a row was updated. + fn cancel_job( + &mut self, + context: &opentelemetry::Context, + job_id: i32, + ) -> Result; + + /// Cancel all running jobs for a given file. Returns the number of + /// jobs cancelled. + fn cancel_active_jobs( + &mut self, + context: &opentelemetry::Context, + library_id: i32, + file_path: &str, + ) -> Result; + + /// Find the latest running job for a given file. Returns None if no + /// running job exists. + fn get_active_job( + &mut self, + context: &opentelemetry::Context, + library_id: i32, + file_path: &str, + ) -> Result, DbError>; + + /// Find any job by id regardless of status. + fn get_job_by_id( + &mut self, + context: &opentelemetry::Context, + job_id: i32, + ) -> Result, DbError>; + + /// Mark all jobs still in "running" status as "failed" with a recovery + /// error message. Returns the number of jobs recovered. + fn recover_orphaned_jobs(&mut self, context: &opentelemetry::Context) + -> Result; +} + +pub struct SqliteInsightGenerationJobDao { + connection: Arc>, +} + +impl Default for SqliteInsightGenerationJobDao { + fn default() -> Self { + Self::new() + } +} + +impl SqliteInsightGenerationJobDao { + pub fn new() -> Self { + Self { + connection: Arc::new(Mutex::new(connect())), + } + } + + #[cfg(test)] + pub fn from_connection(conn: Arc>) -> Self { + Self { connection: conn } + } +} + +impl InsightGenerationJobDao for SqliteInsightGenerationJobDao { + fn create_job( + &mut self, + context: &opentelemetry::Context, + library_id: i32, + file_path: &str, + generation_type: InsightGenerationType, + ) -> Result { + trace_db_call(context, "insert", "create_job", |_span| { + use schema::insight_generation_jobs::dsl; + + let mut connection = self + .connection + .lock() + .expect("Unable to lock InsightGenerationJobDao"); + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as i64; + + let new_job = InsertInsightGenerationJob { + library_id, + path: file_path.to_string(), + gen_type: generation_type.to_string(), + status: InsightJobStatus::Running.to_string(), + started_at: now, + }; + + diesel::insert_into(dsl::insight_generation_jobs) + .values(&new_job) + .execute(connection.deref_mut()) + .map_err(|e| anyhow::anyhow!("Failed to insert job: {}", e))?; + + dsl::insight_generation_jobs + .filter( + dsl::library_id + .eq(library_id) + .and(dsl::file_path.eq(file_path)) + .and(dsl::generation_type.eq(generation_type.as_str())) + .and(dsl::status.eq(InsightJobStatus::Running.as_str())), + ) + .select(dsl::id) + .order(dsl::id.desc()) + .first::(connection.deref_mut()) + .map_err(|e| anyhow::anyhow!("Failed to get job id: {}", e)) + }) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) + } + + fn complete_job( + &mut self, + context: &opentelemetry::Context, + job_id: i32, + insight_id: i32, + ) -> Result<(), DbError> { + trace_db_call(context, "update", "complete_job", |_span| { + use schema::insight_generation_jobs::dsl; + + let mut connection = self + .connection + .lock() + .expect("Unable to lock InsightGenerationJobDao"); + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as i64; + + // Only update if still running — prevents cancelled job from + // being overwritten by a late-completing task. + diesel::update( + dsl::insight_generation_jobs.filter( + dsl::id + .eq(job_id) + .and(dsl::status.eq(InsightJobStatus::Running.as_str())), + ), + ) + .set(( + dsl::status.eq(InsightJobStatus::Completed.as_str()), + dsl::completed_at.eq(Some(now)), + dsl::result_insight_id.eq(Some(insight_id)), + )) + .execute(connection.deref_mut()) + .map(|_| ()) + .map_err(|e| anyhow::anyhow!("Failed to complete job: {}", e)) + }) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) + } + + fn fail_job( + &mut self, + context: &opentelemetry::Context, + job_id: i32, + error_message: &str, + ) -> Result<(), DbError> { + trace_db_call(context, "update", "fail_job", |_span| { + use schema::insight_generation_jobs::dsl; + + let mut connection = self + .connection + .lock() + .expect("Unable to lock InsightGenerationJobDao"); + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as i64; + + // Only update if still running. + diesel::update( + dsl::insight_generation_jobs.filter( + dsl::id + .eq(job_id) + .and(dsl::status.eq(InsightJobStatus::Running.as_str())), + ), + ) + .set(( + dsl::status.eq(InsightJobStatus::Failed.as_str()), + dsl::completed_at.eq(Some(now)), + dsl::error_message.eq(Some(error_message.to_string())), + )) + .execute(connection.deref_mut()) + .map(|_| ()) + .map_err(|e| anyhow::anyhow!("Failed to fail job: {}", e)) + }) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) + } + + fn cancel_job( + &mut self, + context: &opentelemetry::Context, + job_id: i32, + ) -> Result { + trace_db_call(context, "update", "cancel_job", |_span| { + use schema::insight_generation_jobs::dsl; + + let mut connection = self + .connection + .lock() + .expect("Unable to lock InsightGenerationJobDao"); + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as i64; + + let rows = diesel::update( + dsl::insight_generation_jobs.filter( + dsl::id + .eq(job_id) + .and(dsl::status.eq(InsightJobStatus::Running.as_str())), + ), + ) + .set(( + dsl::status.eq(InsightJobStatus::Cancelled.as_str()), + dsl::completed_at.eq(Some(now)), + dsl::error_message.eq(Some("cancelled by user".to_string())), + )) + .execute(connection.deref_mut()) + .map_err(|e| anyhow::anyhow!("Failed to cancel job: {}", e))?; + + Ok(rows > 0) + }) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) + } + + fn cancel_active_jobs( + &mut self, + context: &opentelemetry::Context, + library_id: i32, + file_path: &str, + ) -> Result { + trace_db_call(context, "update", "cancel_active_jobs", |_span| { + use schema::insight_generation_jobs::dsl; + + let mut connection = self + .connection + .lock() + .expect("Unable to lock InsightGenerationJobDao"); + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as i64; + + let rows = diesel::update( + dsl::insight_generation_jobs.filter( + dsl::library_id + .eq(library_id) + .and(dsl::file_path.eq(file_path)) + .and(dsl::status.eq(InsightJobStatus::Running.as_str())), + ), + ) + .set(( + dsl::status.eq(InsightJobStatus::Cancelled.as_str()), + dsl::completed_at.eq(Some(now)), + dsl::error_message.eq(Some("cancelled by newer request".to_string())), + )) + .execute(connection.deref_mut()) + .map_err(|e| anyhow::anyhow!("Failed to cancel active jobs: {}", e))?; + + Ok(rows) + }) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) + } + + fn get_active_job( + &mut self, + context: &opentelemetry::Context, + library_id: i32, + file_path: &str, + ) -> Result, DbError> { + trace_db_call(context, "query", "get_active_job", |_span| { + use schema::insight_generation_jobs::dsl; + + let mut connection = self + .connection + .lock() + .expect("Unable to lock InsightGenerationJobDao"); + + dsl::insight_generation_jobs + .filter( + dsl::library_id + .eq(library_id) + .and(dsl::file_path.eq(file_path)) + .and(dsl::status.eq(InsightJobStatus::Running.as_str())), + ) + .order(dsl::id.desc()) + .first::(connection.deref_mut()) + .optional() + .map_err(|e| anyhow::anyhow!("Failed to get active job: {}", e)) + }) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) + } + + fn get_job_by_id( + &mut self, + context: &opentelemetry::Context, + job_id: i32, + ) -> Result, DbError> { + trace_db_call(context, "query", "get_job_by_id", |_span| { + use schema::insight_generation_jobs::dsl; + + let mut connection = self + .connection + .lock() + .expect("Unable to lock InsightGenerationJobDao"); + + dsl::insight_generation_jobs + .filter(dsl::id.eq(job_id)) + .first::(connection.deref_mut()) + .optional() + .map_err(|e| anyhow::anyhow!("Failed to get job: {}", e)) + }) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) + } + + fn recover_orphaned_jobs( + &mut self, + context: &opentelemetry::Context, + ) -> Result { + trace_db_call(context, "update", "recover_orphaned_jobs", |_span| { + use schema::insight_generation_jobs::dsl; + + let mut connection = self + .connection + .lock() + .expect("Unable to lock InsightGenerationJobDao"); + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as i64; + + let rows = diesel::update( + dsl::insight_generation_jobs + .filter(dsl::status.eq(InsightJobStatus::Running.as_str())), + ) + .set(( + dsl::status.eq(InsightJobStatus::Failed.as_str()), + dsl::completed_at.eq(Some(now)), + dsl::error_message.eq(Some("server crashed while running".to_string())), + )) + .execute(connection.deref_mut()) + .map_err(|e| anyhow::anyhow!("Failed to recover orphaned jobs: {}", e))?; + + Ok(rows) + }) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use diesel::Connection; + use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; + + const DB_MIGRATIONS: EmbeddedMigrations = embed_migrations!(); + + fn setup_dao() -> SqliteInsightGenerationJobDao { + let mut conn = SqliteConnection::establish(":memory:") + .expect("Unable to create in-memory db connection"); + conn.run_pending_migrations(DB_MIGRATIONS) + .expect("Failure running DB migrations"); + SqliteInsightGenerationJobDao::from_connection(Arc::new(Mutex::new(conn))) + } + + fn ctx() -> opentelemetry::Context { + opentelemetry::Context::new() + } + + #[test] + fn create_job_inserts_new_row() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id_1 = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + let job_id_2 = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + assert_ne!(job_id_1, job_id_2, "each create_job call inserts a new row"); + } + + #[test] + fn complete_job_sets_result() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + dao.complete_job(&ctx, job_id, 42).unwrap(); + + let job = dao.get_job_by_id(&ctx, job_id).unwrap().unwrap(); + assert_eq!(job.status, InsightJobStatus::Completed.as_str()); + assert_eq!(job.result_insight_id, Some(42)); + assert!(job.completed_at.is_some()); + } + + #[test] + fn fail_job_sets_error() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Agentic) + .unwrap(); + + dao.fail_job(&ctx, job_id, "model timeout").unwrap(); + + let job = dao.get_job_by_id(&ctx, job_id).unwrap().unwrap(); + assert_eq!(job.status, InsightJobStatus::Failed.as_str()); + assert_eq!(job.error_message.as_deref(), Some("model timeout")); + assert!(job.completed_at.is_some()); + } + + #[test] + fn get_active_job_returns_none_when_completed() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + // Job is running + let active = dao.get_active_job(&ctx, 1, "photos/test.jpg").unwrap(); + assert!(active.is_some()); + assert_eq!(active.unwrap().id, job_id); + + // Complete it + dao.complete_job(&ctx, job_id, 1).unwrap(); + + // No longer active + let active = dao.get_active_job(&ctx, 1, "photos/test.jpg").unwrap(); + assert!(active.is_none()); + } + + #[test] + fn cancel_active_jobs() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + let cancelled = dao.cancel_active_jobs(&ctx, 1, "photos/test.jpg").unwrap(); + assert_eq!(cancelled, 1, "should cancel 1 running job"); + + // Job is no longer active + let active = dao.get_active_job(&ctx, 1, "photos/test.jpg").unwrap(); + assert!(active.is_none()); + + // Job exists with cancelled status + let job = dao.get_job_by_id(&ctx, job_id).unwrap().unwrap(); + assert_eq!(job.status, InsightJobStatus::Cancelled.as_str()); + + // Cancelling again returns 0 (nothing to cancel) + let cancelled2 = dao.cancel_active_jobs(&ctx, 1, "photos/test.jpg").unwrap(); + assert_eq!(cancelled2, 0, "should return 0 when no running job"); + } + + #[test] + fn get_active_job_scoped_by_library() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id_1 = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + let job_id_2 = dao + .create_job(&ctx, 2, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + assert_ne!( + job_id_1, job_id_2, + "different libraries should have separate jobs" + ); + + // Complete lib1's job + dao.complete_job(&ctx, job_id_1, 1).unwrap(); + + // lib1 has no active job + let active1 = dao.get_active_job(&ctx, 1, "photos/test.jpg").unwrap(); + assert!(active1.is_none()); + + // lib2 still has active job + let active2 = dao.get_active_job(&ctx, 2, "photos/test.jpg").unwrap(); + assert!(active2.is_some()); + assert_eq!(active2.unwrap().id, job_id_2); + } + + #[test] + fn get_job_by_id_finds_any_status() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + // Find while running + let job = dao.get_job_by_id(&ctx, job_id).unwrap().unwrap(); + assert_eq!(job.status, InsightJobStatus::Running.as_str()); + + // Complete it + dao.complete_job(&ctx, job_id, 99).unwrap(); + + // Still findable + let job = dao.get_job_by_id(&ctx, job_id).unwrap().unwrap(); + assert_eq!(job.status, InsightJobStatus::Completed.as_str()); + assert_eq!(job.result_insight_id, Some(99)); + } + + #[test] + fn recover_orphaned_jobs() { + let mut dao = setup_dao(); + let ctx = ctx(); + + // Create two running jobs + let job_id_1 = dao + .create_job(&ctx, 1, "photos/a.jpg", InsightGenerationType::Standard) + .unwrap(); + let job_id_2 = dao + .create_job(&ctx, 1, "photos/b.jpg", InsightGenerationType::Agentic) + .unwrap(); + + // Complete one + dao.complete_job(&ctx, job_id_1, 1).unwrap(); + + // Recover should only affect the running job + let recovered = dao.recover_orphaned_jobs(&ctx).unwrap(); + assert_eq!(recovered, 1, "should recover exactly 1 running job"); + + // job_id_1 is still completed + let job1 = dao.get_job_by_id(&ctx, job_id_1).unwrap().unwrap(); + assert_eq!(job1.status, InsightJobStatus::Completed.as_str()); + + // job_id_2 is now failed with recovery message + let job2 = dao.get_job_by_id(&ctx, job_id_2).unwrap().unwrap(); + assert_eq!(job2.status, InsightJobStatus::Failed.as_str()); + assert_eq!( + job2.error_message.as_deref(), + Some("server crashed while running") + ); + + // Second recovery is a no-op + let recovered2 = dao.recover_orphaned_jobs(&ctx).unwrap(); + assert_eq!(recovered2, 0, "no running jobs remain"); + } + + #[test] + fn complete_job_noop_when_cancelled() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + dao.cancel_job(&ctx, job_id).unwrap(); + + // Late-completing task tries to mark as completed — should be a no-op + dao.complete_job(&ctx, job_id, 42).unwrap(); + + let job = dao.get_job_by_id(&ctx, job_id).unwrap().unwrap(); + assert_eq!( + job.status, + InsightJobStatus::Cancelled.as_str(), + "cancelled status must not be overwritten by late complete" + ); + assert_eq!( + job.result_insight_id, None, + "insight_id must stay None when complete is a no-op" + ); + } + + #[test] + fn fail_job_noop_when_cancelled() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Agentic) + .unwrap(); + + dao.cancel_job(&ctx, job_id).unwrap(); + + // Late-failing task tries to mark as failed — should be a no-op + dao.fail_job(&ctx, job_id, "timeout after 120s").unwrap(); + + let job = dao.get_job_by_id(&ctx, job_id).unwrap().unwrap(); + assert_eq!( + job.status, + InsightJobStatus::Cancelled.as_str(), + "cancelled status must not be overwritten by late fail" + ); + assert_eq!( + job.error_message.as_deref(), + Some("cancelled by user"), + "error_message must reflect the cancel, not the late fail" + ); + } + + #[test] + fn cancel_job_by_id() { + let mut dao = setup_dao(); + let ctx = ctx(); + + let job_id = dao + .create_job(&ctx, 1, "photos/test.jpg", InsightGenerationType::Standard) + .unwrap(); + + let cancelled = dao.cancel_job(&ctx, job_id).unwrap(); + assert!(cancelled, "should cancel running job"); + + let job = dao.get_job_by_id(&ctx, job_id).unwrap().unwrap(); + assert_eq!(job.status, InsightJobStatus::Cancelled.as_str()); + assert!(job.completed_at.is_some()); + + // Cancelling again is a no-op + let cancelled2 = dao.cancel_job(&ctx, job_id).unwrap(); + assert!(!cancelled2, "already cancelled job should return false"); + } +} diff --git a/src/database/insights_dao.rs b/src/database/insights_dao.rs index 86c51aa..6b467ea 100644 --- a/src/database/insights_dao.rs +++ b/src/database/insights_dao.rs @@ -90,13 +90,15 @@ pub trait InsightDao: Sync + Send { /// Replace the `training_messages` JSON blob on the current row for /// `(library_id, rel_path)`. Used by chat-turn append mode to persist /// the extended conversation without inserting a new insight version. + /// Returns the number of rows affected (0 if no current row matched, + /// indicating a concurrent regenerate/reconcile flipped `is_current`). fn update_training_messages( &mut self, context: &opentelemetry::Context, library_id: i32, file_path: &str, training_messages_json: &str, - ) -> Result<(), DbError>; + ) -> Result; } pub struct SqliteInsightDao { @@ -159,13 +161,13 @@ impl InsightDao for SqliteInsightDao { ) .set(is_current.eq(false)) .execute(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Update is_current error"))?; + .map_err(|e| anyhow::anyhow!("Failed to flip is_current: {}", e))?; // Insert the new insight as current diesel::insert_into(photo_insights) .values(&insight) .execute(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Insert error"))?; + .map_err(|e| anyhow::anyhow!("Failed to insert insight: {}", e))?; // Retrieve the inserted record (is_current = true) photo_insights @@ -173,9 +175,12 @@ impl InsightDao for SqliteInsightDao { .filter(rel_path.eq(&insight.file_path)) .filter(is_current.eq(true)) .first::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Failed to retrieve inserted insight: {}", e)) + }) + .map_err(|e| { + log::error!("store_insight failed: {}", e); + DbError::new(DbErrorKind::InsertError) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) } fn get_insight( @@ -193,9 +198,9 @@ impl InsightDao for SqliteInsightDao { .filter(is_current.eq(true)) .first::(connection.deref_mut()) .optional() - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_current_insight_for_library( @@ -219,10 +224,10 @@ impl InsightDao for SqliteInsightDao { .filter(is_current.eq(true)) .first::(connection.deref_mut()) .optional() - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }, ) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_insight_for_paths( @@ -244,9 +249,9 @@ impl InsightDao for SqliteInsightDao { .order(generated_at.desc()) .first::(connection.deref_mut()) .optional() - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_insight_history( @@ -263,9 +268,9 @@ impl InsightDao for SqliteInsightDao { .filter(rel_path.eq(path)) .order(generated_at.desc()) .load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_insight_by_id( @@ -282,9 +287,9 @@ impl InsightDao for SqliteInsightDao { .find(insight_id) .first::(connection.deref_mut()) .optional() - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn delete_insight( @@ -300,9 +305,9 @@ impl InsightDao for SqliteInsightDao { diesel::delete(photo_insights.filter(rel_path.eq(path))) .execute(connection.deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Delete error")) + .map_err(|e| anyhow::anyhow!("Delete error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_all_insights( @@ -318,9 +323,9 @@ impl InsightDao for SqliteInsightDao { .filter(is_current.eq(true)) .order(generated_at.desc()) .load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn rate_insight( @@ -342,9 +347,9 @@ impl InsightDao for SqliteInsightDao { .set(approved.eq(Some(is_approved))) .execute(connection.deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Update error")) + .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn get_approved_insights( @@ -361,9 +366,9 @@ impl InsightDao for SqliteInsightDao { .filter(training_messages.is_not_null()) .order(generated_at.desc()) .load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn update_training_messages( @@ -372,7 +377,7 @@ impl InsightDao for SqliteInsightDao { lib_id: i32, path: &str, training_messages_json: &str, - ) -> Result<(), DbError> { + ) -> Result { trace_db_call(context, "update", "update_training_messages", |_span| { use schema::photo_insights::dsl::*; @@ -386,9 +391,8 @@ impl InsightDao for SqliteInsightDao { ) .set(training_messages.eq(Some(training_messages_json.to_string()))) .execute(connection.deref_mut()) - .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Update error")) + .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } } diff --git a/src/database/knowledge_dao.rs b/src/database/knowledge_dao.rs index 06b2b2d..9eddb26 100644 --- a/src/database/knowledge_dao.rs +++ b/src/database/knowledge_dao.rs @@ -582,7 +582,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .map_err(|e| anyhow::anyhow!("Query error: {}", e)) } }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn get_entity_by_id( @@ -599,7 +599,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .optional() .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_entity_by_name( @@ -624,7 +624,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .load::(conn.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_entities_with_embeddings( @@ -649,7 +649,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .load::(conn.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn list_entities( @@ -706,7 +706,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { Ok((results, total)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn list_entities_with_fact_counts( @@ -894,7 +894,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { Ok((pairs, total)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_predicate_stats( @@ -957,7 +957,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .map_err(|e| anyhow::anyhow!("Query error: {}", e))?; Ok(rows.into_iter().map(|r| (r.predicate, r.cnt)).collect()) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn bulk_reject_facts_by_predicate( @@ -1016,7 +1016,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { }; Ok(touched) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn build_entity_graph( @@ -1194,7 +1194,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { Ok(EntityGraph { nodes, edges }) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_consolidation_proposals( @@ -1349,7 +1349,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { result.truncate(max_groups); Ok(result) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_persona_breakdowns_for_entities( @@ -1411,7 +1411,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { } Ok(out) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn update_entity_status( @@ -1429,7 +1429,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .map(|_| ()) .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn update_entity( @@ -1475,7 +1475,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .optional() .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn delete_entity( @@ -1565,7 +1565,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { }) .map_err(|e| anyhow::anyhow!("Merge transaction error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } // ----------------------------------------------------------------------- @@ -1636,7 +1636,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { Ok((inserted, true)) // true = newly created } }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn get_facts_for_entity( @@ -1662,7 +1662,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { q.load::(conn.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn list_facts( @@ -1719,7 +1719,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { Ok((results, total)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn update_fact( @@ -1801,7 +1801,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .optional() .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn update_facts_insight_id( @@ -1823,7 +1823,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .map(|_| ()) .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn delete_fact(&mut self, cx: &opentelemetry::Context, fact_id: i32) -> Result<(), DbError> { @@ -2015,7 +2015,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .map(|_| ()) .map_err(|e| anyhow::anyhow!("Insert error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn delete_photo_links_for_file( @@ -2031,7 +2031,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .map(|_| ()) .map_err(|e| anyhow::anyhow!("Delete error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_links_for_photo( @@ -2047,7 +2047,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .load::(conn.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_links_for_entity( @@ -2063,7 +2063,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { .load::(conn.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } // ----------------------------------------------------------------------- @@ -2111,7 +2111,7 @@ impl KnowledgeDao for SqliteKnowledgeDao { facts: recent_facts, }) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } } diff --git a/src/database/location_dao.rs b/src/database/location_dao.rs index 95f5d8f..8bb0ac4 100644 --- a/src/database/location_dao.rs +++ b/src/database/location_dao.rs @@ -273,7 +273,7 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { source_file: location.source_file, }) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn store_locations_batch( @@ -350,7 +350,7 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { Ok(inserted) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn find_nearest_location( @@ -385,7 +385,7 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { Ok(results.into_iter().next().map(|r| r.to_location_record())) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_locations_in_range( @@ -413,7 +413,7 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { .map(|rows| rows.into_iter().map(|r| r.to_location_record()).collect()) .map_err(|e| anyhow::anyhow!("Query error: {:?}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_locations_near_point( @@ -468,7 +468,7 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { Ok(filtered) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn location_exists( @@ -502,7 +502,7 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { Ok(result.count > 0) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_location_count(&mut self, context: &opentelemetry::Context) -> Result { @@ -525,6 +525,6 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { Ok(result.count) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 4488a00..d063bd0 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -45,6 +45,7 @@ pub struct DuplicateRow { pub mod calendar_dao; pub mod daily_summary_dao; +pub mod insight_generation_job_dao; pub mod insights_dao; pub mod knowledge_dao; pub mod location_dao; @@ -57,6 +58,7 @@ pub mod search_dao; pub use calendar_dao::{CalendarEventDao, SqliteCalendarEventDao}; pub use daily_summary_dao::{DailySummaryDao, InsertDailySummary, SqliteDailySummaryDao}; +pub use insight_generation_job_dao::{InsightGenerationJobDao, SqliteInsightGenerationJobDao}; pub use insights_dao::{InsightDao, SqliteInsightDao}; pub use knowledge_dao::{ ConsolidationGroup, EntityFilter, EntityGraph, EntityPatch, EntitySort, FactFilter, FactPatch, @@ -191,14 +193,26 @@ pub fn connect() -> SqliteConnection { conn } -#[derive(Debug)] pub struct DbError { pub kind: DbErrorKind, + pub source: Option, } impl DbError { fn new(kind: DbErrorKind) -> Self { - DbError { kind } + DbError { kind, source: None } + } + + /// Capture the source error message AND log it. Callers should use + /// this from `map_err` closures so the underlying Diesel/SQLite + /// error survives the conversion to `DbError`. + fn log(kind: DbErrorKind, source: impl std::fmt::Display) -> Self { + let msg = source.to_string(); + log::error!("DB {:?}: {}", kind, msg); + DbError { + kind, + source: Some(msg), + } } fn exists() -> Self { @@ -206,6 +220,26 @@ impl DbError { } } +impl std::fmt::Debug for DbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.source { + Some(s) => write!(f, "DbError {{ kind: {:?}, source: {} }}", self.kind, s), + None => write!(f, "DbError {{ kind: {:?} }}", self.kind), + } + } +} + +impl std::fmt::Display for DbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.source { + Some(s) => write!(f, "{:?}: {}", self.kind, s), + None => write!(f, "{:?}", self.kind), + } + } +} + +impl std::error::Error for DbError {} + #[derive(Debug, PartialEq)] pub enum DbErrorKind { AlreadyExists, @@ -260,7 +294,7 @@ impl FavoriteDao for SqliteFavoriteDao { path: favorite_path, }) .execute(connection.deref_mut()) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } else { Err(DbError::exists()) } @@ -281,7 +315,7 @@ impl FavoriteDao for SqliteFavoriteDao { favorites .filter(userid.eq(user_id)) .load::(self.connection.lock().unwrap().deref_mut()) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn update_path(&mut self, old_path: &str, new_path: &str) -> Result<(), DbError> { @@ -290,7 +324,7 @@ impl FavoriteDao for SqliteFavoriteDao { diesel::update(favorites.filter(rel_path.eq(old_path))) .set(rel_path.eq(new_path)) .execute(self.connection.lock().unwrap().deref_mut()) - .map_err(|_| DbError::new(DbErrorKind::UpdateError))?; + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e))?; Ok(()) } @@ -301,7 +335,7 @@ impl FavoriteDao for SqliteFavoriteDao { .select(rel_path) .distinct() .load(self.connection.lock().unwrap().deref_mut()) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } } @@ -921,7 +955,7 @@ impl ExifDao for SqliteExifDao { .first::(connection.deref_mut()) .map_err(|e| anyhow::anyhow!("Post-insert lookup failed: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn get_exif( @@ -948,7 +982,7 @@ impl ExifDao for SqliteExifDao { Err(_) => Err(anyhow::anyhow!("Query error")), } }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn update_exif( @@ -985,15 +1019,15 @@ impl ExifDao for SqliteExifDao { last_modified.eq(&exif_data.last_modified), )) .execute(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Update error"))?; + .map_err(|e| anyhow::anyhow!("Update error: {}", e))?; image_exif .filter(library_id.eq(exif_data.library_id)) .filter(rel_path.eq(&exif_data.file_path)) .first::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn delete_exif(&mut self, context: &opentelemetry::Context, path: &str) -> Result<(), DbError> { @@ -1003,9 +1037,9 @@ impl ExifDao for SqliteExifDao { diesel::delete(image_exif.filter(rel_path.eq(path))) .execute(self.connection.lock().unwrap().deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Delete error")) + .map_err(|e| anyhow::anyhow!("Delete error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_all_with_date_taken( @@ -1036,9 +1070,9 @@ impl ExifDao for SqliteExifDao { .filter_map(|(path, dt)| dt.map(|ts| (path, ts))) .collect() }) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_exif_batch( @@ -1062,9 +1096,9 @@ impl ExifDao for SqliteExifDao { query .filter(rel_path.eq_any(file_paths)) .load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn query_by_exif( @@ -1123,9 +1157,9 @@ impl ExifDao for SqliteExifDao { query .load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_camera_makes( @@ -1150,9 +1184,9 @@ impl ExifDao for SqliteExifDao { .filter_map(|(make, cnt)| make.map(|m| (m, cnt))) .collect() }) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn update_file_path( @@ -1169,10 +1203,10 @@ impl ExifDao for SqliteExifDao { diesel::update(image_exif.filter(rel_path.eq(old_path))) .set(rel_path.eq(new_path)) .execute(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Update error"))?; + .map_err(|e| anyhow::anyhow!("Update error: {}", e))?; Ok(()) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn get_all_file_paths( @@ -1187,9 +1221,9 @@ impl ExifDao for SqliteExifDao { image_exif .select(rel_path) .load(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_all_with_gps( @@ -1257,7 +1291,7 @@ impl ExifDao for SqliteExifDao { Ok(filtered) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_rows_missing_hash( @@ -1276,9 +1310,9 @@ impl ExifDao for SqliteExifDao { .order(id.asc()) .limit(limit) .load::<(i32, String)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn backfill_content_hash( @@ -1302,9 +1336,9 @@ impl ExifDao for SqliteExifDao { .set((content_hash.eq(hash), size_bytes.eq(size_val))) .execute(connection.deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Update error")) + .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn list_distinct_content_hashes( @@ -1322,9 +1356,9 @@ impl ExifDao for SqliteExifDao { .distinct() .load::>(connection.deref_mut()) .map(|rows| rows.into_iter().flatten().collect()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn list_paths_and_hashes_for_library( @@ -1345,10 +1379,10 @@ impl ExifDao for SqliteExifDao { .filter(library_id.eq(lib_id)) .select((rel_path, content_hash)) .load::<(String, Option)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }, ) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_rows_needing_date_backfill( @@ -1375,10 +1409,10 @@ impl ExifDao for SqliteExifDao { .order(id.asc()) .limit(limit) .load::<(i32, String)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }, ) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn backfill_date_taken( @@ -1469,10 +1503,10 @@ impl ExifDao for SqliteExifDao { .order(id.asc()) .limit(limit) .load::<(String, String)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }, ) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn backfill_clip_embedding( @@ -1565,7 +1599,7 @@ impl ExifDao for SqliteExifDao { )) .order(id.asc()) .load::<(String, Vec)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error"))?; + .map_err(|e| anyhow::anyhow!("Query error: {}", e))?; // Dedupe by hash, keeping the first occurrence. Cheap; sized // to ~14k entries on this library. @@ -1579,7 +1613,7 @@ impl ExifDao for SqliteExifDao { } Ok(out) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn set_manual_date_taken( @@ -1739,7 +1773,7 @@ impl ExifDao for SqliteExifDao { }) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_by_content_hash( @@ -1756,9 +1790,9 @@ impl ExifDao for SqliteExifDao { .filter(content_hash.eq(hash)) .first::(connection.deref_mut()) .optional() - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_rel_paths_sharing_content( @@ -1781,7 +1815,7 @@ impl ExifDao for SqliteExifDao { .select(content_hash) .first::>(connection.deref_mut()) .optional() - .map_err(|_| anyhow::anyhow!("Query error"))? + .map_err(|e| anyhow::anyhow!("Query error: {}", e))? .flatten(); let paths = match hash { @@ -1790,13 +1824,13 @@ impl ExifDao for SqliteExifDao { .select(rel_path) .distinct() .load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error"))?, + .map_err(|e| anyhow::anyhow!("Query error: {}", e))?, None => vec![rel_path_val.to_string()], }; Ok(paths) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_rel_paths_for_library( @@ -1813,9 +1847,9 @@ impl ExifDao for SqliteExifDao { .filter(library_id.eq(library_id_val)) .select(rel_path) .load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_content_hash_anywhere( @@ -1835,9 +1869,9 @@ impl ExifDao for SqliteExifDao { .first::>(connection.deref_mut()) .optional() .map(|opt| opt.flatten()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_rel_paths_by_hash( @@ -1855,9 +1889,9 @@ impl ExifDao for SqliteExifDao { .select(rel_path) .distinct() .load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_rel_paths_for_hashes( @@ -1884,14 +1918,14 @@ impl ExifDao for SqliteExifDao { .select((content_hash.assume_not_null(), rel_path)) .distinct() .load::<(String, String)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error"))?; + .map_err(|e| anyhow::anyhow!("Query error: {}", e))?; for (hash, path) in rows { out.entry(hash).or_default().push(path); } } Ok(out) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn list_rel_paths_for_libraries( @@ -1957,9 +1991,9 @@ impl ExifDao for SqliteExifDao { query .load::<(i32, String)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn delete_exif_by_library( @@ -1978,9 +2012,9 @@ impl ExifDao for SqliteExifDao { ) .execute(self.connection.lock().unwrap().deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Delete error")) + .map_err(|e| anyhow::anyhow!("Delete error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn count_for_library( @@ -1995,9 +2029,9 @@ impl ExifDao for SqliteExifDao { .filter(library_id.eq(library_id_val)) .count() .get_result::(self.connection.lock().unwrap().deref_mut()) - .map_err(|_| anyhow::anyhow!("Count error")) + .map_err(|e| anyhow::anyhow!("Count error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn list_rel_paths_for_library_page( @@ -2021,10 +2055,10 @@ impl ExifDao for SqliteExifDao { .limit(limit) .offset(offset) .load::<(i32, String)>(self.connection.lock().unwrap().deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }, ) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_rows_missing_perceptual_hash( @@ -2069,10 +2103,10 @@ impl ExifDao for SqliteExifDao { .order(id.asc()) .limit(limit) .load::<(i32, String)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }, ) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn backfill_perceptual_hash( @@ -2096,11 +2130,12 @@ impl ExifDao for SqliteExifDao { .set((phash_64.eq(phash_val), dhash_64.eq(dhash_val))) .execute(connection.deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Update error")) + .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } + #[allow(clippy::type_complexity)] fn list_duplicates_exact( &mut self, context: &opentelemetry::Context, @@ -2127,7 +2162,7 @@ impl ExifDao for SqliteExifDao { q = q.filter(library_id.eq(lib)); } q.load::(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error"))? + .map_err(|e| anyhow::anyhow!("Query error: {}", e))? }; if dup_hashes.is_empty() { @@ -2174,7 +2209,7 @@ impl ExifDao for SqliteExifDao { Option, )> = q .load(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error"))?; + .map_err(|e| anyhow::anyhow!("Query error: {}", e))?; Ok(rows .into_iter() @@ -2193,9 +2228,10 @@ impl ExifDao for SqliteExifDao { }) .collect()) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } + #[allow(clippy::type_complexity)] fn list_perceptual_candidates( &mut self, context: &opentelemetry::Context, @@ -2255,7 +2291,7 @@ impl ExifDao for SqliteExifDao { Option, )> = q .load(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error"))?; + .map_err(|e| anyhow::anyhow!("Query error: {}", e))?; // Dedup keyed on content_hash, keeping the first occurrence // (deterministic by the SQL ORDER BY: lowest library_id, @@ -2281,7 +2317,7 @@ impl ExifDao for SqliteExifDao { } Ok(out) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn list_image_paths( @@ -2306,9 +2342,9 @@ impl ExifDao for SqliteExifDao { q = q.filter(duplicate_of_hash.is_null()); } q.load::<(i32, String)>(connection.deref_mut()) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn lookup_duplicate_row( @@ -2368,9 +2404,9 @@ impl ExifDao for SqliteExifDao { duplicate_decided_at: r.10, }) }) - .map_err(|_| anyhow::anyhow!("Query error")) + .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn set_duplicate_of( @@ -2397,9 +2433,9 @@ impl ExifDao for SqliteExifDao { )) .execute(connection.deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Update error")) + .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn clear_duplicate_of( @@ -2424,9 +2460,9 @@ impl ExifDao for SqliteExifDao { )) .execute(connection.deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Update error")) + .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn union_perceptual_tags( @@ -2464,9 +2500,9 @@ impl ExifDao for SqliteExifDao { .bind::(survivor_hash) .execute(connection.deref_mut()) .map(|_| ()) - .map_err(|_| anyhow::anyhow!("Tag union error")) + .map_err(|e| anyhow::anyhow!("Tag union error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } } diff --git a/src/database/models.rs b/src/database/models.rs index 9d005f5..62274e2 100644 --- a/src/database/models.rs +++ b/src/database/models.rs @@ -1,9 +1,75 @@ use crate::database::schema::{ - entities, entity_facts, entity_photo_links, favorites, image_exif, libraries, personas, - photo_insights, users, video_preview_clips, + entities, entity_facts, entity_photo_links, favorites, image_exif, insight_generation_jobs, + libraries, personas, photo_insights, users, video_preview_clips, }; use serde::Serialize; +/// Possible statuses for an insight generation job. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, FromSqlRow)] +#[serde(rename_all = "snake_case")] +pub enum InsightJobStatus { + Running, + Completed, + Failed, + Cancelled, +} + +impl InsightJobStatus { + pub fn as_str(&self) -> &'static str { + match self { + Self::Running => "running", + Self::Completed => "completed", + Self::Failed => "failed", + Self::Cancelled => "cancelled", + } + } + + pub fn parse(s: &str) -> Self { + match s { + "running" => Self::Running, + "completed" => Self::Completed, + "failed" => Self::Failed, + "cancelled" => Self::Cancelled, + other => { + log::warn!( + "Unknown InsightJobStatus value: {:?}, treating as failed", + other + ); + Self::Failed + } + } + } +} + +impl std::fmt::Display for InsightJobStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Type of insight generation (standard vs agentic). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum InsightGenerationType { + Standard, + Agentic, +} + +impl InsightGenerationType { + pub fn as_str(&self) -> &'static str { + match self { + Self::Standard => "standard", + Self::Agentic => "agentic", + } + } +} + +impl std::fmt::Display for InsightGenerationType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + #[derive(Insertable)] #[diesel(table_name = users)] pub struct InsertUser<'a> { @@ -152,6 +218,15 @@ pub struct InsertPhotoInsight { /// inserted before the hash is available stay null and the /// reconciliation pass backfills them. pub content_hash: Option, + pub num_ctx: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub min_p: Option, + pub system_prompt: Option, + pub persona_id: Option, + pub prompt_eval_count: Option, + pub eval_count: Option, } #[derive(Serialize, Queryable, Clone, Debug)] @@ -171,6 +246,15 @@ pub struct PhotoInsight { pub backend: String, pub fewshot_source_ids: Option, pub content_hash: Option, + pub num_ctx: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub min_p: Option, + pub system_prompt: Option, + pub persona_id: Option, + pub prompt_eval_count: Option, + pub eval_count: Option, } // --- Libraries --- @@ -394,3 +478,30 @@ pub struct VideoPreviewClip { pub created_at: String, pub updated_at: String, } + +#[derive(Insertable)] +#[diesel(table_name = insight_generation_jobs)] +pub struct InsertInsightGenerationJob { + pub library_id: i32, + #[diesel(column_name = file_path)] + pub path: String, + #[diesel(column_name = generation_type)] + pub gen_type: String, + pub status: String, + pub started_at: i64, +} + +#[derive(Queryable, Serialize, Clone, Debug)] +pub struct InsightGenerationJob { + pub id: i32, + pub library_id: i32, + #[diesel(column_name = file_path)] + pub path: String, + #[diesel(column_name = generation_type)] + pub gen_type: String, + pub status: String, + pub started_at: i64, + pub completed_at: Option, + pub result_insight_id: Option, + pub error_message: Option, +} diff --git a/src/database/persona_dao.rs b/src/database/persona_dao.rs index 4924244..6ceb2af 100644 --- a/src/database/persona_dao.rs +++ b/src/database/persona_dao.rs @@ -119,7 +119,7 @@ impl PersonaDao for SqlitePersonaDao { .load::(conn.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_persona( @@ -138,7 +138,7 @@ impl PersonaDao for SqlitePersonaDao { .optional() .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn create_persona( @@ -178,7 +178,7 @@ impl PersonaDao for SqlitePersonaDao { .first::(conn.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn update_persona( @@ -241,7 +241,7 @@ impl PersonaDao for SqlitePersonaDao { .optional() .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn delete_persona( @@ -258,7 +258,7 @@ impl PersonaDao for SqlitePersonaDao { .map_err(|e| anyhow::anyhow!("Delete error: {}", e))?; Ok(n > 0) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn bulk_import( @@ -294,7 +294,7 @@ impl PersonaDao for SqlitePersonaDao { } Ok(inserted) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } } diff --git a/src/database/preview_dao.rs b/src/database/preview_dao.rs index c528327..f94ad1d 100644 --- a/src/database/preview_dao.rs +++ b/src/database/preview_dao.rs @@ -96,7 +96,7 @@ impl PreviewDao for SqlitePreviewDao { .map(|_| ()) .map_err(|e| anyhow::anyhow!("Insert error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn update_status( @@ -126,7 +126,7 @@ impl PreviewDao for SqlitePreviewDao { .map(|_| ()) .map_err(|e| anyhow::anyhow!("Update error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + .map_err(|e| DbError::log(DbErrorKind::UpdateError, e)) } fn get_preview( @@ -148,7 +148,7 @@ impl PreviewDao for SqlitePreviewDao { Err(e) => Err(anyhow::anyhow!("Query error: {}", e)), } }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_previews_batch( @@ -170,7 +170,7 @@ impl PreviewDao for SqlitePreviewDao { .load::(connection.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_by_status( @@ -188,7 +188,7 @@ impl PreviewDao for SqlitePreviewDao { .load::(connection.deref_mut()) .map_err(|e| anyhow::anyhow!("Query error: {}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } } diff --git a/src/database/schema.rs b/src/database/schema.rs index 28b5d26..bf5791b 100644 --- a/src/database/schema.rs +++ b/src/database/schema.rs @@ -216,6 +216,15 @@ diesel::table! { backend -> Text, fewshot_source_ids -> Nullable, content_hash -> Nullable, + num_ctx -> Nullable, + temperature -> Nullable, + top_p -> Nullable, + top_k -> Nullable, + min_p -> Nullable, + system_prompt -> Nullable, + persona_id -> Nullable, + prompt_eval_count -> Nullable, + eval_count -> Nullable, } } @@ -271,12 +280,27 @@ diesel::table! { } } +diesel::table! { + insight_generation_jobs (id) { + id -> Integer, + library_id -> Integer, + file_path -> Text, + generation_type -> Text, + status -> Text, + started_at -> BigInt, + completed_at -> Nullable, + result_insight_id -> Nullable, + error_message -> Nullable, + } +} + diesel::joinable!(entity_facts -> photo_insights (source_insight_id)); diesel::joinable!(entity_photo_links -> entities (entity_id)); diesel::joinable!(entity_photo_links -> libraries (library_id)); diesel::joinable!(face_detections -> libraries (library_id)); diesel::joinable!(face_detections -> persons (person_id)); diesel::joinable!(image_exif -> libraries (library_id)); +diesel::joinable!(insight_generation_jobs -> libraries (library_id)); diesel::joinable!(personas -> users (user_id)); diesel::joinable!(persons -> entities (entity_id)); diesel::joinable!(photo_insights -> libraries (library_id)); @@ -292,6 +316,7 @@ diesel::allow_tables_to_appear_in_same_query!( face_detections, favorites, image_exif, + insight_generation_jobs, libraries, location_history, personas, diff --git a/src/database/search_dao.rs b/src/database/search_dao.rs index a74fd92..ee7d0ad 100644 --- a/src/database/search_dao.rs +++ b/src/database/search_dao.rs @@ -227,7 +227,7 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { source_file: search.source_file, }) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn store_searches_batch( @@ -283,7 +283,7 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { Ok(inserted) }) - .map_err(|_| DbError::new(DbErrorKind::InsertError)) + .map_err(|e| DbError::log(DbErrorKind::InsertError, e)) } fn find_searches_in_range( @@ -310,7 +310,7 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { .map(|rows| rows.into_iter().map(|r| r.to_search_record()).collect()) .map_err(|e| anyhow::anyhow!("Query error: {:?}", e)) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_similar_searches( @@ -372,7 +372,7 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { .map(|(_, search)| search) .collect()) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn find_relevant_searches_hybrid( @@ -459,7 +459,7 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { .collect()) } }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn search_exists( @@ -490,7 +490,7 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { Ok(result.count > 0) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } fn get_search_count(&mut self, context: &opentelemetry::Context) -> Result { @@ -513,6 +513,6 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { Ok(result.count) }) - .map_err(|_| DbError::new(DbErrorKind::QueryError)) + .map_err(|e| DbError::log(DbErrorKind::QueryError, e)) } } diff --git a/src/faces.rs b/src/faces.rs index ba47508..3288aa3 100644 --- a/src/faces.rs +++ b/src/faces.rs @@ -1024,9 +1024,14 @@ impl FaceDao for SqliteFaceDao { if let Some(lib) = library_id { q = q.filter(face_detections::library_id.eq(lib)); } - q.select(diesel::dsl::count_distinct(face_detections::content_hash)) - .first(conn.deref_mut()) - .with_context(|| "stats: scanned")? + q.select( + #[allow(deprecated)] + { + diesel::dsl::count_distinct(face_detections::content_hash) + }, + ) + .first(conn.deref_mut()) + .with_context(|| "stats: scanned")? }; let with_faces: i64 = { let mut q = face_detections::table @@ -1035,9 +1040,14 @@ impl FaceDao for SqliteFaceDao { if let Some(lib) = library_id { q = q.filter(face_detections::library_id.eq(lib)); } - q.select(diesel::dsl::count_distinct(face_detections::content_hash)) - .first(conn.deref_mut()) - .with_context(|| "stats: with_faces")? + q.select( + #[allow(deprecated)] + { + diesel::dsl::count_distinct(face_detections::content_hash) + }, + ) + .first(conn.deref_mut()) + .with_context(|| "stats: with_faces")? }; let no_faces: i64 = { let mut q = face_detections::table @@ -1046,9 +1056,14 @@ impl FaceDao for SqliteFaceDao { if let Some(lib) = library_id { q = q.filter(face_detections::library_id.eq(lib)); } - q.select(diesel::dsl::count_distinct(face_detections::content_hash)) - .first(conn.deref_mut()) - .with_context(|| "stats: no_faces")? + q.select( + #[allow(deprecated)] + { + diesel::dsl::count_distinct(face_detections::content_hash) + }, + ) + .first(conn.deref_mut()) + .with_context(|| "stats: no_faces")? }; let failed: i64 = { let mut q = face_detections::table @@ -1057,9 +1072,14 @@ impl FaceDao for SqliteFaceDao { if let Some(lib) = library_id { q = q.filter(face_detections::library_id.eq(lib)); } - q.select(diesel::dsl::count_distinct(face_detections::content_hash)) - .first(conn.deref_mut()) - .with_context(|| "stats: failed")? + q.select( + #[allow(deprecated)] + { + diesel::dsl::count_distinct(face_detections::content_hash) + }, + ) + .first(conn.deref_mut()) + .with_context(|| "stats: failed")? }; // Image-extension filter mirrors `list_unscanned_candidates` so // SCANNED can actually reach 100%: videos sit in `image_exif` but diff --git a/src/file_scan.rs b/src/file_scan.rs index 9e6e79d..51318f8 100644 --- a/src/file_scan.rs +++ b/src/file_scan.rs @@ -53,6 +53,7 @@ pub fn walk_library_files(base_path: &Path, excluded_dirs: &[String]) -> Vec bool { s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit()) } +/// Compute the forward-slash `rel_path` used to look up a video's +/// `image_exif` row, from its absolute path string and the library root. +/// +/// Normalizing to forward slashes is essential on Windows: `file_scan` +/// stores rel_paths forward-slash regardless of OS, but a raw strip of a +/// backslash Windows path (`Z:\...\pic\Melissa\clip.mp4`) yields +/// `Melissa\clip.mp4`. `get_exif_batch` does an exact match with no +/// normalization, so the backslash form misses and the handler falls back +/// to re-hashing the entire file on every request. +fn rel_path_for_lookup(full_path_str: &str, resolved_root: &str) -> String { + full_path_str + .strip_prefix(resolved_root) + .unwrap_or(full_path_str) + .trim_start_matches(['/', '\\']) + .replace('\\', "/") +} + /// Allowed file names inside a hash dir. `playlist.m3u8` plus segment /// files matching the `segment_NNN.ts` template that `PlaylistGenerator` /// writes via `hls_paths::SEGMENT_TEMPLATE`. Anything else (including @@ -570,6 +584,63 @@ mod tests { assert!(!is_allowed_hls_filename("")); } + #[test] + fn rel_path_for_lookup_normalizes_windows_separators() { + // Windows: backslash root + backslash full path. The stored row is + // forward-slash (`Melissa/clip.mp4`), so without normalization the + // lookup misses and the handler re-hashes the whole file. + assert_eq!( + rel_path_for_lookup(r"Z:\Media\pic\Melissa\clip.mp4", r"Z:\Media\pic"), + "Melissa/clip.mp4" + ); + } + + #[test] + fn rel_path_for_lookup_handles_unix_separators() { + assert_eq!( + rel_path_for_lookup("/media/pic/Melissa/clip.mp4", "/media/pic"), + "Melissa/clip.mp4" + ); + } + + #[test] + fn rel_path_for_lookup_file_at_root_has_no_separator() { + // A file directly in the library root has no internal separator, so + // the bug never manifested here — guard against a regression anyway. + assert_eq!( + rel_path_for_lookup(r"Z:\Media\pic\clip.mp4", r"Z:\Media\pic"), + "clip.mp4" + ); + assert_eq!( + rel_path_for_lookup("/media/pic/clip.mp4", "/media/pic"), + "clip.mp4" + ); + } + + #[test] + fn rel_path_for_lookup_strips_leading_separators() { + // Both separator styles are trimmed from the front after the root + // is stripped, regardless of which form the join produced. + assert_eq!( + rel_path_for_lookup(r"Z:\Media\pic\sub\a.mp4", r"Z:\Media\pic"), + "sub/a.mp4" + ); + assert_eq!( + rel_path_for_lookup("/media/pic//sub/a.mp4", "/media/pic"), + "sub/a.mp4" + ); + } + + #[test] + fn rel_path_for_lookup_falls_back_when_root_does_not_match() { + // If the root doesn't prefix the path (e.g. a stale mount), we keep + // the whole path but still normalize separators rather than panic. + assert_eq!( + rel_path_for_lookup(r"D:\other\Melissa\clip.mp4", r"Z:\Media\pic"), + "D:/other/Melissa/clip.mp4" + ); + } + fn make_token() -> String { let claims = Claims::valid_user("1".to_string()); jsonwebtoken::encode( diff --git a/src/knowledge.rs b/src/knowledge.rs index 66815b2..2c6cf6f 100644 --- a/src/knowledge.rs +++ b/src/knowledge.rs @@ -803,38 +803,36 @@ async fn synthesize_merge( .json(serde_json::json!({"error": "source_id and target_id must differ"})); } - let cx = opentelemetry::Context::current(); - let mut dao = dao.lock().expect("Unable to lock KnowledgeDao"); + let (source, target) = { + let cx = opentelemetry::Context::current(); + let mut dao = dao.lock().expect("Unable to lock KnowledgeDao"); - let source = match dao.get_entity_by_id(&cx, body.source_id) { - Ok(Some(e)) => e, - Ok(None) => { - return HttpResponse::BadRequest() - .json(serde_json::json!({"error": "source entity not found"})); - } - Err(e) => { - log::error!("synthesize_merge source lookup: {:?}", e); - return HttpResponse::InternalServerError() - .json(serde_json::json!({"error": "Database error"})); - } + let source = match dao.get_entity_by_id(&cx, body.source_id) { + Ok(Some(e)) => e, + Ok(None) => { + return HttpResponse::BadRequest() + .json(serde_json::json!({"error": "source entity not found"})); + } + Err(e) => { + log::error!("synthesize_merge source lookup: {:?}", e); + return HttpResponse::InternalServerError() + .json(serde_json::json!({"error": "Database error"})); + } + }; + let target = match dao.get_entity_by_id(&cx, body.target_id) { + Ok(Some(e)) => e, + Ok(None) => { + return HttpResponse::BadRequest() + .json(serde_json::json!({"error": "target entity not found"})); + } + Err(e) => { + log::error!("synthesize_merge target lookup: {:?}", e); + return HttpResponse::InternalServerError() + .json(serde_json::json!({"error": "Database error"})); + } + }; + (source, target) }; - let target = match dao.get_entity_by_id(&cx, body.target_id) { - Ok(Some(e)) => e, - Ok(None) => { - return HttpResponse::BadRequest() - .json(serde_json::json!({"error": "target entity not found"})); - } - Err(e) => { - log::error!("synthesize_merge target lookup: {:?}", e); - return HttpResponse::InternalServerError() - .json(serde_json::json!({"error": "Database error"})); - } - }; - - // Drop the DAO lock before the LLM call — the generate request - // is the slow part (seconds) and we don't want to block other - // knowledge reads while it runs. - drop(dao); let source_desc = if source.description.trim().is_empty() { "(none)".to_string() diff --git a/src/library_maintenance.rs b/src/library_maintenance.rs index 67f4ba2..ffa87bb 100644 --- a/src/library_maintenance.rs +++ b/src/library_maintenance.rs @@ -296,6 +296,7 @@ impl GcStats { || self.revived > 0 } + #[allow(dead_code)] pub fn total_deleted(&self) -> usize { self.deleted_face_detections + self.deleted_tagged_photo + self.deleted_photo_insights } diff --git a/src/main.rs b/src/main.rs index 63013ce..4099a5d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -75,6 +75,22 @@ fn main() -> std::io::Result<()> { run_migrations(&mut connect()).expect("Failed to run migrations"); + // Recover orphaned insight generation jobs from a previous crash. + { + use crate::database::{InsightGenerationJobDao, SqliteInsightGenerationJobDao}; + let mut dao = SqliteInsightGenerationJobDao::new(); + let ctx = opentelemetry::Context::new(); + match dao.recover_orphaned_jobs(&ctx) { + Ok(n) if n > 0 => { + info!("Recovered {} orphaned insight generation jobs", n); + } + Ok(_) => {} + Err(e) => { + log::warn!("Failed to recover orphaned insight jobs: {:?}", e); + } + } + } + // One-shot retirement of the pre-content-hash HLS layout. Idempotent // — a second boot finds nothing and reports zero deletions, so it's // safe to leave wired in until the module is removed in a later @@ -181,6 +197,28 @@ fn main() -> std::io::Result<()> { app_state.library_health.clone(), ); + // Periodically clean up stale turn entries from the in-memory + // registry. Runs at the same interval as the configured timeout, + // drops entries older than that timeout. + { + let registry = app_state.turn_registry.clone(); + let timeout_secs = registry.timeout_secs(); + tokio::spawn(async move { + // Sweep at most every 5 minutes, and never less often than the + // timeout itself — otherwise entries could linger up to ~2× the + // configured timeout before being reclaimed. + let interval_secs = timeout_secs.clamp(1, 300); + let interval = tokio::time::Duration::from_secs(interval_secs); + loop { + tokio::time::sleep(interval).await; + let cleaned = registry.cleanup_stale().await; + if cleaned > 0 { + log::info!("TurnRegistry: cleaned up {cleaned} stale entries"); + } + } + }); + } + // Spawn background job to generate daily conversation summaries { use crate::ai::generate_daily_summaries; @@ -308,6 +346,8 @@ fn main() -> std::io::Result<()> { .service(memories::list_memories) .service(ai::generate_insight_handler) .service(ai::generate_agentic_insight_handler) + .service(ai::generation_status_handler) + .service(ai::cancel_generation_handler) .service(ai::get_insight_handler) .service(ai::delete_insight_handler) .service(ai::get_all_insights_handler) @@ -317,6 +357,9 @@ fn main() -> std::io::Result<()> { .service(ai::chat_stream_handler) .service(ai::chat_history_handler) .service(ai::chat_rewind_handler) + .service(ai::turn_async_handler) + .service(ai::turn_replay_handler) + .service(ai::cancel_turn_handler) .service(ai::rate_insight_handler) .service(ai::export_training_data_handler) .service(libraries::list_libraries) diff --git a/src/state.rs b/src/state.rs index 8cfccbb..f9adda7 100644 --- a/src/state.rs +++ b/src/state.rs @@ -4,12 +4,13 @@ use crate::ai::face_client::FaceClient; use crate::ai::insight_chat::{ChatLockMap, InsightChatService}; use crate::ai::llamacpp::LlamaCppClient; use crate::ai::openrouter::OpenRouterClient; +use crate::ai::turn_registry::TurnRegistry; use crate::ai::{InsightGenerator, OllamaClient, SmsApiClient}; use crate::database::{ - CalendarEventDao, DailySummaryDao, ExifDao, InsightDao, KnowledgeDao, LocationHistoryDao, - SearchHistoryDao, SqliteCalendarEventDao, SqliteDailySummaryDao, SqliteExifDao, - SqliteInsightDao, SqliteKnowledgeDao, SqliteLocationHistoryDao, SqliteSearchHistoryDao, - connect, + CalendarEventDao, DailySummaryDao, ExifDao, InsightDao, InsightGenerationJobDao, KnowledgeDao, + LocationHistoryDao, SearchHistoryDao, SqliteCalendarEventDao, SqliteDailySummaryDao, + SqliteExifDao, SqliteInsightDao, SqliteInsightGenerationJobDao, SqliteKnowledgeDao, + SqliteLocationHistoryDao, SqliteSearchHistoryDao, connect, }; use crate::database::{PreviewDao, SqlitePreviewDao}; use crate::faces; @@ -19,6 +20,7 @@ use crate::video::actors::{ PlaylistGenerator, PreviewClipGenerator, StreamActor, VideoPlaylistManager, }; use actix::{Actor, Addr}; +use std::collections::HashMap; use std::env; use std::sync::{Arc, Mutex, RwLock}; @@ -77,15 +79,11 @@ pub struct AppState { pub insight_generator: InsightGenerator, /// Chat continuation service. Hold an Arc so handlers can clone cheaply. pub insight_chat: Arc, - /// Face inference client (calls Apollo's `/api/internal/faces/*`). - /// Disabled (`is_enabled() == false`) when neither `APOLLO_FACE_API_BASE_URL` - /// nor `APOLLO_API_BASE_URL` is set; the file-watch hook (Phase 3) and - /// manual-face-create handler short-circuit in that case. + pub turn_registry: Arc, pub face_client: FaceClient, - /// CLIP inference client (calls Apollo's `/api/internal/clip/*`). - /// Same disabled semantics as `face_client`: unset env → no-op - /// backlog drain, /photos/search returns an empty result. pub clip_client: ClipClient, + pub insight_job_dao: Arc>>, + pub insight_job_handles: Arc>>, } impl AppState { @@ -121,9 +119,12 @@ impl AppState { sms_client: SmsApiClient, insight_generator: InsightGenerator, insight_chat: Arc, + turn_registry: Arc, preview_dao: Arc>>, face_client: FaceClient, clip_client: ClipClient, + insight_job_dao: Arc>>, + insight_job_handles: Arc>>, ) -> Self { assert!( !libraries_vec.is_empty(), @@ -163,8 +164,11 @@ impl AppState { sms_client, insight_generator, insight_chat, + turn_registry, face_client, clip_client, + insight_job_dao, + insight_job_handles, } } @@ -253,6 +257,12 @@ impl Default for AppState { let face_dao: Arc>> = Arc::new(Mutex::new(Box::new(faces::SqliteFaceDao::new()))); + // Initialize insight generation job DAO (async generation tracking) + let insight_job_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteInsightGenerationJobDao::new()))); + let insight_job_handles: Arc>> = + Arc::new(Mutex::new(HashMap::new())); + // Load base path and ensure the primary library row reflects it. let base_path = env::var("BASE_PATH").expect("BASE_PATH was not set in the env"); let mut seed_conn = connect(); @@ -294,6 +304,14 @@ impl Default for AppState { chat_locks, )); + // Turn registry for reconnectable chat turns. 5-minute timeout for + // stale turns (background cleaner drops entries older than this). + let timeout_secs: u64 = env::var("INSIGHT_CHAT_TURN_TIMEOUT_SECS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(300); + let turn_registry = Arc::new(TurnRegistry::new(timeout_secs)); + // Ensure preview clips directory exists let preview_clips_path = env::var("PREVIEW_CLIPS_DIRECTORY").unwrap_or_else(|_| "preview_clips".to_string()); @@ -316,9 +334,12 @@ impl Default for AppState { sms_client, insight_generator, insight_chat, + turn_registry, preview_dao, face_client, clip_client, + insight_job_dao, + insight_job_handles, ) } } @@ -389,6 +410,7 @@ fn parse_llamacpp_allowed_models() -> Vec { impl AppState { /// Creates an AppState instance for testing with temporary directories pub fn test_state() -> Self { + use crate::database::insight_generation_job_dao::SqliteInsightGenerationJobDao; use actix::Actor; // Create a base temporary directory let temp_dir = tempfile::tempdir().expect("Failed to create temp directory"); @@ -471,6 +493,9 @@ impl AppState { chat_locks, )); + // Turn registry for test state. + let turn_registry = Arc::new(TurnRegistry::new(300)); + // Initialize test preview DAO let preview_dao: Arc>> = Arc::new(Mutex::new(Box::new(SqlitePreviewDao::new()))); @@ -499,9 +524,12 @@ impl AppState { sms_client, insight_generator, insight_chat, + turn_registry, preview_dao, FaceClient::new(None), // disabled in test ClipClient::new(None), // disabled in test + Arc::new(Mutex::new(Box::new(SqliteInsightGenerationJobDao::new()))), // placeholder for test + Arc::new(Mutex::new(HashMap::new())), // placeholder for test ) } } diff --git a/src/testhelpers.rs b/src/testhelpers.rs index 1536dbb..8c686a1 100644 --- a/src/testhelpers.rs +++ b/src/testhelpers.rs @@ -144,6 +144,7 @@ impl PreviewDao for TestPreviewDao { } else { Err(DbError { kind: DbErrorKind::UpdateError, + source: None, }) } } diff --git a/src/video/actors.rs b/src/video/actors.rs index faad727..22ec1ac 100644 --- a/src/video/actors.rs +++ b/src/video/actors.rs @@ -159,8 +159,16 @@ pub async fn probe_video_stream_meta(video_path: &str) -> VideoStreamMeta { .arg("v:0") .arg("-print_format") .arg("json") + // NOTE: request `stream_side_data_list` (stream-level side data, read + // from the moov atom), NOT the bare `side_data_list` section. On modern + // ffprobe the latter is the *frame* side-data section, which forces + // ffprobe to enumerate every frame — reading the entire mdat over the + // network. For non-faststart phone clips on an SMB mount that turned a + // metadata probe into a full-file read (tens of seconds per open). The + // Display Matrix rotation we need is present at stream level, so this + // keeps codec/fps/rotation while reading only the header. .arg("-show_entries") - .arg("stream=codec_name,r_frame_rate,avg_frame_rate:stream_tags=rotate:side_data_list") + .arg("stream=codec_name,r_frame_rate,avg_frame_rate:stream_tags=rotate:stream_side_data_list") .arg(video_path) .output() .await;