From 962f7bf05c47e5cccf9fea498dbbb1153606c92d Mon Sep 17 00:00:00 2001 From: Cameron Cordes Date: Fri, 29 May 2026 19:50:25 -0400 Subject: [PATCH] Add reconnectable async chat-turn flow with in-memory TurnRegistry Replace the one-shot SSE chat stream with an async dispatch + reconnectable replay flow so the mobile client survives backgrounding, network blips, and OS-killed sockets without losing an in-flight agentic turn. - TurnRegistry/TurnEntry: in-memory per-turn event buffer (cap 500, front eviction) shared by the agentic loop (writer) and SSE replay readers. ReplayOutcome + replay_from/next_batch distinguish Events/CaughtUp/Gone; next_batch registers the Notify before reading state (no lost wakeup) and drains every buffered event before signaling terminal, so the final Done/Error is never dropped and the stream closes cleanly. - Endpoints: POST /insights/chat/turn (202 + turn_id), GET /insights/chat/turn/{id} (SSE replay, ?skip_before= resume, per-event seq, 410 on eviction), DELETE /insights/chat/turn/{id} (real task abort + cooperative is_running() check at each loop boundary). - Cancellation actually stops the task (AbortHandle stored on the entry) and emits a Done{cancelled:true}; callers skip persistence on cancel. - Background sweeper drops stale turns; interval clamped to <=300s. - OpenTelemetry spans: ai.chat.turn.execute/replay/cancel. - Legacy POST /insights/chat/stream path preserved unchanged. Tests: registry coverage for terminal delivery (race guard), waiting, Gone, abort, eviction; handler integration tests for 404/410, skip_before, seq stamping, completed replay, and cancel. Co-Authored-By: Claude Opus 4.8 (1M context) --- Cargo.lock | 3 + Cargo.toml | 1 + src/ai/handlers.rs | 514 ++++++++++++++++++++++++++- src/ai/insight_chat.rs | 638 ++++++++++++++++++++++++++++++++++ src/ai/mod.rs | 7 +- src/ai/turn_registry.rs | 748 ++++++++++++++++++++++++++++++++++++++++ src/main.rs | 25 ++ src/state.rs | 27 +- 8 files changed, 1946 insertions(+), 17 deletions(-) create mode 100644 src/ai/turn_registry.rs diff --git a/Cargo.lock b/Cargo.lock index 5d3e4ce..d35048c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2104,6 +2104,7 @@ dependencies = [ "tokio", "tokio-util", "urlencoding", + "uuid", "walkdir", "zerocopy", ] @@ -4391,7 +4392,9 @@ version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" dependencies = [ + "getrandom 0.4.2", "js-sys", + "serde_core", "wasm-bindgen", ] diff --git a/Cargo.toml b/Cargo.toml index 7324001..6807778 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ image_hasher = "3.0" bk-tree = "0.5" async-trait = "0.1" indicatif = "0.17" +uuid = { version = "1.10", features = ["v4", "serde"] } # Windows lacks system sqlite3, so re-enable the bundled C build there. # Linux/macOS use the system library (faster builds, smaller binary). diff --git a/src/ai/handlers.rs b/src/ai/handlers.rs index bb599c2..9fbe6b7 100644 --- a/src/ai/handlers.rs +++ b/src/ai/handlers.rs @@ -1,4 +1,5 @@ use actix_web::{HttpRequest, HttpResponse, Responder, delete, get, post, web}; +use futures::StreamExt; use opentelemetry::KeyValue; use opentelemetry::trace::{Span, Status, Tracer}; use serde::{Deserialize, Serialize}; @@ -1433,7 +1434,26 @@ pub async fn chat_stream_handler( } fn render_sse_frame(ev: &ChatStreamEvent) -> String { - let (event_name, payload) = match ev { + let (event_name, payload) = sse_event_payload(ev); + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); + format!("event: {}\ndata: {}\n\n", event_name, data) +} + +/// Like `render_sse_frame`, but stamps the event's absolute sequence number +/// (`seq`) into the payload so reconnecting replay clients can compute +/// `skip_before` precisely. `seq` is distinct from the tool-pairing `index` +/// already carried by `tool_call`/`tool_result`. +fn render_indexed_frame(ev: &ChatStreamEvent, seq: u32) -> String { + let (event_name, mut payload) = sse_event_payload(ev); + if let serde_json::Value::Object(map) = &mut payload { + map.insert("seq".to_string(), serde_json::json!(seq)); + } + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); + format!("event: {}\ndata: {}\n\n", event_name, data) +} + +fn sse_event_payload(ev: &ChatStreamEvent) -> (&'static str, serde_json::Value) { + match ev { ChatStreamEvent::IterationStart { n, max } => { ("iteration_start", serde_json::json!({ "n": n, "max": max })) } @@ -1471,6 +1491,7 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String { amended_insight_id, backend_used, model_used, + cancelled, } => ( "done", serde_json::json!({ @@ -1483,6 +1504,7 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String { "amended_insight_id": amended_insight_id, "backend": backend_used, "model": model_used, + "cancelled": cancelled, }), ), // Apollo's frontend SSE consumer (and its free-chat backend, which @@ -1491,7 +1513,491 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String { // "no insight found for path") was silently dropped, leaving an // empty assistant bubble with no clue why the turn died. ChatStreamEvent::Error(msg) => ("error_message", serde_json::json!({ "message": msg })), - }; - let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); - format!("event: {}\ndata: {}\n\n", event_name, data) + } +} + +/// POST /insights/chat/turn — async turn dispatch. Returns turn_id immediately, +/// client then polls GET /insights/chat/turn/{turn_id} for SSE replay. +#[post("/insights/chat/turn")] +pub async fn turn_async_handler( + http_request: HttpRequest, + claims: Claims, + request: web::Json, + app_state: web::Data, +) -> impl Responder { + let parent_context = extract_context_from_request(&http_request); + let tracer = global_tracer(); + let mut span = tracer.start_with_context("http.insights.chat_turn_async", &parent_context); + span.set_attribute(KeyValue::new("file_path", request.file_path.clone())); + + let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) { + Ok(Some(lib)) => lib, + Ok(None) => app_state.primary_library(), + Err(e) => { + return HttpResponse::BadRequest().json(serde_json::json!({ + "error": format!("invalid library: {}", e) + })); + } + }; + + let user_id = claims.sub.parse::().unwrap_or(1); + + let chat_req = ChatTurnRequest { + library_id: library.id, + user_id, + file_path: request.file_path.clone(), + user_message: request.user_message.clone(), + model: request.model.clone(), + backend: request.backend.clone(), + num_ctx: request.num_ctx, + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + min_p: request.min_p, + max_iterations: request.max_iterations, + system_prompt: request.system_prompt.clone(), + persona_id: request.persona_id.clone(), + amend: request.amend, + regenerate: request.regenerate, + }; + + let service = app_state.insight_chat.clone(); + let registry = app_state.turn_registry.clone(); + + let turn_id = service.chat_turn_async(registry, chat_req).await; + span.set_attribute(KeyValue::new("turn_id", turn_id.clone())); + span.set_status(Status::Ok); + HttpResponse::Accepted().json(serde_json::json!({ + "turn_id": turn_id, + "status": "running" + })) +} + +/// Query params for the SSE replay stream. +#[derive(Debug, Deserialize)] +pub struct ReplayQuery { + /// Replay events from this absolute sequence number (`seq`) onward. + /// Absent or 0 replays from the beginning. On reconnect the client sends + /// the `seq` of the last event it applied, plus one. + pub skip_before: Option, +} + +/// GET /insights/chat/turn/{turn_id} — SSE replay stream. +#[get("/insights/chat/turn/{turn_id}")] +pub async fn turn_replay_handler( + http_request: HttpRequest, + path: web::Path, + query: web::Query, + app_state: web::Data, +) -> HttpResponse { + use crate::ai::turn_registry::ReplayOutcome; + + let turn_id = path.into_inner(); + let skip_before = query.skip_before.unwrap_or(0); + + let parent_context = extract_context_from_request(&http_request); + let tracer = global_tracer(); + let mut span = tracer.start_with_context("ai.chat.turn.replay", &parent_context); + span.set_attribute(KeyValue::new("turn_id", turn_id.clone())); + span.set_attribute(KeyValue::new("skip_before", skip_before as i64)); + + let registry = app_state.turn_registry.clone(); + let entry = match registry.get(&turn_id).await { + Some(e) => e, + None => { + span.set_status(Status::error("turn not found")); + return HttpResponse::NotFound().json(serde_json::json!({ + "error": format!("turn {} not found", turn_id) + })); + } + }; + + let info = entry.info().await; + span.set_attribute(KeyValue::new("status", info.status.as_str())); + span.set_attribute(KeyValue::new( + "event_count", + info.total_events_pushed as i64, + )); + let turn_info_frame = render_turn_info_frame(&info); + + // Initial buffered batch: events produced before this connection attached. + // Stamp each frame with its absolute `seq` so the client can track + // `skip_before` precisely across reconnects. + let (initial_frames, start_skip) = match entry.replay_from(skip_before).await { + ReplayOutcome::Gone => { + span.set_status(Status::error("buffer evicted")); + return HttpResponse::Gone().json(serde_json::json!({ + "error": "turn history has expired (buffer evicted)" + })); + } + ReplayOutcome::CaughtUp { next_skip } => (Vec::new(), next_skip), + ReplayOutcome::Events { events, next_skip } => { + let frames: Vec = events + .into_iter() + .enumerate() + .map(|(i, ev)| { + actix_web::web::Bytes::from(render_indexed_frame(&ev, skip_before + i as u32)) + }) + .collect(); + (frames, next_skip) + } + }; + + span.set_status(Status::Ok); + let running = entry.is_running(); + + // Head: the `turn_info` event followed by any already-buffered events. + let head = futures::stream::once(async move { + Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(turn_info_frame)) + }) + .chain(futures::stream::iter( + initial_frames.into_iter().map(Ok::<_, actix_web::Error>), + )); + + if !running { + // Completed turn: every event — including the terminal Done/Error — is + // already in the buffered batch above. Emit it and close. + return HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("X-Accel-Buffering", "no")) + .streaming(head); + } + + // In-progress turn: after the head, wait for new events. `next_batch` + // drains every buffered event (including the terminal one) before it + // reports the turn finished, so the final Done/Error is never dropped; + // CaughtUp then closes the stream by returning None. + let tail = futures::stream::unfold( + ( + entry, + start_skip, + Vec::::new(), + false, + ), + |(entry, skip, pending, finished)| async move { + // Flush queued frames from a previous multi-event batch first. + if let Some((first, rest)) = pending.split_first() { + return Some((Ok(first.clone()), (entry, skip, rest.to_vec(), finished))); + } + if finished { + return None; + } + + match entry.next_batch(skip).await { + ReplayOutcome::Events { events, next_skip } => { + let frames: Vec = events + .into_iter() + .enumerate() + .map(|(i, ev)| { + actix_web::web::Bytes::from(render_indexed_frame(&ev, skip + i as u32)) + }) + .collect(); + // next_batch only returns Events for a non-empty batch. + let (first, rest) = frames.split_first().expect("non-empty batch"); + Some((Ok(first.clone()), (entry, next_skip, rest.to_vec(), false))) + } + // Terminal reached and fully drained — close the connection. + ReplayOutcome::CaughtUp { .. } => None, + ReplayOutcome::Gone => { + // Evicted mid-stream: emit one error frame, then close. + let gone = + actix_web::web::Bytes::from(render_sse_frame(&ChatStreamEvent::Error( + "turn history has expired (buffer evicted)".to_string(), + ))); + Some((Ok(gone), (entry, skip, Vec::new(), true))) + } + } + }, + ); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("X-Accel-Buffering", "no")) + .streaming(head.chain(tail)) +} + +fn render_turn_info_frame(info: &crate::ai::turn_registry::TurnInfo) -> String { + let payload = serde_json::json!({ + "turn_id": info.turn_id, + "file_path": info.file_path, + "library_id": info.library_id, + "status": info.status.as_str(), + "total_events_pushed": info.total_events_pushed, + "buffered_count": info.buffered_count, + }); + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); + format!("event: turn_info\ndata: {}\n\n", data) +} + +/// DELETE /insights/chat/turn/{turn_id} — cancel a running turn. +#[delete("/insights/chat/turn/{turn_id}")] +pub async fn cancel_turn_handler( + http_request: HttpRequest, + path: web::Path, + app_state: web::Data, +) -> impl Responder { + let turn_id = path.into_inner(); + + let parent_context = extract_context_from_request(&http_request); + let tracer = global_tracer(); + let mut span = tracer.start_with_context("ai.chat.turn.cancel", &parent_context); + span.set_attribute(KeyValue::new("turn_id", turn_id.clone())); + + let registry = app_state.turn_registry.clone(); + let entry = match registry.get(&turn_id).await { + Some(e) => e, + None => { + span.set_status(Status::error("turn not found")); + return HttpResponse::NotFound().json(serde_json::json!({ + "error": format!("turn {} not found", turn_id) + })); + } + }; + + // Abort the spawned task so it stops producing events promptly. The loop + // also checks `is_running()` at each iteration boundary as a graceful + // backstop in case the abort lands between await points. + let aborted = entry.abort(); + span.set_attribute(KeyValue::new("aborted", aborted)); + + // Push the terminal event BEFORE flipping status: a replay reader treats a + // terminal status with no buffered tail as "closed", so the Done must be + // buffered first for in-progress connections to receive it. + let _ = entry + .push_event(ChatStreamEvent::Done { + tool_calls_made: 0, + iterations_used: 0, + truncated: false, + prompt_tokens: None, + eval_tokens: None, + num_ctx: None, + amended_insight_id: None, + backend_used: "cancelled".to_string(), + model_used: "cancelled".to_string(), + cancelled: true, + }) + .await; + entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Cancelled); + span.set_status(Status::Ok); + + HttpResponse::Ok().json(serde_json::json!({ + "cancelled": true + })) +} + +#[cfg(test)] +mod turn_replay_tests { + use super::{cancel_turn_handler, render_indexed_frame, turn_replay_handler}; + use crate::ai::insight_chat::ChatStreamEvent; + use crate::ai::turn_registry::{TurnEntry, TurnStatus}; + use crate::state::AppState; + use actix_web::test as actix_test; + use actix_web::{App, web::Data}; + use std::sync::Arc; + + /// Serialize `AppState::test_state()` construction across the parallel + /// tests in this module: each build opens ~10 DAO connections to the one + /// shared `DATABASE_URL` file, and doing several at once races the WAL + /// `journal_mode` switch into a spurious "database is locked". The test + /// bodies themselves still run in parallel; only the open is gated. + static DB_INIT: std::sync::Mutex<()> = std::sync::Mutex::new(()); + + fn build_state() -> Data { + let _guard = DB_INIT.lock().unwrap_or_else(|p| p.into_inner()); + Data::new(AppState::test_state()) + } + + fn done(cancelled: bool) -> ChatStreamEvent { + ChatStreamEvent::Done { + tool_calls_made: 0, + iterations_used: 1, + truncated: false, + prompt_tokens: Some(10), + eval_tokens: Some(20), + num_ctx: None, + amended_insight_id: None, + backend_used: "local".into(), + model_used: "m".into(), + cancelled, + } + } + + /// Seed a completed turn (events + terminal Done) directly in the registry. + async fn seed_completed(state: &AppState, id: &str, text_events: usize) { + let entry = Arc::new(TurnEntry::new(id.into(), "/p.jpg".into(), 1)); + for i in 0..text_events { + entry + .push_event(ChatStreamEvent::TextDelta(format!("d{i}"))) + .await; + } + entry.push_event(done(false)).await; + entry.set_terminal_status(TurnStatus::Done); + state.turn_registry.insert(entry).await; + } + + #[test] + fn indexed_frame_stamps_seq_without_clobbering_tool_index() { + // tool_call carries its own pairing `index`; `seq` must be additive. + let frame = render_indexed_frame( + &ChatStreamEvent::ToolCall { + index: 3, + name: "geo".into(), + arguments: serde_json::json!({}), + }, + 42, + ); + assert!(frame.contains("event: tool_call")); + assert!(frame.contains("\"index\":3")); + assert!(frame.contains("\"seq\":42")); + } + + #[actix_rt::test] + async fn replay_unknown_turn_is_404() { + let state = build_state(); + let app = actix_test::init_service( + App::new() + .service(turn_replay_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/insights/chat/turn/nope") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 404); + } + + #[actix_rt::test] + async fn replay_completed_turn_emits_turn_info_and_done_with_seq() { + let state = build_state(); + seed_completed(&state, "t1", 2).await; + + let app = actix_test::init_service( + App::new() + .service(turn_replay_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/insights/chat/turn/t1") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 200); + + let body = String::from_utf8(actix_test::read_body(resp).await.to_vec()).unwrap(); + assert!(body.contains("event: turn_info")); + assert!(body.contains("event: text")); + assert!(body.contains("event: done")); + // Events are seq-stamped 0,1 (text) and 2 (done). + assert!(body.contains("\"seq\":0")); + assert!(body.contains("\"seq\":2")); + // Done payload carries the renamed token fields the client reads. + assert!(body.contains("\"prompt_tokens\":10")); + } + + #[actix_rt::test] + async fn replay_skip_before_query_skips_applied_events() { + let state = build_state(); + seed_completed(&state, "t2", 3).await; // seqs 0,1,2 text; 3 done + + let app = actix_test::init_service( + App::new() + .service(turn_replay_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/insights/chat/turn/t2?skip_before=2") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 200); + + let body = String::from_utf8(actix_test::read_body(resp).await.to_vec()).unwrap(); + // Only seq 2 (last text) and seq 3 (done) should be present. + assert!(body.contains("\"seq\":2")); + assert!(body.contains("\"seq\":3")); + assert!(!body.contains("\"seq\":0")); + assert!(!body.contains("\"seq\":1")); + } + + #[actix_rt::test] + async fn replay_evicted_index_is_410() { + let state = build_state(); + let entry = Arc::new(TurnEntry::new("t3".into(), "/p.jpg".into(), 1)); + // Push past the cap so the front is evicted and base advances. + for i in 0..600 { + entry + .push_event(ChatStreamEvent::TextDelta(format!("d{i}"))) + .await; + } + entry.set_terminal_status(TurnStatus::Done); + state.turn_registry.insert(entry).await; + + let app = actix_test::init_service( + App::new() + .service(turn_replay_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/insights/chat/turn/t3?skip_before=0") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 410); + } + + #[actix_rt::test] + async fn cancel_unknown_turn_is_404() { + let state = build_state(); + let app = actix_test::init_service( + App::new() + .service(cancel_turn_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::delete() + .uri("/insights/chat/turn/nope") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 404); + } + + #[actix_rt::test] + async fn cancel_running_turn_marks_cancelled_and_buffers_terminal() { + let state = build_state(); + let entry = Arc::new(TurnEntry::new("t4".into(), "/p.jpg".into(), 1)); + entry + .push_event(ChatStreamEvent::TextDelta("partial".into())) + .await; + state.turn_registry.insert(entry.clone()).await; + + let app = actix_test::init_service( + App::new() + .service(cancel_turn_handler) + .app_data(state.clone()), + ) + .await; + + let req = actix_test::TestRequest::delete() + .uri("/insights/chat/turn/t4") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), 200); + + // Status flipped to Cancelled and a terminal Done(cancelled) buffered + // after the existing event, so a late replay reader still completes. + assert_eq!( + TurnStatus::from(entry.status.load(std::sync::atomic::Ordering::Relaxed)), + TurnStatus::Cancelled + ); + let info = entry.info().await; + assert_eq!(info.total_events_pushed, 2); + } } diff --git a/src/ai/insight_chat.rs b/src/ai/insight_chat.rs index d2b153b..3622972 100644 --- a/src/ai/insight_chat.rs +++ b/src/ai/insight_chat.rs @@ -9,11 +9,14 @@ use tokio::sync::Mutex as TokioMutex; use crate::ai::backend::{BackendKind, ResolvedBackend, SamplingOverrides}; use crate::ai::insight_generator::InsightGenerator; use crate::ai::llm_client::{ChatMessage, LlmStreamEvent, Tool}; +use crate::ai::turn_registry::TurnEntry; +use crate::ai::turn_registry::TurnRegistry; use crate::database::InsightDao; use crate::database::models::InsertPhotoInsight; use crate::otel::global_tracer; use crate::utils::normalize_path; use futures::stream::{BoxStream, StreamExt}; +use uuid::Uuid; const DEFAULT_MAX_ITERATIONS: usize = 6; const DEFAULT_NUM_CTX: i32 = 8192; @@ -678,6 +681,626 @@ impl InsightChatService { Ok(rx) } + /// Async turn dispatch: creates a TurnEntry in the registry, spawns the + /// agentic loop on a Tokio task, and returns the turn_id immediately. + /// Events are buffered in the TurnEntry for SSE replay. + pub async fn chat_turn_async( + self: Arc, + registry: Arc, + req: ChatTurnRequest, + ) -> String { + let turn_id = Uuid::new_v4().to_string(); + let entry = Arc::new(TurnEntry::new( + turn_id.clone(), + req.file_path.clone(), + req.library_id, + )); + registry.insert(entry.clone()).await; + + let svc = self.clone(); + let entry_clone = entry.clone(); + let turn_id_for_span = turn_id.clone(); + let library_id = req.library_id; + let handle = tokio::spawn(async move { + // Span covering the whole spawned turn execution. Created here (not + // in the HTTP handler) because the dispatch span ends at the 202 + // response, long before this work runs. + let tracer = global_tracer(); + let mut span = tracer.start("ai.chat.turn.execute"); + span.set_attribute(KeyValue::new("turn_id", turn_id_for_span)); + span.set_attribute(KeyValue::new("library_id", library_id as i64)); + + let result = svc + .run_streaming_turn_with_entry(req, entry_clone.clone()) + .await; + if let Err(ref e) = result { + span.set_attribute(KeyValue::new("status", "error")); + span.set_status(Status::error(format!("{e}"))); + // Push the terminal event BEFORE flipping status: a replay + // reader treats a terminal status with no buffered tail as + // "closed", so the Error must be in the buffer first. + let _ = entry_clone + .push_event(ChatStreamEvent::Error(format!("{}", e))) + .await; + entry_clone.set_terminal_status(crate::ai::turn_registry::TurnStatus::Error); + } else { + span.set_attribute(KeyValue::new("status", "done")); + span.set_status(Status::Ok); + } + }); + + // Install the abort handle so DELETE can actually stop the task. + entry.set_abort_handle(handle.abort_handle()); + + turn_id + } + + /// Variant of `run_streaming_turn` that pushes events to a `TurnEntry` + /// buffer instead of an `mpsc::Sender`. + async fn run_streaming_turn_with_entry( + self: Arc, + req: ChatTurnRequest, + entry: Arc, + ) -> Result<()> { + if req.user_message.trim().is_empty() { + bail!("user_message must not be empty"); + } + if req.user_message.len() > 8192 { + bail!("user_message exceeds 8192 chars"); + } + let normalized = normalize_path(&req.file_path); + + let lock_key = (req.library_id, normalized.clone()); + let entry_lock = { + let mut locks = self.chat_locks.lock().await; + locks + .entry(lock_key.clone()) + .or_insert_with(|| Arc::new(TokioMutex::new(()))) + .clone() + }; + let _guard = entry_lock.lock().await; + + // Look up existing insight scoped to this turn's library_id. + let existing_insight = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.get_current_insight_for_library(&cx, req.library_id, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + }; + + if req.regenerate || existing_insight.is_none() { + return self + .run_bootstrap_streaming_with_entry(req, normalized, entry) + .await; + } + let insight = existing_insight.expect("just checked Some above"); + self.run_continuation_streaming_with_entry(req, normalized, insight, entry) + .await + } + + /// Continuation path with TurnEntry buffer. + async fn run_continuation_streaming_with_entry( + &self, + req: ChatTurnRequest, + normalized: String, + insight: crate::database::models::PhotoInsight, + entry: Arc, + ) -> Result<()> { + let active_persona = req + .persona_id + .clone() + .filter(|s| !s.trim().is_empty()) + .unwrap_or_else(|| "default".to_string()); + let raw_history = insight.training_messages.as_ref().ok_or_else(|| { + anyhow!("insight has no chat history; regenerate this insight in agentic mode") + })?; + let mut messages: Vec = serde_json::from_str(raw_history) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + let stored_backend = insight.backend.clone(); + let effective_backend = req + .backend + .as_deref() + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| stored_backend.clone()); + let kind = BackendKind::parse(&effective_backend)?; + validate_cross_replay(&stored_backend, kind.as_str())?; + + let max_iterations = req + .max_iterations + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .clamp(1, env_max_iterations()); + + let stored_model = insight.model_version.clone(); + let overrides = SamplingOverrides { + model: req + .model + .clone() + .or_else(|| Some(stored_model.clone())) + .filter(|m| !m.is_empty()), + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + }; + let backend = self.generator.resolve_backend(kind, &overrides).await?; + let model_used = backend.model().to_string(); + + let local_first_user_has_image = messages + .iter() + .find(|m| m.role == "user") + .and_then(|m| m.images.as_ref()) + .map(|imgs| !imgs.is_empty()) + .unwrap_or(false); + let offer_describe_tool = backend.images_inline && local_first_user_has_image; + let gate_opts = self.generator.current_gate_opts_for_persona( + offer_describe_tool, + Some((req.user_id, &active_persona)), + ); + let tools = InsightGenerator::build_tool_definitions(gate_opts); + + let image_base64: Option = if offer_describe_tool { + self.generator.load_image_as_base64(&normalized).ok() + } else { + None + }; + + let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize) + .saturating_sub(RESPONSE_HEADROOM_TOKENS); + let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN); + let truncated = apply_context_budget(&mut messages, budget_bytes); + if truncated { + let _ = entry.push_event(ChatStreamEvent::Truncated).await; + } + + messages.push(ChatMessage::user(req.user_message.clone())); + + let override_stash = + apply_system_prompt_override(&mut messages, req.system_prompt.as_deref()); + let original_system_content = annotate_system_with_budget(&mut messages, max_iterations); + + let outcome = self + .run_streaming_agentic_loop_with_entry( + &backend, + &mut messages, + tools, + &image_base64, + &normalized, + req.user_id, + &active_persona, + max_iterations, + &entry, + ) + .await?; + let AgenticLoopOutcome { + tool_calls_made, + iterations_used, + last_prompt_eval_count, + last_eval_count, + final_content, + cancelled, + } = outcome; + + // Turn was cancelled mid-flight: the DELETE handler already pushed the + // terminal event and flipped status. Don't persist a partial turn or + // push a second terminal event. + if cancelled { + return Ok(()); + } + + restore_system_content(&mut messages, original_system_content); + + if !req.amend { + restore_system_prompt_override(&mut messages, override_stash); + } + + let json = serde_json::to_string(&messages) + .map_err(|e| anyhow!("failed to serialize chat history: {}", e))?; + + let mut amended_insight_id: Option = None; + if req.amend { + let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content); + let final_content = body; + + let new_row = InsertPhotoInsight { + library_id: req.library_id, + file_path: normalized.clone(), + title, + summary: final_content.clone(), + generated_at: Utc::now().timestamp(), + model_version: model_used.clone(), + is_current: true, + training_messages: Some(json), + backend: kind.as_str().to_string(), + fewshot_source_ids: None, + content_hash: None, + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + system_prompt: req.system_prompt.clone(), + persona_id: req.persona_id.clone(), + prompt_eval_count: None, + eval_count: None, + }; + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let stored = dao + .store_insight(&cx, new_row) + .map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?; + amended_insight_id = Some(stored.id); + } else { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let rows = dao + .update_training_messages(&cx, req.library_id, &normalized, &json) + .map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?; + if rows == 0 { + log::warn!( + "update_training_messages (stream) updated 0 rows for {} (lib {}), \ + concurrent regenerate likely flipped is_current", + normalized, + req.library_id + ); + } + } + + let _ = entry + .push_event(ChatStreamEvent::Done { + tool_calls_made, + iterations_used, + truncated, + prompt_tokens: last_prompt_eval_count, + eval_tokens: last_eval_count, + num_ctx: req.num_ctx, + amended_insight_id, + backend_used: kind.as_str().to_string(), + model_used, + cancelled: false, + }) + .await; + + entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Done); + Ok(()) + } + + /// Bootstrap path with TurnEntry buffer. + async fn run_bootstrap_streaming_with_entry( + &self, + req: ChatTurnRequest, + normalized: String, + entry: Arc, + ) -> Result<()> { + let active_persona = req + .persona_id + .clone() + .filter(|s| !s.trim().is_empty()) + .unwrap_or_else(|| "default".to_string()); + let effective_backend = resolve_bootstrap_backend(req.backend.as_deref())?; + let kind = BackendKind::parse(&effective_backend)?; + + let max_iterations = req + .max_iterations + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .clamp(1, env_max_iterations()); + + let overrides = SamplingOverrides { + model: req.model.clone().filter(|m| !m.is_empty()), + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + }; + let backend = self.generator.resolve_backend(kind, &overrides).await?; + let model_used = backend.model().to_string(); + + let image_base64: Option = self.generator.load_image_as_base64(&normalized).ok(); + + let exif = self.generator.fetch_exif(&normalized); + let date_taken_str = resolve_date_taken_for_context(&exif, &normalized); + let gps = exif + .as_ref() + .and_then(|e| match (e.gps_latitude, e.gps_longitude) { + (Some(lat), Some(lon)) => Some((lat as f64, lon as f64)), + _ => None, + }); + + let visual_block = if !backend.images_inline { + match image_base64.as_deref() { + Some(b64) => match backend.local().describe_image(b64).await { + Ok(desc) => { + format!("Visual description (from local vision model):\n{}\n", desc) + } + Err(e) => { + log::warn!("{} bootstrap: describe_image failed: {}", kind.as_str(), e); + String::new() + } + }, + None => String::new(), + } + } else { + String::new() + }; + + let offer_describe_tool = backend.images_inline && image_base64.is_some(); + let gate_opts = self.generator.current_gate_opts_for_persona( + offer_describe_tool, + Some((req.user_id, &active_persona)), + ); + let tools = InsightGenerator::build_tool_definitions(gate_opts); + + let persona = resolve_bootstrap_system_prompt(req.system_prompt.as_deref()); + let system_content = build_bootstrap_system_message( + &persona, + &normalized, + date_taken_str.as_deref(), + gps, + &visual_block, + ); + let system_msg = ChatMessage::system(system_content); + let mut user_msg = ChatMessage::user(req.user_message.clone()); + if backend.images_inline + && let Some(ref img) = image_base64 + { + user_msg.images = Some(vec![img.clone()]); + } + let mut messages = vec![system_msg, user_msg]; + + let outcome = self + .run_streaming_agentic_loop_with_entry( + &backend, + &mut messages, + tools, + &image_base64, + &normalized, + req.user_id, + &active_persona, + max_iterations, + &entry, + ) + .await?; + let AgenticLoopOutcome { + tool_calls_made, + iterations_used, + last_prompt_eval_count, + last_eval_count, + final_content, + cancelled, + } = outcome; + + // Turn was cancelled mid-flight: the DELETE handler already pushed the + // terminal event and flipped status. Don't persist a partial turn or + // push a second terminal event. + if cancelled { + return Ok(()); + } + + let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content); + + let json = serde_json::to_string(&messages) + .map_err(|e| anyhow!("failed to serialize chat history: {}", e))?; + let new_row = InsertPhotoInsight { + library_id: req.library_id, + file_path: normalized.clone(), + title, + summary: body, + generated_at: Utc::now().timestamp(), + model_version: model_used.clone(), + is_current: true, + training_messages: Some(json), + backend: kind.as_str().to_string(), + fewshot_source_ids: None, + content_hash: None, + num_ctx: req.num_ctx, + temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + min_p: req.min_p, + system_prompt: req.system_prompt.clone(), + persona_id: req.persona_id.clone(), + prompt_eval_count: None, + eval_count: None, + }; + let stored = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.store_insight(&cx, new_row) + .map_err(|e| anyhow!("failed to store bootstrap insight: {:?}", e))? + }; + + let _ = entry + .push_event(ChatStreamEvent::Done { + tool_calls_made, + iterations_used, + truncated: false, + prompt_tokens: last_prompt_eval_count, + eval_tokens: last_eval_count, + num_ctx: req.num_ctx, + amended_insight_id: Some(stored.id), + backend_used: kind.as_str().to_string(), + model_used, + cancelled: false, + }) + .await; + + entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Done); + Ok(()) + } + + /// Agentic loop variant that pushes events to a `TurnEntry` buffer. + async fn run_streaming_agentic_loop_with_entry( + &self, + backend: &ResolvedBackend, + messages: &mut Vec, + tools: Vec, + image_base64: &Option, + normalized: &str, + user_id: i32, + active_persona: &str, + max_iterations: usize, + entry: &Arc, + ) -> Result { + let mut tool_calls_made = 0usize; + let mut iterations_used = 0usize; + let mut last_prompt_eval_count: Option = None; + let mut last_eval_count: Option = None; + let mut final_content = String::new(); + + for iteration in 0..max_iterations { + // Cooperative cancellation: a DELETE flips status out of Running + // (and aborts this task). Check at the iteration boundary so an + // in-flight tool round finishes cleanly rather than mid-write. + if !entry.is_running() { + return Ok(AgenticLoopOutcome { + tool_calls_made, + iterations_used, + last_prompt_eval_count, + last_eval_count, + final_content, + cancelled: true, + }); + } + + iterations_used = iteration + 1; + let _ = entry + .push_event(ChatStreamEvent::IterationStart { + n: iterations_used, + max: max_iterations, + }) + .await; + + let mut stream = backend + .chat() + .chat_with_tools_stream(messages.clone(), tools.clone()) + .await?; + + let mut final_message: Option = None; + while let Some(ev) = stream.next().await { + let ev = ev?; + match ev { + LlmStreamEvent::TextDelta(delta) => { + let _ = entry.push_event(ChatStreamEvent::TextDelta(delta)).await; + } + LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + } => { + last_prompt_eval_count = prompt_eval_count; + last_eval_count = eval_count; + final_message = Some(message); + break; + } + } + } + let mut response = + final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?; + + if let Some(ref mut tcs) = response.tool_calls { + for tc in tcs.iter_mut() { + if !tc.function.arguments.is_object() { + tc.function.arguments = serde_json::Value::Object(Default::default()); + } + } + } + + messages.push(response.clone()); + + if let Some(ref tool_calls) = response.tool_calls + && !tool_calls.is_empty() + { + for tool_call in tool_calls { + tool_calls_made += 1; + let call_index = tool_calls_made - 1; + let _ = entry + .push_event(ChatStreamEvent::ToolCall { + index: call_index, + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + }) + .await; + let cx = opentelemetry::Context::new(); + let result = self + .generator + .execute_tool( + &tool_call.function.name, + &tool_call.function.arguments, + backend, + image_base64, + normalized, + user_id, + active_persona, + &cx, + ) + .await; + let (result_preview, result_truncated) = truncate_tool_result(&result); + let _ = entry + .push_event(ChatStreamEvent::ToolResult { + index: call_index, + name: tool_call.function.name.clone(), + result: result_preview, + result_truncated, + }) + .await; + messages.push(ChatMessage::tool_result(result)); + } + continue; + } + + final_content = response.content; + break; + } + + // No-tools fallback + if final_content.is_empty() { + let synthetic_idx = messages.len(); + messages.push(ChatMessage::user( + "Please write your final answer now without calling any more tools.", + )); + let mut stream = backend + .chat() + .chat_with_tools_stream(messages.clone(), vec![]) + .await?; + let mut final_message: Option = None; + while let Some(ev) = stream.next().await { + let ev = ev?; + match ev { + LlmStreamEvent::TextDelta(delta) => { + let _ = entry.push_event(ChatStreamEvent::TextDelta(delta)).await; + } + LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + } => { + last_prompt_eval_count = prompt_eval_count; + last_eval_count = eval_count; + final_message = Some(message); + break; + } + } + } + let final_response = + final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?; + final_content = final_response.content.clone(); + messages.push(final_response); + messages.remove(synthetic_idx); + } + + Ok(AgenticLoopOutcome { + tool_calls_made, + iterations_used, + last_prompt_eval_count, + last_eval_count, + final_content, + cancelled: false, + }) + } + async fn run_streaming_turn( self: Arc, req: ChatTurnRequest, @@ -836,6 +1459,8 @@ impl InsightChatService { last_prompt_eval_count, last_eval_count, final_content, + // The mpsc (legacy) path has no cancellation channel. + cancelled: _, } = outcome; // Drop the per-turn iteration-budget note before persisting so it @@ -916,6 +1541,7 @@ impl InsightChatService { amended_insight_id, backend_used: kind.as_str().to_string(), model_used, + cancelled: false, }) .await; @@ -1052,6 +1678,8 @@ impl InsightChatService { last_prompt_eval_count, last_eval_count, final_content, + // The mpsc (legacy) path has no cancellation channel. + cancelled: _, } = outcome; let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content); @@ -1101,6 +1729,7 @@ impl InsightChatService { amended_insight_id: Some(stored.id), backend_used: kind.as_str().to_string(), model_used, + cancelled: false, }) .await; @@ -1274,6 +1903,7 @@ impl InsightChatService { last_prompt_eval_count, last_eval_count, final_content, + cancelled: false, }) } } @@ -1402,6 +2032,10 @@ struct AgenticLoopOutcome { last_prompt_eval_count: Option, last_eval_count: Option, final_content: String, + /// True when the loop exited early because the turn was cancelled + /// (status flipped out of `Running`). Callers skip persistence and the + /// terminal `Done` push — the cancel handler owns the terminal event. + cancelled: bool, } /// Events emitted by `chat_turn_stream`. One stream per turn; ends after @@ -1456,6 +2090,10 @@ pub enum ChatStreamEvent { amended_insight_id: Option, backend_used: String, model_used: String, + /// True only for the synthetic terminal event emitted by the cancel + /// handler, so clients can distinguish a user-cancelled turn from a + /// natural completion. Always false on the normal success path. + cancelled: bool, }, /// Terminal failure event. No further events follow. Error(String), diff --git a/src/ai/mod.rs b/src/ai/mod.rs index c54b113..e9bec09 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -11,6 +11,7 @@ pub mod llm_client; pub mod ollama; pub mod openrouter; pub mod sms_client; +pub mod turn_registry; // strip_summary_boilerplate is used by binaries (test_daily_summary), not the library #[allow(unused_imports)] @@ -19,11 +20,11 @@ pub use daily_summary_job::{ generate_daily_summaries, strip_summary_boilerplate, }; pub use handlers::{ - cancel_generation_handler, chat_history_handler, chat_rewind_handler, chat_stream_handler, - chat_turn_handler, delete_insight_handler, export_training_data_handler, + cancel_generation_handler, cancel_turn_handler, chat_history_handler, chat_rewind_handler, + chat_stream_handler, chat_turn_handler, delete_insight_handler, export_training_data_handler, generate_agentic_insight_handler, generate_insight_handler, generation_status_handler, get_all_insights_handler, get_available_models_handler, get_insight_handler, - get_openrouter_models_handler, rate_insight_handler, + get_openrouter_models_handler, rate_insight_handler, turn_async_handler, turn_replay_handler, }; pub use insight_generator::InsightGenerator; pub use llamacpp::LlamaCppClient; diff --git a/src/ai/turn_registry.rs b/src/ai/turn_registry.rs new file mode 100644 index 0000000..2a5d432 --- /dev/null +++ b/src/ai/turn_registry.rs @@ -0,0 +1,748 @@ +use crate::ai::insight_chat::ChatStreamEvent; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::time::Instant; +use tokio::sync::{Mutex, Notify}; +use tokio::task::AbortHandle; + +/// Maximum number of events buffered per turn. Agentic turns typically +/// produce ~120 events; 500 provides 4× headroom. When exceeded, oldest +/// events are evicted from the front. +const MAX_BUFFERED_EVENTS: usize = 500; + +/// Turn status codes used by `TurnEntry::status`. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TurnStatus { + Running = 0, + Done = 1, + Error = 2, + Cancelled = 3, +} + +impl From for TurnStatus { + fn from(v: u32) -> Self { + match v { + 0 => TurnStatus::Running, + 1 => TurnStatus::Done, + 2 => TurnStatus::Error, + 3 => TurnStatus::Cancelled, + _ => TurnStatus::Running, + } + } +} + +impl TurnStatus { + pub fn as_str(&self) -> &'static str { + match self { + TurnStatus::Running => "running", + TurnStatus::Done => "done", + TurnStatus::Error => "error", + TurnStatus::Cancelled => "cancelled", + } + } +} + +/// Shared metadata about a turn, read by the SSE replay handler to emit +/// the initial `turn_info` event and to decide whether to wait for new +/// events or close immediately. +#[derive(Debug, Clone)] +pub struct TurnInfo { + pub turn_id: String, + pub file_path: String, + pub library_id: i32, + pub status: TurnStatus, + pub total_events_pushed: u32, + pub buffered_count: u32, +} + +/// Result of reading events at or after an absolute `skip_before` index. +#[derive(Debug)] +pub enum ReplayOutcome { + /// New events are available. `next_skip` is the absolute index to pass + /// on the next read (i.e. one past the last event returned). + Events { + events: Vec, + next_skip: u32, + }, + /// The reader is caught up to the live edge — no events past `skip_before` + /// yet. `next_skip` is the current high-water mark. + CaughtUp { next_skip: u32 }, + /// `skip_before` points below the buffer's base index: the requested + /// events were evicted. Maps to HTTP 410 Gone. + Gone, +} + +/// Per-turn state shared between the agentic loop (writer) and all SSE +/// replay connections (readers). +pub struct TurnEntry { + pub turn_id: String, + pub file_path: String, + pub library_id: i32, + /// Shared event buffer — multiple SSE connections can read independently. + /// Each connection tracks its own `skip_before` offset. + events: Mutex>, + /// Monotonic counter: total events pushed (may exceed events.len() + /// due to eviction). Used for skip_before indexing. + total_events_pushed: AtomicU32, + /// The event index that this entry started with. Adjusts on eviction + /// so that `skip_before` stays absolute across connections. + base_index: AtomicU32, + pub status: AtomicU32, + /// Abort handle for the spawned agentic task, set once after spawn. + /// Behind a std `Mutex` because the entry is shared via `Arc` and the + /// handle is installed after the entry is already in the registry. + abort_handle: StdMutex>, + pub created_at: Instant, + notify: Arc, +} + +impl TurnEntry { + pub fn new(turn_id: String, file_path: String, library_id: i32) -> Self { + Self { + turn_id, + file_path, + library_id, + events: Mutex::new(Vec::new()), + total_events_pushed: AtomicU32::new(0), + base_index: AtomicU32::new(0), + status: AtomicU32::new(TurnStatus::Running as u32), + abort_handle: StdMutex::new(None), + created_at: Instant::now(), + notify: Arc::new(Notify::new()), + } + } + + /// Install the abort handle for the spawned agentic task. Called once, + /// right after the task is spawned. + pub fn set_abort_handle(&self, handle: AbortHandle) { + *self.abort_handle.lock().expect("abort_handle poisoned") = Some(handle); + } + + /// Abort the spawned agentic task, if a handle was installed. Returns + /// `true` if a task was aborted. + pub fn abort(&self) -> bool { + if let Some(handle) = self + .abort_handle + .lock() + .expect("abort_handle poisoned") + .take() + { + handle.abort(); + true + } else { + false + } + } + + /// Push an event into the buffer. Evicts oldest events if the buffer + /// exceeds `MAX_BUFFERED_EVENTS`. Notifies all waiting SSE connections. + pub async fn push_event(&self, event: ChatStreamEvent) { + { + let mut events = self.events.lock().await; + + // Evict oldest events if we've hit the cap. + if events.len() >= MAX_BUFFERED_EVENTS { + // Drop the oldest event to make room and advance the base + // index so skip_before stays absolute across connections. + events.remove(0); + self.base_index.fetch_add(1, Ordering::Relaxed); + } + + events.push(event); + // Increment while holding the buffer lock so the counter stays in + // lock-step with the buffer even if multiple writers ever exist. + self.total_events_pushed.fetch_add(1, Ordering::Relaxed); + } + + self.notify.notify_waiters(); + } + + /// Get a snapshot of turn metadata for the `turn_info` SSE event. + pub async fn info(&self) -> TurnInfo { + let events = self.events.lock().await; + let buffered = events.len() as u32; + let total = self.total_events_pushed.load(Ordering::Relaxed); + drop(events); + + TurnInfo { + turn_id: self.turn_id.clone(), + file_path: self.file_path.clone(), + library_id: self.library_id, + status: self.status.load(Ordering::Relaxed).into(), + total_events_pushed: total, + buffered_count: buffered, + } + } + + /// Set the terminal status and notify all waiters. + pub fn set_terminal_status(&self, status: TurnStatus) { + self.status.store(status as u32, Ordering::Relaxed); + self.notify.notify_waiters(); + } + + /// Read buffered events at or after absolute index `skip_before` without + /// waiting. Distinguishes "evicted" (Gone) from "caught up" (no new + /// events yet) — the previous boolean/`Option` API conflated the two. + pub async fn replay_from(&self, skip_before: u32) -> ReplayOutcome { + let events = self.events.lock().await; + let base = self.base_index.load(Ordering::Relaxed); + + // The buffer holds absolute indices [base, base + len). A request + // below `base` asked for events that have been evicted. + if skip_before < base { + return ReplayOutcome::Gone; + } + + let offset = (skip_before - base) as usize; + let next_skip = base + events.len() as u32; + if offset >= events.len() { + // Caught up to (or past) the live edge — nothing new yet. + return ReplayOutcome::CaughtUp { next_skip }; + } + + ReplayOutcome::Events { + events: events[offset..].to_vec(), + next_skip, + } + } + + /// Wait for the next batch of events past `skip_before`, the turn to + /// finish, or eviction. Returns: + /// - `Events` when new events are available (drained before any terminal + /// signal so the final `Done`/`Error` is never dropped), + /// - `CaughtUp` only when the turn has reached a terminal status and the + /// reader is fully drained (the caller should close the stream), + /// - `Gone` when `skip_before` points into evicted territory. + pub async fn next_batch(&self, skip_before: u32) -> ReplayOutcome { + loop { + // Register interest BEFORE inspecting state so a push/terminal that + // races between our read and our await can't be lost (Notify's + // `notify_waiters` does not store a permit). + let notified = self.notify.notified(); + tokio::pin!(notified); + notified.as_mut().enable(); + + match self.replay_from(skip_before).await { + ReplayOutcome::CaughtUp { next_skip } => { + // No new events. If the turn is finished, every event + // (including the terminal one) has already been drained + // above on a prior call, so signal the caller to close. + if !self.is_running() { + return ReplayOutcome::CaughtUp { next_skip }; + } + // Still running — wait for the next push or terminal. + } + other => return other, // Events or Gone + } + + notified.await; + } + } + + /// Check if this turn is still running. + pub fn is_running(&self) -> bool { + self.status.load(Ordering::Relaxed) == TurnStatus::Running as u32 + } +} + +/// In-memory registry of all active chat turns. Injected into `AppState` +/// and shared across all handlers. +pub struct TurnRegistry { + entries: Mutex>>, + timeout_secs: u64, +} + +impl TurnRegistry { + pub fn new(timeout_secs: u64) -> Self { + Self { + entries: Mutex::new(HashMap::new()), + timeout_secs, + } + } + + /// Returns the cleanup timeout in seconds. + pub fn timeout_secs(&self) -> u64 { + self.timeout_secs + } + + /// Insert a new turn entry. Returns the turn_id. + pub async fn insert(&self, entry: Arc) -> String { + let turn_id = entry.turn_id.clone(); + let mut entries = self.entries.lock().await; + entries.insert(turn_id.clone(), entry); + turn_id + } + + /// Look up a turn by id. Returns None if not found or expired. + pub async fn get(&self, turn_id: &str) -> Option> { + let entries = self.entries.lock().await; + entries.get(turn_id).cloned() + } + + /// Clean up stale entries older than the timeout. Returns the count of + /// entries removed. + pub async fn cleanup_stale(&self) -> usize { + let mut entries = self.entries.lock().await; + let _now = Instant::now(); + let stale: Vec = entries + .iter() + .filter(|(_, entry)| entry.created_at.elapsed().as_secs() > self.timeout_secs) + .map(|(id, _)| id.clone()) + .collect(); + + for id in &stale { + entries.remove(id); + } + + if !stale.is_empty() { + log::info!( + "TurnRegistry: cleaned up {} stale entries (timeout={}s)", + stale.len(), + self.timeout_secs + ); + } + + stale.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::insight_chat::ChatStreamEvent; + use std::time::Duration; + + /// Unwrap the events from a `ReplayOutcome::Events`, panicking otherwise. + fn events_of(outcome: ReplayOutcome) -> Vec { + match outcome { + ReplayOutcome::Events { events, .. } => events, + other => panic!("expected Events, got {other:?}"), + } + } + + // ── TurnStatus ────────────────────────────────────────────────── + + #[test] + fn turn_status_from_u32_valid_values() { + assert_eq!(TurnStatus::from(0), TurnStatus::Running); + assert_eq!(TurnStatus::from(1), TurnStatus::Done); + assert_eq!(TurnStatus::from(2), TurnStatus::Error); + assert_eq!(TurnStatus::from(3), TurnStatus::Cancelled); + } + + #[test] + fn turn_status_from_u32_unknown_defaults_to_running() { + assert_eq!(TurnStatus::from(4), TurnStatus::Running); + assert_eq!(TurnStatus::from(u32::MAX), TurnStatus::Running); + } + + #[test] + fn turn_status_as_str() { + assert_eq!(TurnStatus::Running.as_str(), "running"); + assert_eq!(TurnStatus::Done.as_str(), "done"); + assert_eq!(TurnStatus::Error.as_str(), "error"); + assert_eq!(TurnStatus::Cancelled.as_str(), "cancelled"); + } + + // ── TurnEntry ─────────────────────────────────────────────────── + + #[tokio::test] + async fn turn_entry_push_and_replay() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + + entry + .push_event(ChatStreamEvent::TextDelta("hello".to_string())) + .await; + entry + .push_event(ChatStreamEvent::TextDelta(" world".to_string())) + .await; + + let events = events_of(entry.replay_from(0).await); + assert_eq!(events.len(), 2); + } + + #[tokio::test] + async fn turn_entry_replay_with_skip() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + + for i in 0..5 { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + + // skip_before=0 → all 5 events + let all = events_of(entry.replay_from(0).await); + assert_eq!(all.len(), 5); + + // skip_before=2 → events 2,3,4 (3 events) + let skipped = events_of(entry.replay_from(2).await); + assert_eq!(skipped.len(), 3); + + // skip_before=5 → caught up to the live edge (not Gone). + assert!(matches!( + entry.replay_from(5).await, + ReplayOutcome::CaughtUp { next_skip: 5 } + )); + } + + #[tokio::test] + async fn turn_entry_replay_empty_by_default() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + // Empty buffer with skip_before=0 → caught up (nothing to replay yet). + assert!(matches!( + entry.replay_from(0).await, + ReplayOutcome::CaughtUp { next_skip: 0 } + )); + } + + #[tokio::test] + async fn turn_entry_is_running_initially() { + let entry = TurnEntry::new("t1".to_string(), "/photo.jpg".to_string(), 1); + assert!(entry.is_running()); + } + + #[tokio::test] + async fn turn_entry_set_terminal_status() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + assert!(entry.is_running()); + entry.set_terminal_status(TurnStatus::Done); + assert!(!entry.is_running()); + } + + #[tokio::test] + async fn turn_entry_info() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 42, + )); + + entry + .push_event(ChatStreamEvent::TextDelta("x".to_string())) + .await; + entry.set_terminal_status(TurnStatus::Done); + + let info = entry.info().await; + assert_eq!(info.turn_id, "t1"); + assert_eq!(info.file_path, "/photo.jpg"); + assert_eq!(info.library_id, 42); + assert_eq!(info.status, TurnStatus::Done); + assert_eq!(info.total_events_pushed, 1); + assert_eq!(info.buffered_count, 1); + } + + #[tokio::test] + async fn turn_entry_eviction_caps_buffer() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + + // Push MAX_BUFFERED_EVENTS + 10 events. + for i in 0..(MAX_BUFFERED_EVENTS + 10) { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + + // Asking from absolute 0 after eviction is Gone (0-9 were dropped). + assert!(matches!(entry.replay_from(0).await, ReplayOutcome::Gone)); + + // Reading from the new base (10) returns the full capped buffer. + let events = events_of(entry.replay_from(10).await); + assert_eq!(events.len(), MAX_BUFFERED_EVENTS); + + // First event should be at index 10 (0-9 were evicted). + if let ChatStreamEvent::TextDelta(s) = &events[0] { + assert_eq!(s, "e10"); + } else { + panic!("expected TextDelta"); + } + + // Last event should be at index MAX_BUFFERED_EVENTS + 9. + if let ChatStreamEvent::TextDelta(s) = &events[events.len() - 1] { + assert_eq!(s, &format!("e{}", MAX_BUFFERED_EVENTS + 9)); + } else { + panic!("expected TextDelta"); + } + } + + #[tokio::test] + async fn turn_entry_replay_evicted_index_is_gone() { + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + + // Push one past the cap so exactly one event (index 0) is evicted. + for i in 0..=MAX_BUFFERED_EVENTS { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + + // Base is now 1; asking from absolute 0 is evicted territory → Gone. + assert!(matches!(entry.replay_from(0).await, ReplayOutcome::Gone)); + + // skip_before = MAX_BUFFERED_EVENTS → last event only (index valid). + let last = events_of(entry.replay_from(MAX_BUFFERED_EVENTS as u32).await); + assert_eq!(last.len(), 1); + + // skip_before = MAX_BUFFERED_EVENTS + 1 → caught up to the live edge. + assert!(matches!( + entry.replay_from((MAX_BUFFERED_EVENTS + 1) as u32).await, + ReplayOutcome::CaughtUp { .. } + )); + } + + // ── TurnRegistry ──────────────────────────────────────────────── + + #[tokio::test] + async fn turn_registry_insert_and_get() { + let registry = TurnRegistry::new(300); + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + let id = registry.insert(entry).await; + assert_eq!(id, "t1"); + + let retrieved = registry.get("t1").await; + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().turn_id, "t1"); + } + + #[tokio::test] + async fn turn_registry_get_nonexistent_returns_none() { + let registry = TurnRegistry::new(300); + assert!(registry.get("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn turn_registry_cleanup_stale_removes_old_entries() { + let registry = TurnRegistry::new(0); + let mut entry = TurnEntry::new("t1".to_string(), "/photo.jpg".to_string(), 1); + entry.created_at = Instant::now() - Duration::from_secs(1); + registry.insert(Arc::new(entry)).await; + + let cleaned = registry.cleanup_stale().await; + assert_eq!(cleaned, 1); + assert!(registry.get("t1").await.is_none()); + } + + #[tokio::test] + async fn turn_registry_cleanup_stale_preserves_recent() { + let registry = TurnRegistry::new(3600); // 1 hour + let entry = Arc::new(TurnEntry::new( + "t1".to_string(), + "/photo.jpg".to_string(), + 1, + )); + registry.insert(entry).await; + + let cleaned = registry.cleanup_stale().await; + assert_eq!(cleaned, 0); + assert!(registry.get("t1").await.is_some()); + } + + #[tokio::test] + async fn turn_registry_cleanup_stale_multiple() { + let registry = TurnRegistry::new(0); + + for i in 0..5 { + let mut entry = TurnEntry::new(format!("t{i}"), "/photo.jpg".to_string(), 1); + entry.created_at = Instant::now() - Duration::from_secs(1); + registry.insert(Arc::new(entry)).await; + } + + let cleaned = registry.cleanup_stale().await; + assert_eq!(cleaned, 5); + } + + #[tokio::test] + async fn turn_registry_timeout_secs() { + let registry = TurnRegistry::new(600); + assert_eq!(registry.timeout_secs(), 600); + } + + // ── next_batch / live replay ──────────────────────────────────── + + /// Drain a turn the way the SSE replay handler does: pull batches via + /// `next_batch` until the turn is finished and fully drained. + async fn drain_to_end(entry: Arc) -> Vec { + let mut out = Vec::new(); + let mut skip = 0u32; + while let ReplayOutcome::Events { events, next_skip } = entry.next_batch(skip).await { + out.extend(events); + skip = next_skip; + } + out + } + + fn is_terminal(ev: &ChatStreamEvent) -> bool { + matches!(ev, ChatStreamEvent::Done { .. } | ChatStreamEvent::Error(_)) + } + + /// The core guarantee behind the replay rewrite: a reader waiting on + /// `next_batch` always receives the terminal event, even though the + /// writer flips status to terminal immediately after pushing it. + #[tokio::test] + async fn next_batch_always_delivers_terminal_event() { + for _ in 0..50 { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + + let writer = entry.clone(); + let w = tokio::spawn(async move { + writer + .push_event(ChatStreamEvent::IterationStart { n: 1, max: 6 }) + .await; + writer + .push_event(ChatStreamEvent::TextDelta("hi".into())) + .await; + // Push terminal then flip status with no await between — the + // race that previously dropped the Done on the reader side. + writer + .push_event(ChatStreamEvent::Done { + tool_calls_made: 0, + iterations_used: 1, + truncated: false, + prompt_tokens: None, + eval_tokens: None, + num_ctx: None, + amended_insight_id: None, + backend_used: "local".into(), + model_used: "m".into(), + cancelled: false, + }) + .await; + writer.set_terminal_status(TurnStatus::Done); + }); + + let events = drain_to_end(entry).await; + w.await.unwrap(); + + assert!( + events.last().is_some_and(is_terminal), + "terminal event missing; got {} events", + events.len() + ); + assert_eq!(events.len(), 3, "expected IterationStart, TextDelta, Done"); + } + } + + /// A reader that connects before any event is pushed blocks in + /// `next_batch` and then receives events as the writer produces them. + #[tokio::test] + async fn next_batch_waits_for_late_events() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + + let writer = entry.clone(); + tokio::spawn(async move { + tokio::task::yield_now().await; + writer + .push_event(ChatStreamEvent::TextDelta("late".into())) + .await; + writer.set_terminal_status(TurnStatus::Done); + }); + + // First call blocks until the writer pushes, rather than returning + // CaughtUp on the empty buffer of a running turn. + match entry.next_batch(0).await { + ReplayOutcome::Events { events, next_skip } => { + assert_eq!(events.len(), 1); + assert_eq!(next_skip, 1); + } + other => panic!("expected Events, got {other:?}"), + } + } + + #[tokio::test] + async fn next_batch_closes_on_terminal_when_caught_up() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + entry + .push_event(ChatStreamEvent::TextDelta("x".into())) + .await; + entry.set_terminal_status(TurnStatus::Done); + + // Caught up (skip past the one buffered event) on a finished turn → + // CaughtUp so the handler closes the stream rather than hanging. + assert!(matches!( + entry.next_batch(1).await, + ReplayOutcome::CaughtUp { .. } + )); + } + + #[tokio::test] + async fn next_batch_reports_gone_for_evicted_index() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + for i in 0..=MAX_BUFFERED_EVENTS { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + // Index 0 was evicted (base advanced to 1). + assert!(matches!(entry.next_batch(0).await, ReplayOutcome::Gone)); + } + + // ── abort handle (#1 cancellation) ────────────────────────────── + + #[tokio::test] + async fn abort_handle_aborts_task_once() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + + // No handle installed yet → abort is a no-op. + assert!(!entry.abort()); + + let handle = tokio::spawn(async { + // Long-lived task that only ends via abort. + futures::future::pending::<()>().await; + }); + entry.set_abort_handle(handle.abort_handle()); + + assert!(entry.abort(), "first abort should fire"); + assert!(!entry.abort(), "handle is taken; second abort is a no-op"); + + // The aborted task resolves to a cancellation JoinError. + let join = handle.await; + assert!(join.unwrap_err().is_cancelled()); + } + + #[tokio::test] + async fn base_index_tracks_eviction() { + let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1)); + for i in 0..(MAX_BUFFERED_EVENTS + 5) { + entry + .push_event(ChatStreamEvent::TextDelta(format!("e{i}"))) + .await; + } + let info = entry.info().await; + // 5 events evicted; total keeps climbing, buffer stays capped. + assert_eq!(info.total_events_pushed, (MAX_BUFFERED_EVENTS + 5) as u32); + assert_eq!(info.buffered_count, MAX_BUFFERED_EVENTS as u32); + // First live index is 5: reading from there yields the full buffer. + let from_base = events_of(entry.replay_from(5).await); + assert_eq!(from_base.len(), MAX_BUFFERED_EVENTS); + } +} diff --git a/src/main.rs b/src/main.rs index a3af554..4099a5d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -197,6 +197,28 @@ fn main() -> std::io::Result<()> { app_state.library_health.clone(), ); + // Periodically clean up stale turn entries from the in-memory + // registry. Runs at the same interval as the configured timeout, + // drops entries older than that timeout. + { + let registry = app_state.turn_registry.clone(); + let timeout_secs = registry.timeout_secs(); + tokio::spawn(async move { + // Sweep at most every 5 minutes, and never less often than the + // timeout itself — otherwise entries could linger up to ~2× the + // configured timeout before being reclaimed. + let interval_secs = timeout_secs.clamp(1, 300); + let interval = tokio::time::Duration::from_secs(interval_secs); + loop { + tokio::time::sleep(interval).await; + let cleaned = registry.cleanup_stale().await; + if cleaned > 0 { + log::info!("TurnRegistry: cleaned up {cleaned} stale entries"); + } + } + }); + } + // Spawn background job to generate daily conversation summaries { use crate::ai::generate_daily_summaries; @@ -335,6 +357,9 @@ fn main() -> std::io::Result<()> { .service(ai::chat_stream_handler) .service(ai::chat_history_handler) .service(ai::chat_rewind_handler) + .service(ai::turn_async_handler) + .service(ai::turn_replay_handler) + .service(ai::cancel_turn_handler) .service(ai::rate_insight_handler) .service(ai::export_training_data_handler) .service(libraries::list_libraries) diff --git a/src/state.rs b/src/state.rs index 4ba63bf..f9adda7 100644 --- a/src/state.rs +++ b/src/state.rs @@ -4,6 +4,7 @@ use crate::ai::face_client::FaceClient; use crate::ai::insight_chat::{ChatLockMap, InsightChatService}; use crate::ai::llamacpp::LlamaCppClient; use crate::ai::openrouter::OpenRouterClient; +use crate::ai::turn_registry::TurnRegistry; use crate::ai::{InsightGenerator, OllamaClient, SmsApiClient}; use crate::database::{ CalendarEventDao, DailySummaryDao, ExifDao, InsightDao, InsightGenerationJobDao, KnowledgeDao, @@ -78,19 +79,10 @@ pub struct AppState { pub insight_generator: InsightGenerator, /// Chat continuation service. Hold an Arc so handlers can clone cheaply. pub insight_chat: Arc, - /// Face inference client (calls Apollo's `/api/internal/faces/*`). - /// Disabled (`is_enabled() == false`) when neither `APOLLO_FACE_API_BASE_URL` - /// nor `APOLLO_API_BASE_URL` is set; the file-watch hook (Phase 3) and - /// manual-face-create handler short-circuit in that case. + pub turn_registry: Arc, pub face_client: FaceClient, - /// CLIP inference client (calls Apollo's `/api/internal/clip/*`). - /// Same disabled semantics as `face_client`: unset env → no-op - /// backlog drain, /photos/search returns an empty result. pub clip_client: ClipClient, - /// Tracks async insight generation jobs (spawned by generate endpoints). pub insight_job_dao: Arc>>, - /// In-memory map from job_id → tokio AbortHandle for running tasks. - /// Used to abort server-side tasks on cancel or regenerate. pub insight_job_handles: Arc>>, } @@ -127,6 +119,7 @@ impl AppState { sms_client: SmsApiClient, insight_generator: InsightGenerator, insight_chat: Arc, + turn_registry: Arc, preview_dao: Arc>>, face_client: FaceClient, clip_client: ClipClient, @@ -171,6 +164,7 @@ impl AppState { sms_client, insight_generator, insight_chat, + turn_registry, face_client, clip_client, insight_job_dao, @@ -310,6 +304,14 @@ impl Default for AppState { chat_locks, )); + // Turn registry for reconnectable chat turns. 5-minute timeout for + // stale turns (background cleaner drops entries older than this). + let timeout_secs: u64 = env::var("INSIGHT_CHAT_TURN_TIMEOUT_SECS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(300); + let turn_registry = Arc::new(TurnRegistry::new(timeout_secs)); + // Ensure preview clips directory exists let preview_clips_path = env::var("PREVIEW_CLIPS_DIRECTORY").unwrap_or_else(|_| "preview_clips".to_string()); @@ -332,6 +334,7 @@ impl Default for AppState { sms_client, insight_generator, insight_chat, + turn_registry, preview_dao, face_client, clip_client, @@ -490,6 +493,9 @@ impl AppState { chat_locks, )); + // Turn registry for test state. + let turn_registry = Arc::new(TurnRegistry::new(300)); + // Initialize test preview DAO let preview_dao: Arc>> = Arc::new(Mutex::new(Box::new(SqlitePreviewDao::new()))); @@ -518,6 +524,7 @@ impl AppState { sms_client, insight_generator, insight_chat, + turn_registry, preview_dao, FaceClient::new(None), // disabled in test ClipClient::new(None), // disabled in test