diff --git a/src/ai/handlers.rs b/src/ai/handlers.rs index b07b157..9c664b6 100644 --- a/src/ai/handlers.rs +++ b/src/ai/handlers.rs @@ -706,6 +706,63 @@ pub struct RenderedHistoryMessage { pub is_initial: bool, } +#[derive(Debug, Deserialize)] +pub struct ChatRewindHttpRequest { + pub file_path: String, + #[serde(default)] + pub library: Option, + /// 0-based index into the rendered transcript. The message at this + /// index, and everything after it, is discarded. Must be > 0 — the + /// initial user message is protected. + pub discard_from_rendered_index: usize, +} + +/// POST /insights/chat/rewind — truncate the stored conversation so the +/// rendered message at `discard_from_rendered_index` (and everything after) +/// is removed. Use when a user wants to retry a turn with a different +/// prompt without prior replies poisoning context. +#[post("/insights/chat/rewind")] +pub async fn chat_rewind_handler( + _claims: Claims, + request: web::Json, + app_state: web::Data, +) -> impl Responder { + let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) { + Ok(Some(lib)) => lib, + Ok(None) => app_state.primary_library(), + Err(e) => { + return HttpResponse::BadRequest().json(serde_json::json!({ + "error": format!("invalid library: {}", e) + })); + } + }; + + match app_state + .insight_chat + .rewind_history( + library.id, + &request.file_path, + request.discard_from_rendered_index, + ) + .await + { + Ok(()) => HttpResponse::Ok().json(serde_json::json!({ "success": true })), + Err(e) => { + let msg = format!("{}", e); + log::error!("Chat rewind failed: {}", msg); + if msg.contains("no insight found") { + HttpResponse::NotFound().json(serde_json::json!({ "error": msg })) + } else if msg.contains("no chat history") { + HttpResponse::Conflict().json(serde_json::json!({ "error": msg })) + } else if msg.contains("cannot discard the initial") || msg.contains("out of range") { + HttpResponse::BadRequest().json(serde_json::json!({ "error": msg })) + } else { + HttpResponse::InternalServerError().json(serde_json::json!({ "error": msg })) + } + } + } +} + /// GET /insights/chat/history — return the rendered transcript for a photo. #[get("/insights/chat/history")] pub async fn chat_history_handler( diff --git a/src/ai/insight_chat.rs b/src/ai/insight_chat.rs index 7e5422c..41411d3 100644 --- a/src/ai/insight_chat.rs +++ b/src/ai/insight_chat.rs @@ -479,6 +479,109 @@ impl InsightChatService { model_used, }) } + + /// Truncate the stored conversation so the rendered message at + /// `discard_from_rendered_index` (and everything after it — including + /// the tool-call scaffolding that produced a discarded assistant reply) + /// is removed. The initial user turn cannot be discarded; attempting to + /// do so returns an error. + /// + /// Holds the per-file chat mutex so it serialises with `chat_turn`. + pub async fn rewind_history( + &self, + library_id: i32, + file_path: &str, + discard_from_rendered_index: usize, + ) -> Result<()> { + if discard_from_rendered_index == 0 { + bail!("cannot discard the initial user message"); + } + let normalized = normalize_path(file_path); + + let lock_key = (library_id, normalized.clone()); + let entry_lock = { + let mut locks = self.chat_locks.lock().await; + locks + .entry(lock_key.clone()) + .or_insert_with(|| Arc::new(TokioMutex::new(()))) + .clone() + }; + let _guard = entry_lock.lock().await; + + let insight = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.get_insight(&cx, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + .ok_or_else(|| anyhow!("no insight found for path"))? + }; + let raw_history = insight + .training_messages + .as_ref() + .ok_or_else(|| anyhow!("insight has no chat history"))?; + let messages: Vec = serde_json::from_str(raw_history) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + let cut_at = find_raw_cut(&messages, discard_from_rendered_index) + .ok_or_else(|| anyhow!("discard_from_rendered_index out of range"))?; + + let truncated = &messages[..cut_at]; + let json = serde_json::to_string(truncated) + .map_err(|e| anyhow!("failed to serialize truncated history: {}", e))?; + + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.update_training_messages(&cx, library_id, &normalized, &json) + .map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?; + Ok(()) + } +} + +/// Is this raw message visible in the rendered transcript? Must match +/// `load_history`'s filter exactly — `find_raw_cut` depends on it to map +/// rendered indices back to raw positions. +fn is_rendered(m: &ChatMessage) -> bool { + match m.role.as_str() { + "user" => true, + "assistant" => { + let has_tool_calls = m + .tool_calls + .as_ref() + .map(|c| !c.is_empty()) + .unwrap_or(false); + !(has_tool_calls && m.content.trim().is_empty()) + } + _ => false, + } +} + +/// Given a rendered index to start discarding from, find the raw index at +/// which to truncate. The cut position is the raw length after all prior +/// rendered messages — which also strips any tool-call scaffolding that +/// immediately precedes the discarded rendered message. Returns `None` if +/// `discard_from_rendered_index` is past the end of the rendered view. +pub(crate) fn find_raw_cut( + messages: &[ChatMessage], + discard_from_rendered_index: usize, +) -> Option { + let mut rendered_count = 0usize; + let mut last_kept_raw_end = 0usize; + for (i, m) in messages.iter().enumerate() { + if !is_rendered(m) { + continue; + } + if rendered_count == discard_from_rendered_index { + return Some(last_kept_raw_end); + } + rendered_count += 1; + last_kept_raw_end = i + 1; + } + if rendered_count == discard_from_rendered_index { + // Discarding past the last rendered message is a no-op, but we + // surface it as "nothing to cut" rather than silent success. + return None; + } + None } /// Read AGENTIC_CHAT_MAX_ITERATIONS once per call. Cheap; keeps the code @@ -637,4 +740,46 @@ mod tests { let dropped = apply_context_budget(&mut msgs, 1); assert!(!dropped); } + + #[test] + fn rewind_strips_assistant_and_tool_scaffolding() { + // Rendered: [user1, asst1, user2, asst2] → cut at rendered index 3 + // (the final asst2) should drop the tool-call scaffolding + asst2, + // leaving raw up through user2. + let msgs = vec![ + ChatMessage::system("sys"), + ChatMessage::user("q1"), + assistant_text("a1"), + ChatMessage::user("q2"), + assistant_with_tool_call("lookup"), + ChatMessage::tool_result("data"), + assistant_text("a2 final"), + ]; + let cut = find_raw_cut(&msgs, 3).expect("cut found"); + // raw[0..cut] should end at user("q2") — indices 0..=3. + assert_eq!(cut, 4); + assert_eq!(msgs[cut - 1].role, "user"); + assert_eq!(msgs[cut - 1].content, "q2"); + } + + #[test] + fn rewind_at_second_rendered_cuts_after_first_user() { + // Rendered index 1 = the first assistant reply → dropping it should + // leave just the initial user message. + let msgs = vec![ + ChatMessage::system("s"), + ChatMessage::user("q1"), + assistant_with_tool_call("tool"), + ChatMessage::tool_result("r"), + assistant_text("a1"), + ]; + let cut = find_raw_cut(&msgs, 1).expect("cut found"); + assert_eq!(cut, 2); // sys + user("q1") + } + + #[test] + fn rewind_beyond_range_returns_none() { + let msgs = vec![ChatMessage::user("q1"), assistant_text("a1")]; + assert!(find_raw_cut(&msgs, 5).is_none()); + } } diff --git a/src/ai/mod.rs b/src/ai/mod.rs index 93735c0..3c58b2a 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -11,10 +11,10 @@ pub mod sms_client; #[allow(unused_imports)] pub use daily_summary_job::{generate_daily_summaries, strip_summary_boilerplate}; pub use handlers::{ - chat_history_handler, chat_turn_handler, delete_insight_handler, export_training_data_handler, - generate_agentic_insight_handler, generate_insight_handler, get_all_insights_handler, - get_available_models_handler, get_insight_handler, get_openrouter_models_handler, - rate_insight_handler, + chat_history_handler, chat_rewind_handler, chat_turn_handler, delete_insight_handler, + export_training_data_handler, generate_agentic_insight_handler, generate_insight_handler, + get_all_insights_handler, get_available_models_handler, get_insight_handler, + get_openrouter_models_handler, rate_insight_handler, }; pub use insight_generator::InsightGenerator; #[allow(unused_imports)] diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs index 2cc2cfa..8c487c5 100644 --- a/src/ai/ollama.rs +++ b/src/ai/ollama.rs @@ -691,6 +691,17 @@ Analyze the image and use specific details from both the visual content and the .await .with_context(|| "Failed to parse Ollama chat response")?; + // Log performance counters returned by Ollama. Durations are + // reported in nanoseconds; we render ms + tokens/sec for skim-ability + // in the server log. Missing fields are left off the line rather + // than printed as `None`. + log_chat_metrics( + chat_response.prompt_eval_count, + chat_response.prompt_eval_duration, + chat_response.eval_count, + chat_response.eval_duration, + ); + Ok(( chat_response.message, chat_response.prompt_eval_count, @@ -915,8 +926,14 @@ struct OllamaChatResponse { done_reason: String, #[serde(default)] prompt_eval_count: Option, + /// Nanoseconds spent evaluating the prompt (context ingestion). + #[serde(default)] + prompt_eval_duration: Option, #[serde(default)] eval_count: Option, + /// Nanoseconds spent generating the response tokens. + #[serde(default)] + eval_duration: Option, } #[derive(Deserialize)] @@ -924,6 +941,52 @@ struct OllamaResponse { response: String, } +fn log_chat_metrics( + prompt_eval_count: Option, + prompt_eval_duration_ns: Option, + eval_count: Option, + eval_duration_ns: Option, +) { + // Compute tokens/sec when both count and duration are present. + fn tokens_per_sec(count: Option, duration_ns: Option) -> Option { + match (count, duration_ns) { + (Some(c), Some(d)) if c > 0 && d > 0 => Some((c as f64) * 1_000_000_000.0 / (d as f64)), + _ => None, + } + } + let prompt_ms = prompt_eval_duration_ns.map(|ns| ns as f64 / 1_000_000.0); + let eval_ms = eval_duration_ns.map(|ns| ns as f64 / 1_000_000.0); + let prompt_tps = tokens_per_sec(prompt_eval_count, prompt_eval_duration_ns); + let eval_tps = tokens_per_sec(eval_count, eval_duration_ns); + + let mut parts: Vec = Vec::new(); + if let Some(c) = prompt_eval_count { + let mut s = format!("prompt={} tok", c); + if let Some(ms) = prompt_ms { + s.push_str(&format!(" ({:.0} ms", ms)); + if let Some(tps) = prompt_tps { + s.push_str(&format!(", {:.1} tok/s", tps)); + } + s.push(')'); + } + parts.push(s); + } + if let Some(c) = eval_count { + let mut s = format!("gen={} tok", c); + if let Some(ms) = eval_ms { + s.push_str(&format!(" ({:.0} ms", ms)); + if let Some(tps) = eval_tps { + s.push_str(&format!(", {:.1} tok/s", tps)); + } + s.push(')'); + } + parts.push(s); + } + if !parts.is_empty() { + log::info!("Ollama chat metrics — {}", parts.join(", ")); + } +} + #[derive(Deserialize)] struct OllamaTagsResponse { models: Vec, diff --git a/src/main.rs b/src/main.rs index 53ab607..8be397f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1358,6 +1358,7 @@ fn main() -> std::io::Result<()> { .service(ai::get_openrouter_models_handler) .service(ai::chat_turn_handler) .service(ai::chat_history_handler) + .service(ai::chat_rewind_handler) .service(ai::rate_insight_handler) .service(ai::export_training_data_handler) .service(libraries::list_libraries)