diff --git a/src/ai/handlers.rs b/src/ai/handlers.rs index 5c4036e..b07b157 100644 --- a/src/ai/handlers.rs +++ b/src/ai/handlers.rs @@ -3,6 +3,7 @@ use opentelemetry::KeyValue; use opentelemetry::trace::{Span, Status, Tracer}; use serde::{Deserialize, Serialize}; +use crate::ai::insight_chat::ChatTurnRequest; use crate::ai::{InsightGenerator, ModelCapabilities, OllamaClient}; use crate::data::Claims; use crate::database::{ExifDao, InsightDao}; @@ -70,6 +71,9 @@ pub struct PhotoInsightResponse { #[serde(skip_serializing_if = "Option::is_none")] pub approved: Option, pub backend: String, + /// True when the insight was generated agentically and a chat + /// continuation can be started against it. Drives the mobile chat button. + pub has_training_messages: bool, } #[derive(Debug, Serialize)] @@ -192,6 +196,7 @@ pub async fn get_insight_handler( prompt_eval_count: None, eval_count: None, approved: insight.approved, + has_training_messages: insight.training_messages.is_some(), backend: insight.backend, }; HttpResponse::Ok().json(response) @@ -260,6 +265,7 @@ pub async fn get_all_insights_handler( prompt_eval_count: None, eval_count: None, approved: insight.approved, + has_training_messages: insight.training_messages.is_some(), backend: insight.backend, }) .collect(); @@ -353,6 +359,7 @@ pub async fn generate_agentic_insight_handler( prompt_eval_count, eval_count, approved: insight.approved, + has_training_messages: insight.training_messages.is_some(), backend: insight.backend, }; HttpResponse::Ok().json(response) @@ -558,3 +565,186 @@ pub async fn export_training_data_handler( } } } + +#[derive(Debug, Deserialize)] +pub struct ChatTurnHttpRequest { + pub file_path: String, + #[serde(default)] + pub library: Option, + pub user_message: String, + #[serde(default)] + pub model: Option, + #[serde(default)] + pub backend: Option, + #[serde(default)] + pub num_ctx: Option, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub top_p: Option, + #[serde(default)] + pub top_k: Option, + #[serde(default)] + pub min_p: Option, + #[serde(default)] + pub max_iterations: Option, + #[serde(default)] + pub amend: bool, +} + +#[derive(Debug, Serialize)] +pub struct ChatTurnHttpResponse { + pub assistant_message: String, + pub tool_calls_made: usize, + pub iterations_used: usize, + pub truncated: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_eval_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub eval_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub amended_insight_id: Option, + pub backend: String, + pub model: String, +} + +/// POST /insights/chat — submit a follow-up turn against an existing insight. +#[post("/insights/chat")] +pub async fn chat_turn_handler( + http_request: HttpRequest, + _claims: Claims, + request: web::Json, + app_state: web::Data, +) -> impl Responder { + let parent_context = extract_context_from_request(&http_request); + let tracer = global_tracer(); + let mut span = tracer.start_with_context("http.insights.chat", &parent_context); + span.set_attribute(KeyValue::new("file_path", request.file_path.clone())); + + let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) { + Ok(Some(lib)) => lib, + Ok(None) => app_state.primary_library(), + Err(e) => { + return HttpResponse::BadRequest().json(serde_json::json!({ + "error": format!("invalid library: {}", e) + })); + } + }; + + let chat_req = ChatTurnRequest { + library_id: library.id, + file_path: request.file_path.clone(), + user_message: request.user_message.clone(), + model: request.model.clone(), + backend: request.backend.clone(), + num_ctx: request.num_ctx, + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + min_p: request.min_p, + max_iterations: request.max_iterations, + amend: request.amend, + }; + + match app_state.insight_chat.chat_turn(chat_req).await { + Ok(result) => { + span.set_status(Status::Ok); + HttpResponse::Ok().json(ChatTurnHttpResponse { + assistant_message: result.assistant_message, + tool_calls_made: result.tool_calls_made, + iterations_used: result.iterations_used, + truncated: result.truncated, + prompt_eval_count: result.prompt_eval_count, + eval_count: result.eval_count, + amended_insight_id: result.amended_insight_id, + backend: result.backend_used, + model: result.model_used, + }) + } + Err(e) => { + let msg = format!("{}", e); + log::error!("Chat turn failed: {}", msg); + span.set_status(Status::error(msg.clone())); + + // Map well-known errors to client-facing 4xx codes. + if msg.contains("no insight found") { + HttpResponse::NotFound().json(serde_json::json!({ "error": msg })) + } else if msg.contains("no chat history") { + HttpResponse::Conflict().json(serde_json::json!({ "error": msg })) + } else if msg.contains("user_message") + || msg.contains("unknown backend") + || msg.contains("switching from local to hybrid") + || msg.contains("hybrid backend unavailable") + { + HttpResponse::BadRequest().json(serde_json::json!({ "error": msg })) + } else { + HttpResponse::InternalServerError().json(serde_json::json!({ "error": msg })) + } + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ChatHistoryQuery { + pub path: String, + #[serde(default)] + pub library: Option, +} + +#[derive(Debug, Serialize)] +pub struct ChatHistoryHttpResponse { + pub messages: Vec, + pub turn_count: usize, + pub model_version: String, + pub backend: String, +} + +#[derive(Debug, Serialize)] +pub struct RenderedHistoryMessage { + pub role: String, + pub content: String, + pub is_initial: bool, +} + +/// GET /insights/chat/history — return the rendered transcript for a photo. +#[get("/insights/chat/history")] +pub async fn chat_history_handler( + _claims: Claims, + query: web::Query, + app_state: web::Data, +) -> impl Responder { + // library param parsed for parity with other insight endpoints, even + // though load_history currently keys on file_path alone (matches the + // existing get_insight DAO contract). + let _library = libraries::resolve_library_param(&app_state, query.library.as_deref()) + .ok() + .flatten() + .unwrap_or_else(|| app_state.primary_library()); + + match app_state.insight_chat.load_history(&query.path) { + Ok(view) => HttpResponse::Ok().json(ChatHistoryHttpResponse { + messages: view + .messages + .into_iter() + .map(|m| RenderedHistoryMessage { + role: m.role, + content: m.content, + is_initial: m.is_initial, + }) + .collect(), + turn_count: view.turn_count, + model_version: view.model_version, + backend: view.backend, + }), + Err(e) => { + let msg = format!("{}", e); + if msg.contains("no insight found") { + HttpResponse::NotFound().json(serde_json::json!({ "error": msg })) + } else if msg.contains("no chat history") { + HttpResponse::Conflict().json(serde_json::json!({ "error": msg })) + } else { + HttpResponse::InternalServerError().json(serde_json::json!({ "error": msg })) + } + } + } +} diff --git a/src/ai/insight_chat.rs b/src/ai/insight_chat.rs new file mode 100644 index 0000000..7e5422c --- /dev/null +++ b/src/ai/insight_chat.rs @@ -0,0 +1,640 @@ +use anyhow::{Result, anyhow, bail}; +use chrono::Utc; +use opentelemetry::KeyValue; +use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use tokio::sync::Mutex as TokioMutex; + +use crate::ai::insight_generator::InsightGenerator; +use crate::ai::llm_client::{ChatMessage, LlmClient}; +use crate::ai::ollama::OllamaClient; +use crate::ai::openrouter::OpenRouterClient; +use crate::database::InsightDao; +use crate::database::models::InsertPhotoInsight; +use crate::otel::global_tracer; +use crate::utils::normalize_path; + +const DEFAULT_MAX_ITERATIONS: usize = 6; +const DEFAULT_NUM_CTX: i32 = 8192; +/// Headroom reserved for the model's response, deducted from the context +/// budget when deciding whether to truncate the replayed history. +const RESPONSE_HEADROOM_TOKENS: usize = 2048; +/// Cheap byte-to-token approximation used by the truncation pass. The real +/// tokenization is model-specific; this avoids carrying tiktoken just for a +/// soft bound. +const BYTES_PER_TOKEN: usize = 4; + +pub type ChatLockMap = Arc>>>>; + +#[derive(Debug)] +pub struct ChatTurnRequest { + pub library_id: i32, + pub file_path: String, + pub user_message: String, + /// Override the model id. Local mode: an Ollama model name. Hybrid: + /// an OpenRouter id. None defers to the stored insight's `model_version`. + pub model: Option, + /// Override the backend used for this turn. None defers to the stored + /// insight's `backend`. Switching `local -> hybrid` is rejected in v1. + pub backend: Option, + pub num_ctx: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub min_p: Option, + pub max_iterations: Option, + /// When true, write a new insight row (regenerating title) instead of + /// updating training_messages on the existing row. + pub amend: bool, +} + +#[derive(Debug)] +pub struct ChatTurnResult { + pub assistant_message: String, + pub tool_calls_made: usize, + pub iterations_used: usize, + pub truncated: bool, + pub prompt_eval_count: Option, + pub eval_count: Option, + /// Set when `amend=true` and the new insight row was inserted. + pub amended_insight_id: Option, + /// Backend used for this turn — useful when the client overrode the + /// stored value. + pub backend_used: String, + /// Model identifier the chat backend ran with. + pub model_used: String, +} + +#[derive(Clone)] +pub struct InsightChatService { + generator: Arc, + ollama: OllamaClient, + openrouter: Option>, + insight_dao: Arc>>, + chat_locks: ChatLockMap, +} + +impl InsightChatService { + pub fn new( + generator: Arc, + ollama: OllamaClient, + openrouter: Option>, + insight_dao: Arc>>, + chat_locks: ChatLockMap, + ) -> Self { + Self { + generator, + ollama, + openrouter, + insight_dao, + chat_locks, + } + } + + /// Load the rendered transcript for chat-UI display. Filters internal + /// scaffolding (system message, tool turns, tool-dispatch-only assistant + /// messages) and drops base64 images from user turns to keep payloads + /// small. The first remaining user message is flagged `is_initial`. + pub fn load_history(&self, file_path: &str) -> Result { + let normalized = normalize_path(file_path); + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let insight = dao + .get_insight(&cx, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + .ok_or_else(|| anyhow!("no insight found for path"))?; + + let raw = insight + .training_messages + .as_ref() + .ok_or_else(|| anyhow!("insight has no chat history (pre-agentic insight)"))?; + let messages: Vec = serde_json::from_str(raw) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + let mut rendered = Vec::new(); + let mut user_turns_seen = 0usize; + let mut assistant_turns_seen = 0usize; + for msg in &messages { + match msg.role.as_str() { + "system" => continue, + "tool" => continue, + "assistant" => { + let has_tool_calls = msg + .tool_calls + .as_ref() + .map(|c| !c.is_empty()) + .unwrap_or(false); + if has_tool_calls && msg.content.trim().is_empty() { + continue; + } + assistant_turns_seen += 1; + rendered.push(RenderedMessage { + role: "assistant".to_string(), + content: msg.content.clone(), + is_initial: false, + }); + } + "user" => { + let is_initial = user_turns_seen == 0; + user_turns_seen += 1; + rendered.push(RenderedMessage { + role: "user".to_string(), + content: msg.content.clone(), + is_initial, + }); + } + _ => continue, + } + } + + Ok(HistoryView { + messages: rendered, + turn_count: assistant_turns_seen, + model_version: insight.model_version, + backend: insight.backend, + }) + } + + pub async fn chat_turn(&self, req: ChatTurnRequest) -> Result { + let tracer = global_tracer(); + let parent_cx = opentelemetry::Context::new(); + let mut span = tracer.start_with_context("ai.insight.chat_turn", &parent_cx); + span.set_attribute(KeyValue::new("file_path", req.file_path.clone())); + span.set_attribute(KeyValue::new("library_id", req.library_id as i64)); + span.set_attribute(KeyValue::new("amend", req.amend)); + + if req.user_message.trim().is_empty() { + bail!("user_message must not be empty"); + } + if req.user_message.len() > 8192 { + bail!("user_message exceeds 8192 chars"); + } + + let normalized = normalize_path(&req.file_path); + + // 1. Acquire the per-(library, file) async mutex. Two concurrent + // chat turns on the same insight would race on the JSON blob — + // the lock serialises them. + let lock_key = (req.library_id, normalized.clone()); + let entry_lock = { + let mut locks = self.chat_locks.lock().await; + locks + .entry(lock_key.clone()) + .or_insert_with(|| Arc::new(TokioMutex::new(()))) + .clone() + }; + let _guard = entry_lock.lock().await; + + // 2. Load the current insight + history. + let insight = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.get_insight(&cx, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + .ok_or_else(|| anyhow!("no insight found for path"))? + }; + let raw_history = insight + .training_messages + .as_ref() + .ok_or_else(|| { + anyhow!("insight has no chat history; regenerate this insight in agentic mode") + })? + .clone(); + let mut messages: Vec = serde_json::from_str(&raw_history) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + // 3. Resolve effective backend. Reject the unsupported switch. + let stored_backend = insight.backend.clone(); + let effective_backend = req + .backend + .as_deref() + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| stored_backend.clone()); + if !matches!(effective_backend.as_str(), "local" | "hybrid") { + bail!( + "unknown backend '{}'; expected 'local' or 'hybrid'", + effective_backend + ); + } + if stored_backend == "local" && effective_backend == "hybrid" { + bail!( + "switching from local to hybrid mid-chat isn't supported yet; \ + regenerate the insight in hybrid mode if you want OpenRouter chat" + ); + } + let is_hybrid = effective_backend == "hybrid"; + span.set_attribute(KeyValue::new("backend", effective_backend.clone())); + + // 4. Build the chat backend client. Ollama in local mode, a freshly + // cloned OpenRouter client in hybrid mode (clone so per-request + // sampling/model overrides don't leak into shared state). + let max_iterations = req + .max_iterations + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .clamp(1, env_max_iterations()); + span.set_attribute(KeyValue::new("max_iterations", max_iterations as i64)); + + let stored_model = insight.model_version.clone(); + let custom_model = req + .model + .clone() + .or_else(|| Some(stored_model.clone())) + .filter(|m| !m.is_empty()); + + let mut ollama_client = self.ollama.clone(); + let mut openrouter_client: Option = None; + + if is_hybrid { + let arc = self.openrouter.as_ref().ok_or_else(|| { + anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured") + })?; + let mut c: OpenRouterClient = (**arc).clone(); + if let Some(ref m) = custom_model { + c.primary_model = m.clone(); + } + if req.temperature.is_some() + || req.top_p.is_some() + || req.top_k.is_some() + || req.min_p.is_some() + { + c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p); + } + if let Some(ctx) = req.num_ctx { + c.set_num_ctx(Some(ctx)); + } + openrouter_client = Some(c); + } else { + // Local-mode model swap. Build a new client when the chat model + // differs from the configured one (mirrors the agentic pattern). + if let Some(ref m) = custom_model + && m != &self.ollama.primary_model + { + ollama_client = OllamaClient::new( + self.ollama.primary_url.clone(), + self.ollama.fallback_url.clone(), + m.clone(), + Some(m.clone()), + ); + } + if req.temperature.is_some() + || req.top_p.is_some() + || req.top_k.is_some() + || req.min_p.is_some() + { + ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p); + } + if let Some(ctx) = req.num_ctx { + ollama_client.set_num_ctx(Some(ctx)); + } + } + + let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client { + c + } else { + &ollama_client + }; + let model_used = chat_backend.primary_model().to_string(); + span.set_attribute(KeyValue::new("model", model_used.clone())); + + // 5. Decide vision + tool set. In hybrid we always omit + // `describe_photo` (matches the original generation flow). In + // local we trust the stored history's first-user shape: if it + // carries `images`, the original model was vision-capable, and + // we keep `describe_photo` available. + let local_first_user_has_image = messages + .iter() + .find(|m| m.role == "user") + .and_then(|m| m.images.as_ref()) + .map(|imgs| !imgs.is_empty()) + .unwrap_or(false); + let offer_describe_tool = !is_hybrid && local_first_user_has_image; + let tools = InsightGenerator::build_tool_definitions(offer_describe_tool); + + // Image base64 only needed when describe_photo is on the menu. Load + // lazily to avoid disk IO when the loop never invokes it. + let image_base64: Option = if offer_describe_tool { + self.generator.load_image_as_base64(&normalized).ok() + } else { + None + }; + + // 6. Apply truncation budget. Drops oldest tool_call+tool pairs + // (preserves system + first user including any images). + let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize) + .saturating_sub(RESPONSE_HEADROOM_TOKENS); + let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN); + let truncated = apply_context_budget(&mut messages, budget_bytes); + if truncated { + span.set_attribute(KeyValue::new("history_truncated", true)); + } + + // 7. Append the new user turn. + messages.push(ChatMessage::user(req.user_message.clone())); + + let insight_cx = parent_cx.with_span(span); + + // 8. Agentic loop — same shape as insight_generator's, but capped + // tighter and dispatching tools through the shared executor. + let loop_span = tracer.start_with_context("ai.chat.loop", &insight_cx); + let loop_cx = insight_cx.with_span(loop_span); + let mut tool_calls_made = 0usize; + let mut iterations_used = 0usize; + let mut last_prompt_eval_count: Option = None; + let mut last_eval_count: Option = None; + let mut final_content = String::new(); + + for iteration in 0..max_iterations { + iterations_used = iteration + 1; + log::info!("Chat iteration {}/{}", iterations_used, max_iterations); + + let (response, prompt_tokens, eval_tokens) = chat_backend + .chat_with_tools(messages.clone(), tools.clone()) + .await?; + last_prompt_eval_count = prompt_tokens; + last_eval_count = eval_tokens; + + // Ollama rejects non-object tool-call arguments on replay. + let mut response = response; + if let Some(ref mut tcs) = response.tool_calls { + for tc in tcs.iter_mut() { + if !tc.function.arguments.is_object() { + tc.function.arguments = serde_json::Value::Object(Default::default()); + } + } + } + + messages.push(response.clone()); + + if let Some(ref tool_calls) = response.tool_calls + && !tool_calls.is_empty() + { + for tool_call in tool_calls { + tool_calls_made += 1; + log::info!( + "Chat tool call [{}]: {} {:?}", + iteration, + tool_call.function.name, + tool_call.function.arguments + ); + let result = self + .generator + .execute_tool( + &tool_call.function.name, + &tool_call.function.arguments, + &ollama_client, + &image_base64, + &normalized, + &loop_cx, + ) + .await; + messages.push(ChatMessage::tool_result(result)); + } + continue; + } + + final_content = response.content; + break; + } + + if final_content.is_empty() { + // The model never produced a final answer; ask once more without + // tools to force a textual reply. + log::info!( + "Chat loop exhausted after {} iterations, requesting final answer", + iterations_used + ); + messages.push(ChatMessage::user( + "Please write your final answer now without calling any more tools.", + )); + let (final_response, prompt_tokens, eval_tokens) = chat_backend + .chat_with_tools(messages.clone(), vec![]) + .await?; + last_prompt_eval_count = prompt_tokens; + last_eval_count = eval_tokens; + final_content = final_response.content.clone(); + messages.push(final_response); + } + + loop_cx.span().set_status(Status::Ok); + + // 9. Persist. Append mode rewrites the JSON blob in place; amend + // mode regenerates the title and inserts a new insight row, + // relying on store_insight to flip prior rows' is_current=false. + let json = serde_json::to_string(&messages) + .map_err(|e| anyhow!("failed to serialize chat history: {}", e))?; + + let mut amended_insight_id: Option = None; + if req.amend { + let title_prompt = format!( + "Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\ + Capture the key moment or theme. Return ONLY the title, nothing else.", + final_content + ); + let title_raw = chat_backend + .generate( + &title_prompt, + Some( + "You are my long term memory assistant. Use only the information provided. Do not invent details.", + ), + None, + ) + .await?; + let title = title_raw.trim().trim_matches('"').to_string(); + + let new_row = InsertPhotoInsight { + library_id: req.library_id, + file_path: normalized.clone(), + title, + summary: final_content.clone(), + generated_at: Utc::now().timestamp(), + model_version: model_used.clone(), + is_current: true, + training_messages: Some(json), + backend: effective_backend.clone(), + }; + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let stored = dao + .store_insight(&cx, new_row) + .map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?; + amended_insight_id = Some(stored.id); + } else { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.update_training_messages(&cx, req.library_id, &normalized, &json) + .map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?; + } + + Ok(ChatTurnResult { + assistant_message: final_content, + tool_calls_made, + iterations_used, + truncated, + prompt_eval_count: last_prompt_eval_count, + eval_count: last_eval_count, + amended_insight_id, + backend_used: effective_backend, + model_used, + }) + } +} + +/// Read AGENTIC_CHAT_MAX_ITERATIONS once per call. Cheap; keeps the code +/// free of static globals and lets the operator change the cap by env without +/// a restart in test harnesses (the running server still caches via Default). +fn env_max_iterations() -> usize { + std::env::var("AGENTIC_CHAT_MAX_ITERATIONS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .max(1) +} + +/// View returned to clients for chat-UI rendering. +#[derive(Debug)] +pub struct HistoryView { + pub messages: Vec, + pub turn_count: usize, + pub model_version: String, + pub backend: String, +} + +#[derive(Debug)] +pub struct RenderedMessage { + pub role: String, + pub content: String, + pub is_initial: bool, +} + +/// Trim history to fit within `budget_bytes` of serialized JSON. Preserves +/// the system message and the first user message (with its base64 images +/// intact, since dropping those would invalidate the model's prior visual +/// reasoning). Drops the oldest assistant-tool_call + corresponding +/// tool-result pair on each pass until the budget is met or only the +/// preserved prefix remains. +/// +/// Returns true when at least one message was dropped. +pub(crate) fn apply_context_budget(messages: &mut Vec, budget_bytes: usize) -> bool { + if budget_bytes == 0 { + return false; + } + if estimate_bytes(messages) <= budget_bytes { + return false; + } + + // Find the index past the protected prefix: system messages + the first + // user message. Everything after is droppable in pairs. + let first_user_idx = messages.iter().position(|m| m.role == "user"); + let preserve_through = match first_user_idx { + Some(i) => i, // keep [0..=i] + None => return false, + }; + + let mut dropped_any = false; + loop { + if estimate_bytes(messages) <= budget_bytes { + break; + } + // Find the oldest assistant-with-tool_calls strictly after the + // preserved prefix. Drop it together with the following tool turn(s) + // until we hit the next assistant or user turn. + let drop_start = (preserve_through + 1..messages.len()).find(|&i| { + let m = &messages[i]; + m.role == "assistant" + && m.tool_calls + .as_ref() + .map(|c| !c.is_empty()) + .unwrap_or(false) + }); + let Some(start) = drop_start else { break }; + // Determine end: drop the assistant turn plus any contiguous tool + // result turns that follow. + let mut end = start + 1; + while end < messages.len() && messages[end].role == "tool" { + end += 1; + } + // Stop if dropping these would leave the just-appended user turn at + // the end alone with no preceding context — we still want it kept. + if end > messages.len() { + break; + } + messages.drain(start..end); + dropped_any = true; + } + + dropped_any +} + +fn estimate_bytes(messages: &[ChatMessage]) -> usize { + serde_json::to_string(messages) + .map(|s| s.len()) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::llm_client::{ToolCall, ToolCallFunction}; + + fn assistant_with_tool_call(name: &str) -> ChatMessage { + ChatMessage { + role: "assistant".to_string(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: None, + function: ToolCallFunction { + name: name.to_string(), + arguments: serde_json::Value::Object(Default::default()), + }, + }]), + images: None, + } + } + + fn assistant_text(text: &str) -> ChatMessage { + ChatMessage { + role: "assistant".to_string(), + content: text.to_string(), + tool_calls: None, + images: None, + } + } + + #[test] + fn truncation_preserves_system_and_first_user() { + let mut msgs = vec![ + ChatMessage::system("sys"), + ChatMessage::user("first user with lots of context".repeat(50)), + assistant_with_tool_call("get_x"), + ChatMessage::tool_result("x result ".repeat(200)), + assistant_with_tool_call("get_y"), + ChatMessage::tool_result("y result ".repeat(200)), + assistant_text("final answer"), + ]; + let original_len = msgs.len(); + let dropped = apply_context_budget(&mut msgs, 500); + assert!(dropped, "should drop something at this small budget"); + assert!(msgs.len() < original_len); + // First two messages preserved. + assert_eq!(msgs[0].role, "system"); + assert_eq!(msgs[1].role, "user"); + } + + #[test] + fn truncation_no_op_when_under_budget() { + let mut msgs = vec![ChatMessage::system("s"), ChatMessage::user("u")]; + let dropped = apply_context_budget(&mut msgs, 1_000_000); + assert!(!dropped); + assert_eq!(msgs.len(), 2); + } + + #[test] + fn truncation_returns_false_with_no_droppable_pairs() { + // Only system + user, no tool-call turns to drop. + let mut msgs = vec![ChatMessage::system("s"), ChatMessage::user("u")]; + let dropped = apply_context_budget(&mut msgs, 1); + assert!(!dropped); + } +} diff --git a/src/ai/insight_generator.rs b/src/ai/insight_generator.rs index fcf89ee..09fd585 100644 --- a/src/ai/insight_generator.rs +++ b/src/ai/insight_generator.rs @@ -96,7 +96,7 @@ impl InsightGenerator { /// first root under which the file exists. Insights may be generated /// for any library — the generator itself doesn't know which — so we /// probe each root rather than trust a single `base_path`. - fn resolve_full_path(&self, rel_path: &str) -> Option { + pub(crate) fn resolve_full_path(&self, rel_path: &str) -> Option { use std::path::Path; for lib in &self.libraries { let candidate = Path::new(&lib.root_path).join(rel_path); @@ -129,7 +129,7 @@ impl InsightGenerator { /// Load image file, resize it, and encode as base64 for vision models /// Resizes to max 1024px on longest edge to reduce context usage - fn load_image_as_base64(&self, file_path: &str) -> Result { + pub(crate) fn load_image_as_base64(&self, file_path: &str) -> Result { use image::imageops::FilterType; let full_path = self.resolve_full_path(file_path).ok_or_else(|| { @@ -1411,7 +1411,7 @@ Return ONLY the summary, nothing else."#, // ── Tool executors for agentic loop ──────────────────────────────── /// Dispatch a tool call to the appropriate executor - async fn execute_tool( + pub(crate) async fn execute_tool( &self, tool_name: &str, arguments: &serde_json::Value, @@ -2136,7 +2136,7 @@ Return ONLY the summary, nothing else."#, // ── Agentic insight generation ────────────────────────────────────── /// Build the list of tool definitions for the agentic loop - fn build_tool_definitions(has_vision: bool) -> Vec { + pub(crate) fn build_tool_definitions(has_vision: bool) -> Vec { let mut tools = vec![ Tool::function( "search_rag", diff --git a/src/ai/mod.rs b/src/ai/mod.rs index 60a3f43..93735c0 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -1,5 +1,6 @@ pub mod daily_summary_job; pub mod handlers; +pub mod insight_chat; pub mod insight_generator; pub mod llm_client; pub mod ollama; @@ -10,9 +11,10 @@ pub mod sms_client; #[allow(unused_imports)] pub use daily_summary_job::{generate_daily_summaries, strip_summary_boilerplate}; pub use handlers::{ - delete_insight_handler, export_training_data_handler, generate_agentic_insight_handler, - generate_insight_handler, get_all_insights_handler, get_available_models_handler, - get_insight_handler, get_openrouter_models_handler, rate_insight_handler, + chat_history_handler, chat_turn_handler, delete_insight_handler, export_training_data_handler, + generate_agentic_insight_handler, generate_insight_handler, get_all_insights_handler, + get_available_models_handler, get_insight_handler, get_openrouter_models_handler, + rate_insight_handler, }; pub use insight_generator::InsightGenerator; #[allow(unused_imports)] diff --git a/src/database/insights_dao.rs b/src/database/insights_dao.rs index 553b579..34c1d12 100644 --- a/src/database/insights_dao.rs +++ b/src/database/insights_dao.rs @@ -60,6 +60,17 @@ pub trait InsightDao: Sync + Send { &mut self, context: &opentelemetry::Context, ) -> Result, DbError>; + + /// Replace the `training_messages` JSON blob on the current row for + /// `(library_id, rel_path)`. Used by chat-turn append mode to persist + /// the extended conversation without inserting a new insight version. + fn update_training_messages( + &mut self, + context: &opentelemetry::Context, + library_id: i32, + file_path: &str, + training_messages_json: &str, + ) -> Result<(), DbError>; } pub struct SqliteInsightDao { @@ -265,4 +276,30 @@ impl InsightDao for SqliteInsightDao { }) .map_err(|_| DbError::new(DbErrorKind::QueryError)) } + + fn update_training_messages( + &mut self, + context: &opentelemetry::Context, + lib_id: i32, + path: &str, + training_messages_json: &str, + ) -> Result<(), DbError> { + trace_db_call(context, "update", "update_training_messages", |_span| { + use schema::photo_insights::dsl::*; + + let mut connection = self.connection.lock().expect("Unable to get InsightDao"); + + diesel::update( + photo_insights + .filter(library_id.eq(lib_id)) + .filter(rel_path.eq(path)) + .filter(is_current.eq(true)), + ) + .set(training_messages.eq(Some(training_messages_json.to_string()))) + .execute(connection.deref_mut()) + .map(|_| ()) + .map_err(|_| anyhow::anyhow!("Update error")) + }) + .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + } } diff --git a/src/main.rs b/src/main.rs index 2deee7e..53ab607 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1356,6 +1356,8 @@ fn main() -> std::io::Result<()> { .service(ai::get_all_insights_handler) .service(ai::get_available_models_handler) .service(ai::get_openrouter_models_handler) + .service(ai::chat_turn_handler) + .service(ai::chat_history_handler) .service(ai::rate_insight_handler) .service(ai::export_training_data_handler) .service(libraries::list_libraries) diff --git a/src/state.rs b/src/state.rs index dd8628a..8e13d28 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,3 +1,4 @@ +use crate::ai::insight_chat::{ChatLockMap, InsightChatService}; use crate::ai::openrouter::OpenRouterClient; use crate::ai::{InsightGenerator, OllamaClient, SmsApiClient}; use crate::database::{ @@ -44,6 +45,8 @@ pub struct AppState { pub openrouter_allowed_models: Vec, pub sms_client: SmsApiClient, pub insight_generator: InsightGenerator, + /// Chat continuation service. Hold an Arc so handlers can clone cheaply. + pub insight_chat: Arc, } impl AppState { @@ -76,6 +79,7 @@ impl AppState { openrouter_allowed_models: Vec, sms_client: SmsApiClient, insight_generator: InsightGenerator, + insight_chat: Arc, preview_dao: Arc>>, ) -> Self { assert!( @@ -109,6 +113,7 @@ impl AppState { openrouter_allowed_models, sms_client, insight_generator, + insight_chat, } } @@ -199,6 +204,18 @@ impl Default for AppState { libraries_vec.clone(), ); + // Chat continuation reuses the generator for tool dispatch + image + // loading. The lock map starts empty and grows lazily per file. + let chat_locks: ChatLockMap = + Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())); + let insight_chat = Arc::new(InsightChatService::new( + Arc::new(insight_generator.clone()), + ollama.clone(), + openrouter.clone(), + insight_dao.clone(), + chat_locks, + )); + // Ensure preview clips directory exists let preview_clips_path = env::var("PREVIEW_CLIPS_DIRECTORY").unwrap_or_else(|_| "preview_clips".to_string()); @@ -218,6 +235,7 @@ impl Default for AppState { openrouter_allowed_models, sms_client, insight_generator, + insight_chat, preview_dao, ) } @@ -320,6 +338,16 @@ impl AppState { vec![test_lib], ); + let chat_locks: ChatLockMap = + Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())); + let insight_chat = Arc::new(InsightChatService::new( + Arc::new(insight_generator.clone()), + ollama.clone(), + None, + insight_dao.clone(), + chat_locks, + )); + // Initialize test preview DAO let preview_dao: Arc>> = Arc::new(Mutex::new(Box::new(SqlitePreviewDao::new()))); @@ -343,6 +371,7 @@ impl AppState { Vec::new(), sms_client, insight_generator, + insight_chat, preview_dao, ) }