feature/insight-jobs #102
Generated
+3
@@ -2104,6 +2104,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"urlencoding",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
"zerocopy",
|
||||
]
|
||||
@@ -4391,7 +4392,9 @@ version = "1.23.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76"
|
||||
dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"js-sys",
|
||||
"serde_core",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ image_hasher = "3.0"
|
||||
bk-tree = "0.5"
|
||||
async-trait = "0.1"
|
||||
indicatif = "0.17"
|
||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||
|
||||
# Windows lacks system sqlite3, so re-enable the bundled C build there.
|
||||
# Linux/macOS use the system library (faster builds, smaller binary).
|
||||
|
||||
+510
-4
@@ -1,4 +1,5 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, Responder, delete, get, post, web};
|
||||
use futures::StreamExt;
|
||||
use opentelemetry::KeyValue;
|
||||
use opentelemetry::trace::{Span, Status, Tracer};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -1433,7 +1434,26 @@ pub async fn chat_stream_handler(
|
||||
}
|
||||
|
||||
fn render_sse_frame(ev: &ChatStreamEvent) -> String {
|
||||
let (event_name, payload) = match ev {
|
||||
let (event_name, payload) = sse_event_payload(ev);
|
||||
let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string());
|
||||
format!("event: {}\ndata: {}\n\n", event_name, data)
|
||||
}
|
||||
|
||||
/// Like `render_sse_frame`, but stamps the event's absolute sequence number
|
||||
/// (`seq`) into the payload so reconnecting replay clients can compute
|
||||
/// `skip_before` precisely. `seq` is distinct from the tool-pairing `index`
|
||||
/// already carried by `tool_call`/`tool_result`.
|
||||
fn render_indexed_frame(ev: &ChatStreamEvent, seq: u32) -> String {
|
||||
let (event_name, mut payload) = sse_event_payload(ev);
|
||||
if let serde_json::Value::Object(map) = &mut payload {
|
||||
map.insert("seq".to_string(), serde_json::json!(seq));
|
||||
}
|
||||
let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string());
|
||||
format!("event: {}\ndata: {}\n\n", event_name, data)
|
||||
}
|
||||
|
||||
fn sse_event_payload(ev: &ChatStreamEvent) -> (&'static str, serde_json::Value) {
|
||||
match ev {
|
||||
ChatStreamEvent::IterationStart { n, max } => {
|
||||
("iteration_start", serde_json::json!({ "n": n, "max": max }))
|
||||
}
|
||||
@@ -1471,6 +1491,7 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String {
|
||||
amended_insight_id,
|
||||
backend_used,
|
||||
model_used,
|
||||
cancelled,
|
||||
} => (
|
||||
"done",
|
||||
serde_json::json!({
|
||||
@@ -1483,6 +1504,7 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String {
|
||||
"amended_insight_id": amended_insight_id,
|
||||
"backend": backend_used,
|
||||
"model": model_used,
|
||||
"cancelled": cancelled,
|
||||
}),
|
||||
),
|
||||
// Apollo's frontend SSE consumer (and its free-chat backend, which
|
||||
@@ -1491,7 +1513,491 @@ fn render_sse_frame(ev: &ChatStreamEvent) -> String {
|
||||
// "no insight found for path") was silently dropped, leaving an
|
||||
// empty assistant bubble with no clue why the turn died.
|
||||
ChatStreamEvent::Error(msg) => ("error_message", serde_json::json!({ "message": msg })),
|
||||
};
|
||||
let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string());
|
||||
format!("event: {}\ndata: {}\n\n", event_name, data)
|
||||
}
|
||||
}
|
||||
|
||||
/// POST /insights/chat/turn — async turn dispatch. Returns turn_id immediately,
|
||||
/// client then polls GET /insights/chat/turn/{turn_id} for SSE replay.
|
||||
#[post("/insights/chat/turn")]
|
||||
pub async fn turn_async_handler(
|
||||
http_request: HttpRequest,
|
||||
claims: Claims,
|
||||
request: web::Json<ChatTurnHttpRequest>,
|
||||
app_state: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
let parent_context = extract_context_from_request(&http_request);
|
||||
let tracer = global_tracer();
|
||||
let mut span = tracer.start_with_context("http.insights.chat_turn_async", &parent_context);
|
||||
span.set_attribute(KeyValue::new("file_path", request.file_path.clone()));
|
||||
|
||||
let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) {
|
||||
Ok(Some(lib)) => lib,
|
||||
Ok(None) => app_state.primary_library(),
|
||||
Err(e) => {
|
||||
return HttpResponse::BadRequest().json(serde_json::json!({
|
||||
"error": format!("invalid library: {}", e)
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
let user_id = claims.sub.parse::<i32>().unwrap_or(1);
|
||||
|
||||
let chat_req = ChatTurnRequest {
|
||||
library_id: library.id,
|
||||
user_id,
|
||||
file_path: request.file_path.clone(),
|
||||
user_message: request.user_message.clone(),
|
||||
model: request.model.clone(),
|
||||
backend: request.backend.clone(),
|
||||
num_ctx: request.num_ctx,
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
top_k: request.top_k,
|
||||
min_p: request.min_p,
|
||||
max_iterations: request.max_iterations,
|
||||
system_prompt: request.system_prompt.clone(),
|
||||
persona_id: request.persona_id.clone(),
|
||||
amend: request.amend,
|
||||
regenerate: request.regenerate,
|
||||
};
|
||||
|
||||
let service = app_state.insight_chat.clone();
|
||||
let registry = app_state.turn_registry.clone();
|
||||
|
||||
let turn_id = service.chat_turn_async(registry, chat_req).await;
|
||||
span.set_attribute(KeyValue::new("turn_id", turn_id.clone()));
|
||||
span.set_status(Status::Ok);
|
||||
HttpResponse::Accepted().json(serde_json::json!({
|
||||
"turn_id": turn_id,
|
||||
"status": "running"
|
||||
}))
|
||||
}
|
||||
|
||||
/// Query params for the SSE replay stream.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ReplayQuery {
|
||||
/// Replay events from this absolute sequence number (`seq`) onward.
|
||||
/// Absent or 0 replays from the beginning. On reconnect the client sends
|
||||
/// the `seq` of the last event it applied, plus one.
|
||||
pub skip_before: Option<u32>,
|
||||
}
|
||||
|
||||
/// GET /insights/chat/turn/{turn_id} — SSE replay stream.
|
||||
#[get("/insights/chat/turn/{turn_id}")]
|
||||
pub async fn turn_replay_handler(
|
||||
http_request: HttpRequest,
|
||||
path: web::Path<String>,
|
||||
query: web::Query<ReplayQuery>,
|
||||
app_state: web::Data<AppState>,
|
||||
) -> HttpResponse {
|
||||
use crate::ai::turn_registry::ReplayOutcome;
|
||||
|
||||
let turn_id = path.into_inner();
|
||||
let skip_before = query.skip_before.unwrap_or(0);
|
||||
|
||||
let parent_context = extract_context_from_request(&http_request);
|
||||
let tracer = global_tracer();
|
||||
let mut span = tracer.start_with_context("ai.chat.turn.replay", &parent_context);
|
||||
span.set_attribute(KeyValue::new("turn_id", turn_id.clone()));
|
||||
span.set_attribute(KeyValue::new("skip_before", skip_before as i64));
|
||||
|
||||
let registry = app_state.turn_registry.clone();
|
||||
let entry = match registry.get(&turn_id).await {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
span.set_status(Status::error("turn not found"));
|
||||
return HttpResponse::NotFound().json(serde_json::json!({
|
||||
"error": format!("turn {} not found", turn_id)
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
let info = entry.info().await;
|
||||
span.set_attribute(KeyValue::new("status", info.status.as_str()));
|
||||
span.set_attribute(KeyValue::new(
|
||||
"event_count",
|
||||
info.total_events_pushed as i64,
|
||||
));
|
||||
let turn_info_frame = render_turn_info_frame(&info);
|
||||
|
||||
// Initial buffered batch: events produced before this connection attached.
|
||||
// Stamp each frame with its absolute `seq` so the client can track
|
||||
// `skip_before` precisely across reconnects.
|
||||
let (initial_frames, start_skip) = match entry.replay_from(skip_before).await {
|
||||
ReplayOutcome::Gone => {
|
||||
span.set_status(Status::error("buffer evicted"));
|
||||
return HttpResponse::Gone().json(serde_json::json!({
|
||||
"error": "turn history has expired (buffer evicted)"
|
||||
}));
|
||||
}
|
||||
ReplayOutcome::CaughtUp { next_skip } => (Vec::new(), next_skip),
|
||||
ReplayOutcome::Events { events, next_skip } => {
|
||||
let frames: Vec<actix_web::web::Bytes> = events
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, ev)| {
|
||||
actix_web::web::Bytes::from(render_indexed_frame(&ev, skip_before + i as u32))
|
||||
})
|
||||
.collect();
|
||||
(frames, next_skip)
|
||||
}
|
||||
};
|
||||
|
||||
span.set_status(Status::Ok);
|
||||
let running = entry.is_running();
|
||||
|
||||
// Head: the `turn_info` event followed by any already-buffered events.
|
||||
let head = futures::stream::once(async move {
|
||||
Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(turn_info_frame))
|
||||
})
|
||||
.chain(futures::stream::iter(
|
||||
initial_frames.into_iter().map(Ok::<_, actix_web::Error>),
|
||||
));
|
||||
|
||||
if !running {
|
||||
// Completed turn: every event — including the terminal Done/Error — is
|
||||
// already in the buffered batch above. Emit it and close.
|
||||
return HttpResponse::Ok()
|
||||
.content_type("text/event-stream")
|
||||
.insert_header(("Cache-Control", "no-cache"))
|
||||
.insert_header(("X-Accel-Buffering", "no"))
|
||||
.streaming(head);
|
||||
}
|
||||
|
||||
// In-progress turn: after the head, wait for new events. `next_batch`
|
||||
// drains every buffered event (including the terminal one) before it
|
||||
// reports the turn finished, so the final Done/Error is never dropped;
|
||||
// CaughtUp then closes the stream by returning None.
|
||||
let tail = futures::stream::unfold(
|
||||
(
|
||||
entry,
|
||||
start_skip,
|
||||
Vec::<actix_web::web::Bytes>::new(),
|
||||
false,
|
||||
),
|
||||
|(entry, skip, pending, finished)| async move {
|
||||
// Flush queued frames from a previous multi-event batch first.
|
||||
if let Some((first, rest)) = pending.split_first() {
|
||||
return Some((Ok(first.clone()), (entry, skip, rest.to_vec(), finished)));
|
||||
}
|
||||
if finished {
|
||||
return None;
|
||||
}
|
||||
|
||||
match entry.next_batch(skip).await {
|
||||
ReplayOutcome::Events { events, next_skip } => {
|
||||
let frames: Vec<actix_web::web::Bytes> = events
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, ev)| {
|
||||
actix_web::web::Bytes::from(render_indexed_frame(&ev, skip + i as u32))
|
||||
})
|
||||
.collect();
|
||||
// next_batch only returns Events for a non-empty batch.
|
||||
let (first, rest) = frames.split_first().expect("non-empty batch");
|
||||
Some((Ok(first.clone()), (entry, next_skip, rest.to_vec(), false)))
|
||||
}
|
||||
// Terminal reached and fully drained — close the connection.
|
||||
ReplayOutcome::CaughtUp { .. } => None,
|
||||
ReplayOutcome::Gone => {
|
||||
// Evicted mid-stream: emit one error frame, then close.
|
||||
let gone =
|
||||
actix_web::web::Bytes::from(render_sse_frame(&ChatStreamEvent::Error(
|
||||
"turn history has expired (buffer evicted)".to_string(),
|
||||
)));
|
||||
Some((Ok(gone), (entry, skip, Vec::new(), true)))
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
HttpResponse::Ok()
|
||||
.content_type("text/event-stream")
|
||||
.insert_header(("Cache-Control", "no-cache"))
|
||||
.insert_header(("X-Accel-Buffering", "no"))
|
||||
.streaming(head.chain(tail))
|
||||
}
|
||||
|
||||
fn render_turn_info_frame(info: &crate::ai::turn_registry::TurnInfo) -> String {
|
||||
let payload = serde_json::json!({
|
||||
"turn_id": info.turn_id,
|
||||
"file_path": info.file_path,
|
||||
"library_id": info.library_id,
|
||||
"status": info.status.as_str(),
|
||||
"total_events_pushed": info.total_events_pushed,
|
||||
"buffered_count": info.buffered_count,
|
||||
});
|
||||
let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string());
|
||||
format!("event: turn_info\ndata: {}\n\n", data)
|
||||
}
|
||||
|
||||
/// DELETE /insights/chat/turn/{turn_id} — cancel a running turn.
|
||||
#[delete("/insights/chat/turn/{turn_id}")]
|
||||
pub async fn cancel_turn_handler(
|
||||
http_request: HttpRequest,
|
||||
path: web::Path<String>,
|
||||
app_state: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
let turn_id = path.into_inner();
|
||||
|
||||
let parent_context = extract_context_from_request(&http_request);
|
||||
let tracer = global_tracer();
|
||||
let mut span = tracer.start_with_context("ai.chat.turn.cancel", &parent_context);
|
||||
span.set_attribute(KeyValue::new("turn_id", turn_id.clone()));
|
||||
|
||||
let registry = app_state.turn_registry.clone();
|
||||
let entry = match registry.get(&turn_id).await {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
span.set_status(Status::error("turn not found"));
|
||||
return HttpResponse::NotFound().json(serde_json::json!({
|
||||
"error": format!("turn {} not found", turn_id)
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
// Abort the spawned task so it stops producing events promptly. The loop
|
||||
// also checks `is_running()` at each iteration boundary as a graceful
|
||||
// backstop in case the abort lands between await points.
|
||||
let aborted = entry.abort();
|
||||
span.set_attribute(KeyValue::new("aborted", aborted));
|
||||
|
||||
// Push the terminal event BEFORE flipping status: a replay reader treats a
|
||||
// terminal status with no buffered tail as "closed", so the Done must be
|
||||
// buffered first for in-progress connections to receive it.
|
||||
let _ = entry
|
||||
.push_event(ChatStreamEvent::Done {
|
||||
tool_calls_made: 0,
|
||||
iterations_used: 0,
|
||||
truncated: false,
|
||||
prompt_tokens: None,
|
||||
eval_tokens: None,
|
||||
num_ctx: None,
|
||||
amended_insight_id: None,
|
||||
backend_used: "cancelled".to_string(),
|
||||
model_used: "cancelled".to_string(),
|
||||
cancelled: true,
|
||||
})
|
||||
.await;
|
||||
entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Cancelled);
|
||||
span.set_status(Status::Ok);
|
||||
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
"cancelled": true
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod turn_replay_tests {
|
||||
use super::{cancel_turn_handler, render_indexed_frame, turn_replay_handler};
|
||||
use crate::ai::insight_chat::ChatStreamEvent;
|
||||
use crate::ai::turn_registry::{TurnEntry, TurnStatus};
|
||||
use crate::state::AppState;
|
||||
use actix_web::test as actix_test;
|
||||
use actix_web::{App, web::Data};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Serialize `AppState::test_state()` construction across the parallel
|
||||
/// tests in this module: each build opens ~10 DAO connections to the one
|
||||
/// shared `DATABASE_URL` file, and doing several at once races the WAL
|
||||
/// `journal_mode` switch into a spurious "database is locked". The test
|
||||
/// bodies themselves still run in parallel; only the open is gated.
|
||||
static DB_INIT: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||
|
||||
fn build_state() -> Data<AppState> {
|
||||
let _guard = DB_INIT.lock().unwrap_or_else(|p| p.into_inner());
|
||||
Data::new(AppState::test_state())
|
||||
}
|
||||
|
||||
fn done(cancelled: bool) -> ChatStreamEvent {
|
||||
ChatStreamEvent::Done {
|
||||
tool_calls_made: 0,
|
||||
iterations_used: 1,
|
||||
truncated: false,
|
||||
prompt_tokens: Some(10),
|
||||
eval_tokens: Some(20),
|
||||
num_ctx: None,
|
||||
amended_insight_id: None,
|
||||
backend_used: "local".into(),
|
||||
model_used: "m".into(),
|
||||
cancelled,
|
||||
}
|
||||
}
|
||||
|
||||
/// Seed a completed turn (events + terminal Done) directly in the registry.
|
||||
async fn seed_completed(state: &AppState, id: &str, text_events: usize) {
|
||||
let entry = Arc::new(TurnEntry::new(id.into(), "/p.jpg".into(), 1));
|
||||
for i in 0..text_events {
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta(format!("d{i}")))
|
||||
.await;
|
||||
}
|
||||
entry.push_event(done(false)).await;
|
||||
entry.set_terminal_status(TurnStatus::Done);
|
||||
state.turn_registry.insert(entry).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn indexed_frame_stamps_seq_without_clobbering_tool_index() {
|
||||
// tool_call carries its own pairing `index`; `seq` must be additive.
|
||||
let frame = render_indexed_frame(
|
||||
&ChatStreamEvent::ToolCall {
|
||||
index: 3,
|
||||
name: "geo".into(),
|
||||
arguments: serde_json::json!({}),
|
||||
},
|
||||
42,
|
||||
);
|
||||
assert!(frame.contains("event: tool_call"));
|
||||
assert!(frame.contains("\"index\":3"));
|
||||
assert!(frame.contains("\"seq\":42"));
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn replay_unknown_turn_is_404() {
|
||||
let state = build_state();
|
||||
let app = actix_test::init_service(
|
||||
App::new()
|
||||
.service(turn_replay_handler)
|
||||
.app_data(state.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let req = actix_test::TestRequest::get()
|
||||
.uri("/insights/chat/turn/nope")
|
||||
.to_request();
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn replay_completed_turn_emits_turn_info_and_done_with_seq() {
|
||||
let state = build_state();
|
||||
seed_completed(&state, "t1", 2).await;
|
||||
|
||||
let app = actix_test::init_service(
|
||||
App::new()
|
||||
.service(turn_replay_handler)
|
||||
.app_data(state.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let req = actix_test::TestRequest::get()
|
||||
.uri("/insights/chat/turn/t1")
|
||||
.to_request();
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
let body = String::from_utf8(actix_test::read_body(resp).await.to_vec()).unwrap();
|
||||
assert!(body.contains("event: turn_info"));
|
||||
assert!(body.contains("event: text"));
|
||||
assert!(body.contains("event: done"));
|
||||
// Events are seq-stamped 0,1 (text) and 2 (done).
|
||||
assert!(body.contains("\"seq\":0"));
|
||||
assert!(body.contains("\"seq\":2"));
|
||||
// Done payload carries the renamed token fields the client reads.
|
||||
assert!(body.contains("\"prompt_tokens\":10"));
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn replay_skip_before_query_skips_applied_events() {
|
||||
let state = build_state();
|
||||
seed_completed(&state, "t2", 3).await; // seqs 0,1,2 text; 3 done
|
||||
|
||||
let app = actix_test::init_service(
|
||||
App::new()
|
||||
.service(turn_replay_handler)
|
||||
.app_data(state.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let req = actix_test::TestRequest::get()
|
||||
.uri("/insights/chat/turn/t2?skip_before=2")
|
||||
.to_request();
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
let body = String::from_utf8(actix_test::read_body(resp).await.to_vec()).unwrap();
|
||||
// Only seq 2 (last text) and seq 3 (done) should be present.
|
||||
assert!(body.contains("\"seq\":2"));
|
||||
assert!(body.contains("\"seq\":3"));
|
||||
assert!(!body.contains("\"seq\":0"));
|
||||
assert!(!body.contains("\"seq\":1"));
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn replay_evicted_index_is_410() {
|
||||
let state = build_state();
|
||||
let entry = Arc::new(TurnEntry::new("t3".into(), "/p.jpg".into(), 1));
|
||||
// Push past the cap so the front is evicted and base advances.
|
||||
for i in 0..600 {
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta(format!("d{i}")))
|
||||
.await;
|
||||
}
|
||||
entry.set_terminal_status(TurnStatus::Done);
|
||||
state.turn_registry.insert(entry).await;
|
||||
|
||||
let app = actix_test::init_service(
|
||||
App::new()
|
||||
.service(turn_replay_handler)
|
||||
.app_data(state.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let req = actix_test::TestRequest::get()
|
||||
.uri("/insights/chat/turn/t3?skip_before=0")
|
||||
.to_request();
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), 410);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn cancel_unknown_turn_is_404() {
|
||||
let state = build_state();
|
||||
let app = actix_test::init_service(
|
||||
App::new()
|
||||
.service(cancel_turn_handler)
|
||||
.app_data(state.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let req = actix_test::TestRequest::delete()
|
||||
.uri("/insights/chat/turn/nope")
|
||||
.to_request();
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn cancel_running_turn_marks_cancelled_and_buffers_terminal() {
|
||||
let state = build_state();
|
||||
let entry = Arc::new(TurnEntry::new("t4".into(), "/p.jpg".into(), 1));
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta("partial".into()))
|
||||
.await;
|
||||
state.turn_registry.insert(entry.clone()).await;
|
||||
|
||||
let app = actix_test::init_service(
|
||||
App::new()
|
||||
.service(cancel_turn_handler)
|
||||
.app_data(state.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let req = actix_test::TestRequest::delete()
|
||||
.uri("/insights/chat/turn/t4")
|
||||
.to_request();
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
// Status flipped to Cancelled and a terminal Done(cancelled) buffered
|
||||
// after the existing event, so a late replay reader still completes.
|
||||
assert_eq!(
|
||||
TurnStatus::from(entry.status.load(std::sync::atomic::Ordering::Relaxed)),
|
||||
TurnStatus::Cancelled
|
||||
);
|
||||
let info = entry.info().await;
|
||||
assert_eq!(info.total_events_pushed, 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,14 @@ use tokio::sync::Mutex as TokioMutex;
|
||||
use crate::ai::backend::{BackendKind, ResolvedBackend, SamplingOverrides};
|
||||
use crate::ai::insight_generator::InsightGenerator;
|
||||
use crate::ai::llm_client::{ChatMessage, LlmStreamEvent, Tool};
|
||||
use crate::ai::turn_registry::TurnEntry;
|
||||
use crate::ai::turn_registry::TurnRegistry;
|
||||
use crate::database::InsightDao;
|
||||
use crate::database::models::InsertPhotoInsight;
|
||||
use crate::otel::global_tracer;
|
||||
use crate::utils::normalize_path;
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use uuid::Uuid;
|
||||
|
||||
const DEFAULT_MAX_ITERATIONS: usize = 6;
|
||||
const DEFAULT_NUM_CTX: i32 = 8192;
|
||||
@@ -678,6 +681,626 @@ impl InsightChatService {
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
/// Async turn dispatch: creates a TurnEntry in the registry, spawns the
|
||||
/// agentic loop on a Tokio task, and returns the turn_id immediately.
|
||||
/// Events are buffered in the TurnEntry for SSE replay.
|
||||
pub async fn chat_turn_async(
|
||||
self: Arc<Self>,
|
||||
registry: Arc<TurnRegistry>,
|
||||
req: ChatTurnRequest,
|
||||
) -> String {
|
||||
let turn_id = Uuid::new_v4().to_string();
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
turn_id.clone(),
|
||||
req.file_path.clone(),
|
||||
req.library_id,
|
||||
));
|
||||
registry.insert(entry.clone()).await;
|
||||
|
||||
let svc = self.clone();
|
||||
let entry_clone = entry.clone();
|
||||
let turn_id_for_span = turn_id.clone();
|
||||
let library_id = req.library_id;
|
||||
let handle = tokio::spawn(async move {
|
||||
// Span covering the whole spawned turn execution. Created here (not
|
||||
// in the HTTP handler) because the dispatch span ends at the 202
|
||||
// response, long before this work runs.
|
||||
let tracer = global_tracer();
|
||||
let mut span = tracer.start("ai.chat.turn.execute");
|
||||
span.set_attribute(KeyValue::new("turn_id", turn_id_for_span));
|
||||
span.set_attribute(KeyValue::new("library_id", library_id as i64));
|
||||
|
||||
let result = svc
|
||||
.run_streaming_turn_with_entry(req, entry_clone.clone())
|
||||
.await;
|
||||
if let Err(ref e) = result {
|
||||
span.set_attribute(KeyValue::new("status", "error"));
|
||||
span.set_status(Status::error(format!("{e}")));
|
||||
// Push the terminal event BEFORE flipping status: a replay
|
||||
// reader treats a terminal status with no buffered tail as
|
||||
// "closed", so the Error must be in the buffer first.
|
||||
let _ = entry_clone
|
||||
.push_event(ChatStreamEvent::Error(format!("{}", e)))
|
||||
.await;
|
||||
entry_clone.set_terminal_status(crate::ai::turn_registry::TurnStatus::Error);
|
||||
} else {
|
||||
span.set_attribute(KeyValue::new("status", "done"));
|
||||
span.set_status(Status::Ok);
|
||||
}
|
||||
});
|
||||
|
||||
// Install the abort handle so DELETE can actually stop the task.
|
||||
entry.set_abort_handle(handle.abort_handle());
|
||||
|
||||
turn_id
|
||||
}
|
||||
|
||||
/// Variant of `run_streaming_turn` that pushes events to a `TurnEntry`
|
||||
/// buffer instead of an `mpsc::Sender`.
|
||||
async fn run_streaming_turn_with_entry(
|
||||
self: Arc<Self>,
|
||||
req: ChatTurnRequest,
|
||||
entry: Arc<TurnEntry>,
|
||||
) -> Result<()> {
|
||||
if req.user_message.trim().is_empty() {
|
||||
bail!("user_message must not be empty");
|
||||
}
|
||||
if req.user_message.len() > 8192 {
|
||||
bail!("user_message exceeds 8192 chars");
|
||||
}
|
||||
let normalized = normalize_path(&req.file_path);
|
||||
|
||||
let lock_key = (req.library_id, normalized.clone());
|
||||
let entry_lock = {
|
||||
let mut locks = self.chat_locks.lock().await;
|
||||
locks
|
||||
.entry(lock_key.clone())
|
||||
.or_insert_with(|| Arc::new(TokioMutex::new(())))
|
||||
.clone()
|
||||
};
|
||||
let _guard = entry_lock.lock().await;
|
||||
|
||||
// Look up existing insight scoped to this turn's library_id.
|
||||
let existing_insight = {
|
||||
let cx = opentelemetry::Context::new();
|
||||
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||
dao.get_current_insight_for_library(&cx, req.library_id, &normalized)
|
||||
.map_err(|e| anyhow!("failed to load insight: {:?}", e))?
|
||||
};
|
||||
|
||||
if req.regenerate || existing_insight.is_none() {
|
||||
return self
|
||||
.run_bootstrap_streaming_with_entry(req, normalized, entry)
|
||||
.await;
|
||||
}
|
||||
let insight = existing_insight.expect("just checked Some above");
|
||||
self.run_continuation_streaming_with_entry(req, normalized, insight, entry)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Continuation path with TurnEntry buffer.
|
||||
async fn run_continuation_streaming_with_entry(
|
||||
&self,
|
||||
req: ChatTurnRequest,
|
||||
normalized: String,
|
||||
insight: crate::database::models::PhotoInsight,
|
||||
entry: Arc<TurnEntry>,
|
||||
) -> Result<()> {
|
||||
let active_persona = req
|
||||
.persona_id
|
||||
.clone()
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or_else(|| "default".to_string());
|
||||
let raw_history = insight.training_messages.as_ref().ok_or_else(|| {
|
||||
anyhow!("insight has no chat history; regenerate this insight in agentic mode")
|
||||
})?;
|
||||
let mut messages: Vec<ChatMessage> = serde_json::from_str(raw_history)
|
||||
.map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?;
|
||||
|
||||
let stored_backend = insight.backend.clone();
|
||||
let effective_backend = req
|
||||
.backend
|
||||
.as_deref()
|
||||
.map(|s| s.trim().to_lowercase())
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(|| stored_backend.clone());
|
||||
let kind = BackendKind::parse(&effective_backend)?;
|
||||
validate_cross_replay(&stored_backend, kind.as_str())?;
|
||||
|
||||
let max_iterations = req
|
||||
.max_iterations
|
||||
.unwrap_or(DEFAULT_MAX_ITERATIONS)
|
||||
.clamp(1, env_max_iterations());
|
||||
|
||||
let stored_model = insight.model_version.clone();
|
||||
let overrides = SamplingOverrides {
|
||||
model: req
|
||||
.model
|
||||
.clone()
|
||||
.or_else(|| Some(stored_model.clone()))
|
||||
.filter(|m| !m.is_empty()),
|
||||
num_ctx: req.num_ctx,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
top_k: req.top_k,
|
||||
min_p: req.min_p,
|
||||
};
|
||||
let backend = self.generator.resolve_backend(kind, &overrides).await?;
|
||||
let model_used = backend.model().to_string();
|
||||
|
||||
let local_first_user_has_image = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "user")
|
||||
.and_then(|m| m.images.as_ref())
|
||||
.map(|imgs| !imgs.is_empty())
|
||||
.unwrap_or(false);
|
||||
let offer_describe_tool = backend.images_inline && local_first_user_has_image;
|
||||
let gate_opts = self.generator.current_gate_opts_for_persona(
|
||||
offer_describe_tool,
|
||||
Some((req.user_id, &active_persona)),
|
||||
);
|
||||
let tools = InsightGenerator::build_tool_definitions(gate_opts);
|
||||
|
||||
let image_base64: Option<String> = if offer_describe_tool {
|
||||
self.generator.load_image_as_base64(&normalized).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize)
|
||||
.saturating_sub(RESPONSE_HEADROOM_TOKENS);
|
||||
let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN);
|
||||
let truncated = apply_context_budget(&mut messages, budget_bytes);
|
||||
if truncated {
|
||||
let _ = entry.push_event(ChatStreamEvent::Truncated).await;
|
||||
}
|
||||
|
||||
messages.push(ChatMessage::user(req.user_message.clone()));
|
||||
|
||||
let override_stash =
|
||||
apply_system_prompt_override(&mut messages, req.system_prompt.as_deref());
|
||||
let original_system_content = annotate_system_with_budget(&mut messages, max_iterations);
|
||||
|
||||
let outcome = self
|
||||
.run_streaming_agentic_loop_with_entry(
|
||||
&backend,
|
||||
&mut messages,
|
||||
tools,
|
||||
&image_base64,
|
||||
&normalized,
|
||||
req.user_id,
|
||||
&active_persona,
|
||||
max_iterations,
|
||||
&entry,
|
||||
)
|
||||
.await?;
|
||||
let AgenticLoopOutcome {
|
||||
tool_calls_made,
|
||||
iterations_used,
|
||||
last_prompt_eval_count,
|
||||
last_eval_count,
|
||||
final_content,
|
||||
cancelled,
|
||||
} = outcome;
|
||||
|
||||
// Turn was cancelled mid-flight: the DELETE handler already pushed the
|
||||
// terminal event and flipped status. Don't persist a partial turn or
|
||||
// push a second terminal event.
|
||||
if cancelled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
restore_system_content(&mut messages, original_system_content);
|
||||
|
||||
if !req.amend {
|
||||
restore_system_prompt_override(&mut messages, override_stash);
|
||||
}
|
||||
|
||||
let json = serde_json::to_string(&messages)
|
||||
.map_err(|e| anyhow!("failed to serialize chat history: {}", e))?;
|
||||
|
||||
let mut amended_insight_id: Option<i32> = None;
|
||||
if req.amend {
|
||||
let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content);
|
||||
let final_content = body;
|
||||
|
||||
let new_row = InsertPhotoInsight {
|
||||
library_id: req.library_id,
|
||||
file_path: normalized.clone(),
|
||||
title,
|
||||
summary: final_content.clone(),
|
||||
generated_at: Utc::now().timestamp(),
|
||||
model_version: model_used.clone(),
|
||||
is_current: true,
|
||||
training_messages: Some(json),
|
||||
backend: kind.as_str().to_string(),
|
||||
fewshot_source_ids: None,
|
||||
content_hash: None,
|
||||
num_ctx: req.num_ctx,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
top_k: req.top_k,
|
||||
min_p: req.min_p,
|
||||
system_prompt: req.system_prompt.clone(),
|
||||
persona_id: req.persona_id.clone(),
|
||||
prompt_eval_count: None,
|
||||
eval_count: None,
|
||||
};
|
||||
let cx = opentelemetry::Context::new();
|
||||
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||
let stored = dao
|
||||
.store_insight(&cx, new_row)
|
||||
.map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?;
|
||||
amended_insight_id = Some(stored.id);
|
||||
} else {
|
||||
let cx = opentelemetry::Context::new();
|
||||
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||
let rows = dao
|
||||
.update_training_messages(&cx, req.library_id, &normalized, &json)
|
||||
.map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?;
|
||||
if rows == 0 {
|
||||
log::warn!(
|
||||
"update_training_messages (stream) updated 0 rows for {} (lib {}), \
|
||||
concurrent regenerate likely flipped is_current",
|
||||
normalized,
|
||||
req.library_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let _ = entry
|
||||
.push_event(ChatStreamEvent::Done {
|
||||
tool_calls_made,
|
||||
iterations_used,
|
||||
truncated,
|
||||
prompt_tokens: last_prompt_eval_count,
|
||||
eval_tokens: last_eval_count,
|
||||
num_ctx: req.num_ctx,
|
||||
amended_insight_id,
|
||||
backend_used: kind.as_str().to_string(),
|
||||
model_used,
|
||||
cancelled: false,
|
||||
})
|
||||
.await;
|
||||
|
||||
entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Done);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bootstrap path with TurnEntry buffer.
|
||||
async fn run_bootstrap_streaming_with_entry(
|
||||
&self,
|
||||
req: ChatTurnRequest,
|
||||
normalized: String,
|
||||
entry: Arc<TurnEntry>,
|
||||
) -> Result<()> {
|
||||
let active_persona = req
|
||||
.persona_id
|
||||
.clone()
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or_else(|| "default".to_string());
|
||||
let effective_backend = resolve_bootstrap_backend(req.backend.as_deref())?;
|
||||
let kind = BackendKind::parse(&effective_backend)?;
|
||||
|
||||
let max_iterations = req
|
||||
.max_iterations
|
||||
.unwrap_or(DEFAULT_MAX_ITERATIONS)
|
||||
.clamp(1, env_max_iterations());
|
||||
|
||||
let overrides = SamplingOverrides {
|
||||
model: req.model.clone().filter(|m| !m.is_empty()),
|
||||
num_ctx: req.num_ctx,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
top_k: req.top_k,
|
||||
min_p: req.min_p,
|
||||
};
|
||||
let backend = self.generator.resolve_backend(kind, &overrides).await?;
|
||||
let model_used = backend.model().to_string();
|
||||
|
||||
let image_base64: Option<String> = self.generator.load_image_as_base64(&normalized).ok();
|
||||
|
||||
let exif = self.generator.fetch_exif(&normalized);
|
||||
let date_taken_str = resolve_date_taken_for_context(&exif, &normalized);
|
||||
let gps = exif
|
||||
.as_ref()
|
||||
.and_then(|e| match (e.gps_latitude, e.gps_longitude) {
|
||||
(Some(lat), Some(lon)) => Some((lat as f64, lon as f64)),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let visual_block = if !backend.images_inline {
|
||||
match image_base64.as_deref() {
|
||||
Some(b64) => match backend.local().describe_image(b64).await {
|
||||
Ok(desc) => {
|
||||
format!("Visual description (from local vision model):\n{}\n", desc)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("{} bootstrap: describe_image failed: {}", kind.as_str(), e);
|
||||
String::new()
|
||||
}
|
||||
},
|
||||
None => String::new(),
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let offer_describe_tool = backend.images_inline && image_base64.is_some();
|
||||
let gate_opts = self.generator.current_gate_opts_for_persona(
|
||||
offer_describe_tool,
|
||||
Some((req.user_id, &active_persona)),
|
||||
);
|
||||
let tools = InsightGenerator::build_tool_definitions(gate_opts);
|
||||
|
||||
let persona = resolve_bootstrap_system_prompt(req.system_prompt.as_deref());
|
||||
let system_content = build_bootstrap_system_message(
|
||||
&persona,
|
||||
&normalized,
|
||||
date_taken_str.as_deref(),
|
||||
gps,
|
||||
&visual_block,
|
||||
);
|
||||
let system_msg = ChatMessage::system(system_content);
|
||||
let mut user_msg = ChatMessage::user(req.user_message.clone());
|
||||
if backend.images_inline
|
||||
&& let Some(ref img) = image_base64
|
||||
{
|
||||
user_msg.images = Some(vec![img.clone()]);
|
||||
}
|
||||
let mut messages = vec![system_msg, user_msg];
|
||||
|
||||
let outcome = self
|
||||
.run_streaming_agentic_loop_with_entry(
|
||||
&backend,
|
||||
&mut messages,
|
||||
tools,
|
||||
&image_base64,
|
||||
&normalized,
|
||||
req.user_id,
|
||||
&active_persona,
|
||||
max_iterations,
|
||||
&entry,
|
||||
)
|
||||
.await?;
|
||||
let AgenticLoopOutcome {
|
||||
tool_calls_made,
|
||||
iterations_used,
|
||||
last_prompt_eval_count,
|
||||
last_eval_count,
|
||||
final_content,
|
||||
cancelled,
|
||||
} = outcome;
|
||||
|
||||
// Turn was cancelled mid-flight: the DELETE handler already pushed the
|
||||
// terminal event and flipped status. Don't persist a partial turn or
|
||||
// push a second terminal event.
|
||||
if cancelled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content);
|
||||
|
||||
let json = serde_json::to_string(&messages)
|
||||
.map_err(|e| anyhow!("failed to serialize chat history: {}", e))?;
|
||||
let new_row = InsertPhotoInsight {
|
||||
library_id: req.library_id,
|
||||
file_path: normalized.clone(),
|
||||
title,
|
||||
summary: body,
|
||||
generated_at: Utc::now().timestamp(),
|
||||
model_version: model_used.clone(),
|
||||
is_current: true,
|
||||
training_messages: Some(json),
|
||||
backend: kind.as_str().to_string(),
|
||||
fewshot_source_ids: None,
|
||||
content_hash: None,
|
||||
num_ctx: req.num_ctx,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
top_k: req.top_k,
|
||||
min_p: req.min_p,
|
||||
system_prompt: req.system_prompt.clone(),
|
||||
persona_id: req.persona_id.clone(),
|
||||
prompt_eval_count: None,
|
||||
eval_count: None,
|
||||
};
|
||||
let stored = {
|
||||
let cx = opentelemetry::Context::new();
|
||||
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||
dao.store_insight(&cx, new_row)
|
||||
.map_err(|e| anyhow!("failed to store bootstrap insight: {:?}", e))?
|
||||
};
|
||||
|
||||
let _ = entry
|
||||
.push_event(ChatStreamEvent::Done {
|
||||
tool_calls_made,
|
||||
iterations_used,
|
||||
truncated: false,
|
||||
prompt_tokens: last_prompt_eval_count,
|
||||
eval_tokens: last_eval_count,
|
||||
num_ctx: req.num_ctx,
|
||||
amended_insight_id: Some(stored.id),
|
||||
backend_used: kind.as_str().to_string(),
|
||||
model_used,
|
||||
cancelled: false,
|
||||
})
|
||||
.await;
|
||||
|
||||
entry.set_terminal_status(crate::ai::turn_registry::TurnStatus::Done);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Agentic loop variant that pushes events to a `TurnEntry` buffer.
|
||||
async fn run_streaming_agentic_loop_with_entry(
|
||||
&self,
|
||||
backend: &ResolvedBackend,
|
||||
messages: &mut Vec<ChatMessage>,
|
||||
tools: Vec<Tool>,
|
||||
image_base64: &Option<String>,
|
||||
normalized: &str,
|
||||
user_id: i32,
|
||||
active_persona: &str,
|
||||
max_iterations: usize,
|
||||
entry: &Arc<TurnEntry>,
|
||||
) -> Result<AgenticLoopOutcome> {
|
||||
let mut tool_calls_made = 0usize;
|
||||
let mut iterations_used = 0usize;
|
||||
let mut last_prompt_eval_count: Option<i32> = None;
|
||||
let mut last_eval_count: Option<i32> = None;
|
||||
let mut final_content = String::new();
|
||||
|
||||
for iteration in 0..max_iterations {
|
||||
// Cooperative cancellation: a DELETE flips status out of Running
|
||||
// (and aborts this task). Check at the iteration boundary so an
|
||||
// in-flight tool round finishes cleanly rather than mid-write.
|
||||
if !entry.is_running() {
|
||||
return Ok(AgenticLoopOutcome {
|
||||
tool_calls_made,
|
||||
iterations_used,
|
||||
last_prompt_eval_count,
|
||||
last_eval_count,
|
||||
final_content,
|
||||
cancelled: true,
|
||||
});
|
||||
}
|
||||
|
||||
iterations_used = iteration + 1;
|
||||
let _ = entry
|
||||
.push_event(ChatStreamEvent::IterationStart {
|
||||
n: iterations_used,
|
||||
max: max_iterations,
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut stream = backend
|
||||
.chat()
|
||||
.chat_with_tools_stream(messages.clone(), tools.clone())
|
||||
.await?;
|
||||
|
||||
let mut final_message: Option<ChatMessage> = None;
|
||||
while let Some(ev) = stream.next().await {
|
||||
let ev = ev?;
|
||||
match ev {
|
||||
LlmStreamEvent::TextDelta(delta) => {
|
||||
let _ = entry.push_event(ChatStreamEvent::TextDelta(delta)).await;
|
||||
}
|
||||
LlmStreamEvent::Done {
|
||||
message,
|
||||
prompt_eval_count,
|
||||
eval_count,
|
||||
} => {
|
||||
last_prompt_eval_count = prompt_eval_count;
|
||||
last_eval_count = eval_count;
|
||||
final_message = Some(message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut response =
|
||||
final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?;
|
||||
|
||||
if let Some(ref mut tcs) = response.tool_calls {
|
||||
for tc in tcs.iter_mut() {
|
||||
if !tc.function.arguments.is_object() {
|
||||
tc.function.arguments = serde_json::Value::Object(Default::default());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(response.clone());
|
||||
|
||||
if let Some(ref tool_calls) = response.tool_calls
|
||||
&& !tool_calls.is_empty()
|
||||
{
|
||||
for tool_call in tool_calls {
|
||||
tool_calls_made += 1;
|
||||
let call_index = tool_calls_made - 1;
|
||||
let _ = entry
|
||||
.push_event(ChatStreamEvent::ToolCall {
|
||||
index: call_index,
|
||||
name: tool_call.function.name.clone(),
|
||||
arguments: tool_call.function.arguments.clone(),
|
||||
})
|
||||
.await;
|
||||
let cx = opentelemetry::Context::new();
|
||||
let result = self
|
||||
.generator
|
||||
.execute_tool(
|
||||
&tool_call.function.name,
|
||||
&tool_call.function.arguments,
|
||||
backend,
|
||||
image_base64,
|
||||
normalized,
|
||||
user_id,
|
||||
active_persona,
|
||||
&cx,
|
||||
)
|
||||
.await;
|
||||
let (result_preview, result_truncated) = truncate_tool_result(&result);
|
||||
let _ = entry
|
||||
.push_event(ChatStreamEvent::ToolResult {
|
||||
index: call_index,
|
||||
name: tool_call.function.name.clone(),
|
||||
result: result_preview,
|
||||
result_truncated,
|
||||
})
|
||||
.await;
|
||||
messages.push(ChatMessage::tool_result(result));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
final_content = response.content;
|
||||
break;
|
||||
}
|
||||
|
||||
// No-tools fallback
|
||||
if final_content.is_empty() {
|
||||
let synthetic_idx = messages.len();
|
||||
messages.push(ChatMessage::user(
|
||||
"Please write your final answer now without calling any more tools.",
|
||||
));
|
||||
let mut stream = backend
|
||||
.chat()
|
||||
.chat_with_tools_stream(messages.clone(), vec![])
|
||||
.await?;
|
||||
let mut final_message: Option<ChatMessage> = None;
|
||||
while let Some(ev) = stream.next().await {
|
||||
let ev = ev?;
|
||||
match ev {
|
||||
LlmStreamEvent::TextDelta(delta) => {
|
||||
let _ = entry.push_event(ChatStreamEvent::TextDelta(delta)).await;
|
||||
}
|
||||
LlmStreamEvent::Done {
|
||||
message,
|
||||
prompt_eval_count,
|
||||
eval_count,
|
||||
} => {
|
||||
last_prompt_eval_count = prompt_eval_count;
|
||||
last_eval_count = eval_count;
|
||||
final_message = Some(message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let final_response =
|
||||
final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?;
|
||||
final_content = final_response.content.clone();
|
||||
messages.push(final_response);
|
||||
messages.remove(synthetic_idx);
|
||||
}
|
||||
|
||||
Ok(AgenticLoopOutcome {
|
||||
tool_calls_made,
|
||||
iterations_used,
|
||||
last_prompt_eval_count,
|
||||
last_eval_count,
|
||||
final_content,
|
||||
cancelled: false,
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_streaming_turn(
|
||||
self: Arc<Self>,
|
||||
req: ChatTurnRequest,
|
||||
@@ -836,6 +1459,8 @@ impl InsightChatService {
|
||||
last_prompt_eval_count,
|
||||
last_eval_count,
|
||||
final_content,
|
||||
// The mpsc (legacy) path has no cancellation channel.
|
||||
cancelled: _,
|
||||
} = outcome;
|
||||
|
||||
// Drop the per-turn iteration-budget note before persisting so it
|
||||
@@ -916,6 +1541,7 @@ impl InsightChatService {
|
||||
amended_insight_id,
|
||||
backend_used: kind.as_str().to_string(),
|
||||
model_used,
|
||||
cancelled: false,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1052,6 +1678,8 @@ impl InsightChatService {
|
||||
last_prompt_eval_count,
|
||||
last_eval_count,
|
||||
final_content,
|
||||
// The mpsc (legacy) path has no cancellation channel.
|
||||
cancelled: _,
|
||||
} = outcome;
|
||||
|
||||
let (title, body) = crate::ai::insight_generator::parse_title_body(&final_content);
|
||||
@@ -1101,6 +1729,7 @@ impl InsightChatService {
|
||||
amended_insight_id: Some(stored.id),
|
||||
backend_used: kind.as_str().to_string(),
|
||||
model_used,
|
||||
cancelled: false,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1274,6 +1903,7 @@ impl InsightChatService {
|
||||
last_prompt_eval_count,
|
||||
last_eval_count,
|
||||
final_content,
|
||||
cancelled: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1402,6 +2032,10 @@ struct AgenticLoopOutcome {
|
||||
last_prompt_eval_count: Option<i32>,
|
||||
last_eval_count: Option<i32>,
|
||||
final_content: String,
|
||||
/// True when the loop exited early because the turn was cancelled
|
||||
/// (status flipped out of `Running`). Callers skip persistence and the
|
||||
/// terminal `Done` push — the cancel handler owns the terminal event.
|
||||
cancelled: bool,
|
||||
}
|
||||
|
||||
/// Events emitted by `chat_turn_stream`. One stream per turn; ends after
|
||||
@@ -1456,6 +2090,10 @@ pub enum ChatStreamEvent {
|
||||
amended_insight_id: Option<i32>,
|
||||
backend_used: String,
|
||||
model_used: String,
|
||||
/// True only for the synthetic terminal event emitted by the cancel
|
||||
/// handler, so clients can distinguish a user-cancelled turn from a
|
||||
/// natural completion. Always false on the normal success path.
|
||||
cancelled: bool,
|
||||
},
|
||||
/// Terminal failure event. No further events follow.
|
||||
Error(String),
|
||||
|
||||
+4
-3
@@ -11,6 +11,7 @@ pub mod llm_client;
|
||||
pub mod ollama;
|
||||
pub mod openrouter;
|
||||
pub mod sms_client;
|
||||
pub mod turn_registry;
|
||||
|
||||
// strip_summary_boilerplate is used by binaries (test_daily_summary), not the library
|
||||
#[allow(unused_imports)]
|
||||
@@ -19,11 +20,11 @@ pub use daily_summary_job::{
|
||||
generate_daily_summaries, strip_summary_boilerplate,
|
||||
};
|
||||
pub use handlers::{
|
||||
cancel_generation_handler, chat_history_handler, chat_rewind_handler, chat_stream_handler,
|
||||
chat_turn_handler, delete_insight_handler, export_training_data_handler,
|
||||
cancel_generation_handler, cancel_turn_handler, chat_history_handler, chat_rewind_handler,
|
||||
chat_stream_handler, chat_turn_handler, delete_insight_handler, export_training_data_handler,
|
||||
generate_agentic_insight_handler, generate_insight_handler, generation_status_handler,
|
||||
get_all_insights_handler, get_available_models_handler, get_insight_handler,
|
||||
get_openrouter_models_handler, rate_insight_handler,
|
||||
get_openrouter_models_handler, rate_insight_handler, turn_async_handler, turn_replay_handler,
|
||||
};
|
||||
pub use insight_generator::InsightGenerator;
|
||||
pub use llamacpp::LlamaCppClient;
|
||||
|
||||
@@ -0,0 +1,748 @@
|
||||
use crate::ai::insight_chat::ChatStreamEvent;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{Mutex, Notify};
|
||||
use tokio::task::AbortHandle;
|
||||
|
||||
/// Maximum number of events buffered per turn. Agentic turns typically
|
||||
/// produce ~120 events; 500 provides 4× headroom. When exceeded, oldest
|
||||
/// events are evicted from the front.
|
||||
const MAX_BUFFERED_EVENTS: usize = 500;
|
||||
|
||||
/// Turn status codes used by `TurnEntry::status`.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum TurnStatus {
|
||||
Running = 0,
|
||||
Done = 1,
|
||||
Error = 2,
|
||||
Cancelled = 3,
|
||||
}
|
||||
|
||||
impl From<u32> for TurnStatus {
|
||||
fn from(v: u32) -> Self {
|
||||
match v {
|
||||
0 => TurnStatus::Running,
|
||||
1 => TurnStatus::Done,
|
||||
2 => TurnStatus::Error,
|
||||
3 => TurnStatus::Cancelled,
|
||||
_ => TurnStatus::Running,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TurnStatus {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
TurnStatus::Running => "running",
|
||||
TurnStatus::Done => "done",
|
||||
TurnStatus::Error => "error",
|
||||
TurnStatus::Cancelled => "cancelled",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared metadata about a turn, read by the SSE replay handler to emit
|
||||
/// the initial `turn_info` event and to decide whether to wait for new
|
||||
/// events or close immediately.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TurnInfo {
|
||||
pub turn_id: String,
|
||||
pub file_path: String,
|
||||
pub library_id: i32,
|
||||
pub status: TurnStatus,
|
||||
pub total_events_pushed: u32,
|
||||
pub buffered_count: u32,
|
||||
}
|
||||
|
||||
/// Result of reading events at or after an absolute `skip_before` index.
|
||||
#[derive(Debug)]
|
||||
pub enum ReplayOutcome {
|
||||
/// New events are available. `next_skip` is the absolute index to pass
|
||||
/// on the next read (i.e. one past the last event returned).
|
||||
Events {
|
||||
events: Vec<ChatStreamEvent>,
|
||||
next_skip: u32,
|
||||
},
|
||||
/// The reader is caught up to the live edge — no events past `skip_before`
|
||||
/// yet. `next_skip` is the current high-water mark.
|
||||
CaughtUp { next_skip: u32 },
|
||||
/// `skip_before` points below the buffer's base index: the requested
|
||||
/// events were evicted. Maps to HTTP 410 Gone.
|
||||
Gone,
|
||||
}
|
||||
|
||||
/// Per-turn state shared between the agentic loop (writer) and all SSE
|
||||
/// replay connections (readers).
|
||||
pub struct TurnEntry {
|
||||
pub turn_id: String,
|
||||
pub file_path: String,
|
||||
pub library_id: i32,
|
||||
/// Shared event buffer — multiple SSE connections can read independently.
|
||||
/// Each connection tracks its own `skip_before` offset.
|
||||
events: Mutex<Vec<ChatStreamEvent>>,
|
||||
/// Monotonic counter: total events pushed (may exceed events.len()
|
||||
/// due to eviction). Used for skip_before indexing.
|
||||
total_events_pushed: AtomicU32,
|
||||
/// The event index that this entry started with. Adjusts on eviction
|
||||
/// so that `skip_before` stays absolute across connections.
|
||||
base_index: AtomicU32,
|
||||
pub status: AtomicU32,
|
||||
/// Abort handle for the spawned agentic task, set once after spawn.
|
||||
/// Behind a std `Mutex` because the entry is shared via `Arc` and the
|
||||
/// handle is installed after the entry is already in the registry.
|
||||
abort_handle: StdMutex<Option<AbortHandle>>,
|
||||
pub created_at: Instant,
|
||||
notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl TurnEntry {
|
||||
pub fn new(turn_id: String, file_path: String, library_id: i32) -> Self {
|
||||
Self {
|
||||
turn_id,
|
||||
file_path,
|
||||
library_id,
|
||||
events: Mutex::new(Vec::new()),
|
||||
total_events_pushed: AtomicU32::new(0),
|
||||
base_index: AtomicU32::new(0),
|
||||
status: AtomicU32::new(TurnStatus::Running as u32),
|
||||
abort_handle: StdMutex::new(None),
|
||||
created_at: Instant::now(),
|
||||
notify: Arc::new(Notify::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Install the abort handle for the spawned agentic task. Called once,
|
||||
/// right after the task is spawned.
|
||||
pub fn set_abort_handle(&self, handle: AbortHandle) {
|
||||
*self.abort_handle.lock().expect("abort_handle poisoned") = Some(handle);
|
||||
}
|
||||
|
||||
/// Abort the spawned agentic task, if a handle was installed. Returns
|
||||
/// `true` if a task was aborted.
|
||||
pub fn abort(&self) -> bool {
|
||||
if let Some(handle) = self
|
||||
.abort_handle
|
||||
.lock()
|
||||
.expect("abort_handle poisoned")
|
||||
.take()
|
||||
{
|
||||
handle.abort();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Push an event into the buffer. Evicts oldest events if the buffer
|
||||
/// exceeds `MAX_BUFFERED_EVENTS`. Notifies all waiting SSE connections.
|
||||
pub async fn push_event(&self, event: ChatStreamEvent) {
|
||||
{
|
||||
let mut events = self.events.lock().await;
|
||||
|
||||
// Evict oldest events if we've hit the cap.
|
||||
if events.len() >= MAX_BUFFERED_EVENTS {
|
||||
// Drop the oldest event to make room and advance the base
|
||||
// index so skip_before stays absolute across connections.
|
||||
events.remove(0);
|
||||
self.base_index.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
events.push(event);
|
||||
// Increment while holding the buffer lock so the counter stays in
|
||||
// lock-step with the buffer even if multiple writers ever exist.
|
||||
self.total_events_pushed.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
self.notify.notify_waiters();
|
||||
}
|
||||
|
||||
/// Get a snapshot of turn metadata for the `turn_info` SSE event.
|
||||
pub async fn info(&self) -> TurnInfo {
|
||||
let events = self.events.lock().await;
|
||||
let buffered = events.len() as u32;
|
||||
let total = self.total_events_pushed.load(Ordering::Relaxed);
|
||||
drop(events);
|
||||
|
||||
TurnInfo {
|
||||
turn_id: self.turn_id.clone(),
|
||||
file_path: self.file_path.clone(),
|
||||
library_id: self.library_id,
|
||||
status: self.status.load(Ordering::Relaxed).into(),
|
||||
total_events_pushed: total,
|
||||
buffered_count: buffered,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the terminal status and notify all waiters.
|
||||
pub fn set_terminal_status(&self, status: TurnStatus) {
|
||||
self.status.store(status as u32, Ordering::Relaxed);
|
||||
self.notify.notify_waiters();
|
||||
}
|
||||
|
||||
/// Read buffered events at or after absolute index `skip_before` without
|
||||
/// waiting. Distinguishes "evicted" (Gone) from "caught up" (no new
|
||||
/// events yet) — the previous boolean/`Option` API conflated the two.
|
||||
pub async fn replay_from(&self, skip_before: u32) -> ReplayOutcome {
|
||||
let events = self.events.lock().await;
|
||||
let base = self.base_index.load(Ordering::Relaxed);
|
||||
|
||||
// The buffer holds absolute indices [base, base + len). A request
|
||||
// below `base` asked for events that have been evicted.
|
||||
if skip_before < base {
|
||||
return ReplayOutcome::Gone;
|
||||
}
|
||||
|
||||
let offset = (skip_before - base) as usize;
|
||||
let next_skip = base + events.len() as u32;
|
||||
if offset >= events.len() {
|
||||
// Caught up to (or past) the live edge — nothing new yet.
|
||||
return ReplayOutcome::CaughtUp { next_skip };
|
||||
}
|
||||
|
||||
ReplayOutcome::Events {
|
||||
events: events[offset..].to_vec(),
|
||||
next_skip,
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for the next batch of events past `skip_before`, the turn to
|
||||
/// finish, or eviction. Returns:
|
||||
/// - `Events` when new events are available (drained before any terminal
|
||||
/// signal so the final `Done`/`Error` is never dropped),
|
||||
/// - `CaughtUp` only when the turn has reached a terminal status and the
|
||||
/// reader is fully drained (the caller should close the stream),
|
||||
/// - `Gone` when `skip_before` points into evicted territory.
|
||||
pub async fn next_batch(&self, skip_before: u32) -> ReplayOutcome {
|
||||
loop {
|
||||
// Register interest BEFORE inspecting state so a push/terminal that
|
||||
// races between our read and our await can't be lost (Notify's
|
||||
// `notify_waiters` does not store a permit).
|
||||
let notified = self.notify.notified();
|
||||
tokio::pin!(notified);
|
||||
notified.as_mut().enable();
|
||||
|
||||
match self.replay_from(skip_before).await {
|
||||
ReplayOutcome::CaughtUp { next_skip } => {
|
||||
// No new events. If the turn is finished, every event
|
||||
// (including the terminal one) has already been drained
|
||||
// above on a prior call, so signal the caller to close.
|
||||
if !self.is_running() {
|
||||
return ReplayOutcome::CaughtUp { next_skip };
|
||||
}
|
||||
// Still running — wait for the next push or terminal.
|
||||
}
|
||||
other => return other, // Events or Gone
|
||||
}
|
||||
|
||||
notified.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this turn is still running.
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.status.load(Ordering::Relaxed) == TurnStatus::Running as u32
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory registry of all active chat turns. Injected into `AppState`
|
||||
/// and shared across all handlers.
|
||||
pub struct TurnRegistry {
|
||||
entries: Mutex<HashMap<String, Arc<TurnEntry>>>,
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl TurnRegistry {
|
||||
pub fn new(timeout_secs: u64) -> Self {
|
||||
Self {
|
||||
entries: Mutex::new(HashMap::new()),
|
||||
timeout_secs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the cleanup timeout in seconds.
|
||||
pub fn timeout_secs(&self) -> u64 {
|
||||
self.timeout_secs
|
||||
}
|
||||
|
||||
/// Insert a new turn entry. Returns the turn_id.
|
||||
pub async fn insert(&self, entry: Arc<TurnEntry>) -> String {
|
||||
let turn_id = entry.turn_id.clone();
|
||||
let mut entries = self.entries.lock().await;
|
||||
entries.insert(turn_id.clone(), entry);
|
||||
turn_id
|
||||
}
|
||||
|
||||
/// Look up a turn by id. Returns None if not found or expired.
|
||||
pub async fn get(&self, turn_id: &str) -> Option<Arc<TurnEntry>> {
|
||||
let entries = self.entries.lock().await;
|
||||
entries.get(turn_id).cloned()
|
||||
}
|
||||
|
||||
/// Clean up stale entries older than the timeout. Returns the count of
|
||||
/// entries removed.
|
||||
pub async fn cleanup_stale(&self) -> usize {
|
||||
let mut entries = self.entries.lock().await;
|
||||
let _now = Instant::now();
|
||||
let stale: Vec<String> = entries
|
||||
.iter()
|
||||
.filter(|(_, entry)| entry.created_at.elapsed().as_secs() > self.timeout_secs)
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
|
||||
for id in &stale {
|
||||
entries.remove(id);
|
||||
}
|
||||
|
||||
if !stale.is_empty() {
|
||||
log::info!(
|
||||
"TurnRegistry: cleaned up {} stale entries (timeout={}s)",
|
||||
stale.len(),
|
||||
self.timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
stale.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ai::insight_chat::ChatStreamEvent;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Unwrap the events from a `ReplayOutcome::Events`, panicking otherwise.
|
||||
fn events_of(outcome: ReplayOutcome) -> Vec<ChatStreamEvent> {
|
||||
match outcome {
|
||||
ReplayOutcome::Events { events, .. } => events,
|
||||
other => panic!("expected Events, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── TurnStatus ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn turn_status_from_u32_valid_values() {
|
||||
assert_eq!(TurnStatus::from(0), TurnStatus::Running);
|
||||
assert_eq!(TurnStatus::from(1), TurnStatus::Done);
|
||||
assert_eq!(TurnStatus::from(2), TurnStatus::Error);
|
||||
assert_eq!(TurnStatus::from(3), TurnStatus::Cancelled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn turn_status_from_u32_unknown_defaults_to_running() {
|
||||
assert_eq!(TurnStatus::from(4), TurnStatus::Running);
|
||||
assert_eq!(TurnStatus::from(u32::MAX), TurnStatus::Running);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn turn_status_as_str() {
|
||||
assert_eq!(TurnStatus::Running.as_str(), "running");
|
||||
assert_eq!(TurnStatus::Done.as_str(), "done");
|
||||
assert_eq!(TurnStatus::Error.as_str(), "error");
|
||||
assert_eq!(TurnStatus::Cancelled.as_str(), "cancelled");
|
||||
}
|
||||
|
||||
// ── TurnEntry ───────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_entry_push_and_replay() {
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
1,
|
||||
));
|
||||
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta("hello".to_string()))
|
||||
.await;
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta(" world".to_string()))
|
||||
.await;
|
||||
|
||||
let events = events_of(entry.replay_from(0).await);
|
||||
assert_eq!(events.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_entry_replay_with_skip() {
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
1,
|
||||
));
|
||||
|
||||
for i in 0..5 {
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta(format!("e{i}")))
|
||||
.await;
|
||||
}
|
||||
|
||||
// skip_before=0 → all 5 events
|
||||
let all = events_of(entry.replay_from(0).await);
|
||||
assert_eq!(all.len(), 5);
|
||||
|
||||
// skip_before=2 → events 2,3,4 (3 events)
|
||||
let skipped = events_of(entry.replay_from(2).await);
|
||||
assert_eq!(skipped.len(), 3);
|
||||
|
||||
// skip_before=5 → caught up to the live edge (not Gone).
|
||||
assert!(matches!(
|
||||
entry.replay_from(5).await,
|
||||
ReplayOutcome::CaughtUp { next_skip: 5 }
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_entry_replay_empty_by_default() {
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
1,
|
||||
));
|
||||
// Empty buffer with skip_before=0 → caught up (nothing to replay yet).
|
||||
assert!(matches!(
|
||||
entry.replay_from(0).await,
|
||||
ReplayOutcome::CaughtUp { next_skip: 0 }
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_entry_is_running_initially() {
|
||||
let entry = TurnEntry::new("t1".to_string(), "/photo.jpg".to_string(), 1);
|
||||
assert!(entry.is_running());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_entry_set_terminal_status() {
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
1,
|
||||
));
|
||||
assert!(entry.is_running());
|
||||
entry.set_terminal_status(TurnStatus::Done);
|
||||
assert!(!entry.is_running());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_entry_info() {
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
42,
|
||||
));
|
||||
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta("x".to_string()))
|
||||
.await;
|
||||
entry.set_terminal_status(TurnStatus::Done);
|
||||
|
||||
let info = entry.info().await;
|
||||
assert_eq!(info.turn_id, "t1");
|
||||
assert_eq!(info.file_path, "/photo.jpg");
|
||||
assert_eq!(info.library_id, 42);
|
||||
assert_eq!(info.status, TurnStatus::Done);
|
||||
assert_eq!(info.total_events_pushed, 1);
|
||||
assert_eq!(info.buffered_count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_entry_eviction_caps_buffer() {
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
1,
|
||||
));
|
||||
|
||||
// Push MAX_BUFFERED_EVENTS + 10 events.
|
||||
for i in 0..(MAX_BUFFERED_EVENTS + 10) {
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta(format!("e{i}")))
|
||||
.await;
|
||||
}
|
||||
|
||||
// Asking from absolute 0 after eviction is Gone (0-9 were dropped).
|
||||
assert!(matches!(entry.replay_from(0).await, ReplayOutcome::Gone));
|
||||
|
||||
// Reading from the new base (10) returns the full capped buffer.
|
||||
let events = events_of(entry.replay_from(10).await);
|
||||
assert_eq!(events.len(), MAX_BUFFERED_EVENTS);
|
||||
|
||||
// First event should be at index 10 (0-9 were evicted).
|
||||
if let ChatStreamEvent::TextDelta(s) = &events[0] {
|
||||
assert_eq!(s, "e10");
|
||||
} else {
|
||||
panic!("expected TextDelta");
|
||||
}
|
||||
|
||||
// Last event should be at index MAX_BUFFERED_EVENTS + 9.
|
||||
if let ChatStreamEvent::TextDelta(s) = &events[events.len() - 1] {
|
||||
assert_eq!(s, &format!("e{}", MAX_BUFFERED_EVENTS + 9));
|
||||
} else {
|
||||
panic!("expected TextDelta");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_entry_replay_evicted_index_is_gone() {
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
1,
|
||||
));
|
||||
|
||||
// Push one past the cap so exactly one event (index 0) is evicted.
|
||||
for i in 0..=MAX_BUFFERED_EVENTS {
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta(format!("e{i}")))
|
||||
.await;
|
||||
}
|
||||
|
||||
// Base is now 1; asking from absolute 0 is evicted territory → Gone.
|
||||
assert!(matches!(entry.replay_from(0).await, ReplayOutcome::Gone));
|
||||
|
||||
// skip_before = MAX_BUFFERED_EVENTS → last event only (index valid).
|
||||
let last = events_of(entry.replay_from(MAX_BUFFERED_EVENTS as u32).await);
|
||||
assert_eq!(last.len(), 1);
|
||||
|
||||
// skip_before = MAX_BUFFERED_EVENTS + 1 → caught up to the live edge.
|
||||
assert!(matches!(
|
||||
entry.replay_from((MAX_BUFFERED_EVENTS + 1) as u32).await,
|
||||
ReplayOutcome::CaughtUp { .. }
|
||||
));
|
||||
}
|
||||
|
||||
// ── TurnRegistry ────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_registry_insert_and_get() {
|
||||
let registry = TurnRegistry::new(300);
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
1,
|
||||
));
|
||||
let id = registry.insert(entry).await;
|
||||
assert_eq!(id, "t1");
|
||||
|
||||
let retrieved = registry.get("t1").await;
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().turn_id, "t1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_registry_get_nonexistent_returns_none() {
|
||||
let registry = TurnRegistry::new(300);
|
||||
assert!(registry.get("nonexistent").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_registry_cleanup_stale_removes_old_entries() {
|
||||
let registry = TurnRegistry::new(0);
|
||||
let mut entry = TurnEntry::new("t1".to_string(), "/photo.jpg".to_string(), 1);
|
||||
entry.created_at = Instant::now() - Duration::from_secs(1);
|
||||
registry.insert(Arc::new(entry)).await;
|
||||
|
||||
let cleaned = registry.cleanup_stale().await;
|
||||
assert_eq!(cleaned, 1);
|
||||
assert!(registry.get("t1").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_registry_cleanup_stale_preserves_recent() {
|
||||
let registry = TurnRegistry::new(3600); // 1 hour
|
||||
let entry = Arc::new(TurnEntry::new(
|
||||
"t1".to_string(),
|
||||
"/photo.jpg".to_string(),
|
||||
1,
|
||||
));
|
||||
registry.insert(entry).await;
|
||||
|
||||
let cleaned = registry.cleanup_stale().await;
|
||||
assert_eq!(cleaned, 0);
|
||||
assert!(registry.get("t1").await.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_registry_cleanup_stale_multiple() {
|
||||
let registry = TurnRegistry::new(0);
|
||||
|
||||
for i in 0..5 {
|
||||
let mut entry = TurnEntry::new(format!("t{i}"), "/photo.jpg".to_string(), 1);
|
||||
entry.created_at = Instant::now() - Duration::from_secs(1);
|
||||
registry.insert(Arc::new(entry)).await;
|
||||
}
|
||||
|
||||
let cleaned = registry.cleanup_stale().await;
|
||||
assert_eq!(cleaned, 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_registry_timeout_secs() {
|
||||
let registry = TurnRegistry::new(600);
|
||||
assert_eq!(registry.timeout_secs(), 600);
|
||||
}
|
||||
|
||||
// ── next_batch / live replay ────────────────────────────────────
|
||||
|
||||
/// Drain a turn the way the SSE replay handler does: pull batches via
|
||||
/// `next_batch` until the turn is finished and fully drained.
|
||||
async fn drain_to_end(entry: Arc<TurnEntry>) -> Vec<ChatStreamEvent> {
|
||||
let mut out = Vec::new();
|
||||
let mut skip = 0u32;
|
||||
while let ReplayOutcome::Events { events, next_skip } = entry.next_batch(skip).await {
|
||||
out.extend(events);
|
||||
skip = next_skip;
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn is_terminal(ev: &ChatStreamEvent) -> bool {
|
||||
matches!(ev, ChatStreamEvent::Done { .. } | ChatStreamEvent::Error(_))
|
||||
}
|
||||
|
||||
/// The core guarantee behind the replay rewrite: a reader waiting on
|
||||
/// `next_batch` always receives the terminal event, even though the
|
||||
/// writer flips status to terminal immediately after pushing it.
|
||||
#[tokio::test]
|
||||
async fn next_batch_always_delivers_terminal_event() {
|
||||
for _ in 0..50 {
|
||||
let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1));
|
||||
|
||||
let writer = entry.clone();
|
||||
let w = tokio::spawn(async move {
|
||||
writer
|
||||
.push_event(ChatStreamEvent::IterationStart { n: 1, max: 6 })
|
||||
.await;
|
||||
writer
|
||||
.push_event(ChatStreamEvent::TextDelta("hi".into()))
|
||||
.await;
|
||||
// Push terminal then flip status with no await between — the
|
||||
// race that previously dropped the Done on the reader side.
|
||||
writer
|
||||
.push_event(ChatStreamEvent::Done {
|
||||
tool_calls_made: 0,
|
||||
iterations_used: 1,
|
||||
truncated: false,
|
||||
prompt_tokens: None,
|
||||
eval_tokens: None,
|
||||
num_ctx: None,
|
||||
amended_insight_id: None,
|
||||
backend_used: "local".into(),
|
||||
model_used: "m".into(),
|
||||
cancelled: false,
|
||||
})
|
||||
.await;
|
||||
writer.set_terminal_status(TurnStatus::Done);
|
||||
});
|
||||
|
||||
let events = drain_to_end(entry).await;
|
||||
w.await.unwrap();
|
||||
|
||||
assert!(
|
||||
events.last().is_some_and(is_terminal),
|
||||
"terminal event missing; got {} events",
|
||||
events.len()
|
||||
);
|
||||
assert_eq!(events.len(), 3, "expected IterationStart, TextDelta, Done");
|
||||
}
|
||||
}
|
||||
|
||||
/// A reader that connects before any event is pushed blocks in
|
||||
/// `next_batch` and then receives events as the writer produces them.
|
||||
#[tokio::test]
|
||||
async fn next_batch_waits_for_late_events() {
|
||||
let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1));
|
||||
|
||||
let writer = entry.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::task::yield_now().await;
|
||||
writer
|
||||
.push_event(ChatStreamEvent::TextDelta("late".into()))
|
||||
.await;
|
||||
writer.set_terminal_status(TurnStatus::Done);
|
||||
});
|
||||
|
||||
// First call blocks until the writer pushes, rather than returning
|
||||
// CaughtUp on the empty buffer of a running turn.
|
||||
match entry.next_batch(0).await {
|
||||
ReplayOutcome::Events { events, next_skip } => {
|
||||
assert_eq!(events.len(), 1);
|
||||
assert_eq!(next_skip, 1);
|
||||
}
|
||||
other => panic!("expected Events, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn next_batch_closes_on_terminal_when_caught_up() {
|
||||
let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1));
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta("x".into()))
|
||||
.await;
|
||||
entry.set_terminal_status(TurnStatus::Done);
|
||||
|
||||
// Caught up (skip past the one buffered event) on a finished turn →
|
||||
// CaughtUp so the handler closes the stream rather than hanging.
|
||||
assert!(matches!(
|
||||
entry.next_batch(1).await,
|
||||
ReplayOutcome::CaughtUp { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn next_batch_reports_gone_for_evicted_index() {
|
||||
let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1));
|
||||
for i in 0..=MAX_BUFFERED_EVENTS {
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta(format!("e{i}")))
|
||||
.await;
|
||||
}
|
||||
// Index 0 was evicted (base advanced to 1).
|
||||
assert!(matches!(entry.next_batch(0).await, ReplayOutcome::Gone));
|
||||
}
|
||||
|
||||
// ── abort handle (#1 cancellation) ──────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_handle_aborts_task_once() {
|
||||
let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1));
|
||||
|
||||
// No handle installed yet → abort is a no-op.
|
||||
assert!(!entry.abort());
|
||||
|
||||
let handle = tokio::spawn(async {
|
||||
// Long-lived task that only ends via abort.
|
||||
futures::future::pending::<()>().await;
|
||||
});
|
||||
entry.set_abort_handle(handle.abort_handle());
|
||||
|
||||
assert!(entry.abort(), "first abort should fire");
|
||||
assert!(!entry.abort(), "handle is taken; second abort is a no-op");
|
||||
|
||||
// The aborted task resolves to a cancellation JoinError.
|
||||
let join = handle.await;
|
||||
assert!(join.unwrap_err().is_cancelled());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn base_index_tracks_eviction() {
|
||||
let entry = Arc::new(TurnEntry::new("t".into(), "/p.jpg".into(), 1));
|
||||
for i in 0..(MAX_BUFFERED_EVENTS + 5) {
|
||||
entry
|
||||
.push_event(ChatStreamEvent::TextDelta(format!("e{i}")))
|
||||
.await;
|
||||
}
|
||||
let info = entry.info().await;
|
||||
// 5 events evicted; total keeps climbing, buffer stays capped.
|
||||
assert_eq!(info.total_events_pushed, (MAX_BUFFERED_EVENTS + 5) as u32);
|
||||
assert_eq!(info.buffered_count, MAX_BUFFERED_EVENTS as u32);
|
||||
// First live index is 5: reading from there yields the full buffer.
|
||||
let from_base = events_of(entry.replay_from(5).await);
|
||||
assert_eq!(from_base.len(), MAX_BUFFERED_EVENTS);
|
||||
}
|
||||
}
|
||||
+25
@@ -197,6 +197,28 @@ fn main() -> std::io::Result<()> {
|
||||
app_state.library_health.clone(),
|
||||
);
|
||||
|
||||
// Periodically clean up stale turn entries from the in-memory
|
||||
// registry. Runs at the same interval as the configured timeout,
|
||||
// drops entries older than that timeout.
|
||||
{
|
||||
let registry = app_state.turn_registry.clone();
|
||||
let timeout_secs = registry.timeout_secs();
|
||||
tokio::spawn(async move {
|
||||
// Sweep at most every 5 minutes, and never less often than the
|
||||
// timeout itself — otherwise entries could linger up to ~2× the
|
||||
// configured timeout before being reclaimed.
|
||||
let interval_secs = timeout_secs.clamp(1, 300);
|
||||
let interval = tokio::time::Duration::from_secs(interval_secs);
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
let cleaned = registry.cleanup_stale().await;
|
||||
if cleaned > 0 {
|
||||
log::info!("TurnRegistry: cleaned up {cleaned} stale entries");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Spawn background job to generate daily conversation summaries
|
||||
{
|
||||
use crate::ai::generate_daily_summaries;
|
||||
@@ -335,6 +357,9 @@ fn main() -> std::io::Result<()> {
|
||||
.service(ai::chat_stream_handler)
|
||||
.service(ai::chat_history_handler)
|
||||
.service(ai::chat_rewind_handler)
|
||||
.service(ai::turn_async_handler)
|
||||
.service(ai::turn_replay_handler)
|
||||
.service(ai::cancel_turn_handler)
|
||||
.service(ai::rate_insight_handler)
|
||||
.service(ai::export_training_data_handler)
|
||||
.service(libraries::list_libraries)
|
||||
|
||||
+17
-10
@@ -4,6 +4,7 @@ use crate::ai::face_client::FaceClient;
|
||||
use crate::ai::insight_chat::{ChatLockMap, InsightChatService};
|
||||
use crate::ai::llamacpp::LlamaCppClient;
|
||||
use crate::ai::openrouter::OpenRouterClient;
|
||||
use crate::ai::turn_registry::TurnRegistry;
|
||||
use crate::ai::{InsightGenerator, OllamaClient, SmsApiClient};
|
||||
use crate::database::{
|
||||
CalendarEventDao, DailySummaryDao, ExifDao, InsightDao, InsightGenerationJobDao, KnowledgeDao,
|
||||
@@ -78,19 +79,10 @@ pub struct AppState {
|
||||
pub insight_generator: InsightGenerator,
|
||||
/// Chat continuation service. Hold an Arc so handlers can clone cheaply.
|
||||
pub insight_chat: Arc<InsightChatService>,
|
||||
/// Face inference client (calls Apollo's `/api/internal/faces/*`).
|
||||
/// Disabled (`is_enabled() == false`) when neither `APOLLO_FACE_API_BASE_URL`
|
||||
/// nor `APOLLO_API_BASE_URL` is set; the file-watch hook (Phase 3) and
|
||||
/// manual-face-create handler short-circuit in that case.
|
||||
pub turn_registry: Arc<TurnRegistry>,
|
||||
pub face_client: FaceClient,
|
||||
/// CLIP inference client (calls Apollo's `/api/internal/clip/*`).
|
||||
/// Same disabled semantics as `face_client`: unset env → no-op
|
||||
/// backlog drain, /photos/search returns an empty result.
|
||||
pub clip_client: ClipClient,
|
||||
/// Tracks async insight generation jobs (spawned by generate endpoints).
|
||||
pub insight_job_dao: Arc<Mutex<Box<dyn InsightGenerationJobDao>>>,
|
||||
/// In-memory map from job_id → tokio AbortHandle for running tasks.
|
||||
/// Used to abort server-side tasks on cancel or regenerate.
|
||||
pub insight_job_handles: Arc<Mutex<HashMap<i32, tokio::task::AbortHandle>>>,
|
||||
}
|
||||
|
||||
@@ -127,6 +119,7 @@ impl AppState {
|
||||
sms_client: SmsApiClient,
|
||||
insight_generator: InsightGenerator,
|
||||
insight_chat: Arc<InsightChatService>,
|
||||
turn_registry: Arc<TurnRegistry>,
|
||||
preview_dao: Arc<Mutex<Box<dyn PreviewDao>>>,
|
||||
face_client: FaceClient,
|
||||
clip_client: ClipClient,
|
||||
@@ -171,6 +164,7 @@ impl AppState {
|
||||
sms_client,
|
||||
insight_generator,
|
||||
insight_chat,
|
||||
turn_registry,
|
||||
face_client,
|
||||
clip_client,
|
||||
insight_job_dao,
|
||||
@@ -310,6 +304,14 @@ impl Default for AppState {
|
||||
chat_locks,
|
||||
));
|
||||
|
||||
// Turn registry for reconnectable chat turns. 5-minute timeout for
|
||||
// stale turns (background cleaner drops entries older than this).
|
||||
let timeout_secs: u64 = env::var("INSIGHT_CHAT_TURN_TIMEOUT_SECS")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(300);
|
||||
let turn_registry = Arc::new(TurnRegistry::new(timeout_secs));
|
||||
|
||||
// Ensure preview clips directory exists
|
||||
let preview_clips_path =
|
||||
env::var("PREVIEW_CLIPS_DIRECTORY").unwrap_or_else(|_| "preview_clips".to_string());
|
||||
@@ -332,6 +334,7 @@ impl Default for AppState {
|
||||
sms_client,
|
||||
insight_generator,
|
||||
insight_chat,
|
||||
turn_registry,
|
||||
preview_dao,
|
||||
face_client,
|
||||
clip_client,
|
||||
@@ -490,6 +493,9 @@ impl AppState {
|
||||
chat_locks,
|
||||
));
|
||||
|
||||
// Turn registry for test state.
|
||||
let turn_registry = Arc::new(TurnRegistry::new(300));
|
||||
|
||||
// Initialize test preview DAO
|
||||
let preview_dao: Arc<Mutex<Box<dyn PreviewDao>>> =
|
||||
Arc::new(Mutex::new(Box::new(SqlitePreviewDao::new())));
|
||||
@@ -518,6 +524,7 @@ impl AppState {
|
||||
sms_client,
|
||||
insight_generator,
|
||||
insight_chat,
|
||||
turn_registry,
|
||||
preview_dao,
|
||||
FaceClient::new(None), // disabled in test
|
||||
ClipClient::new(None), // disabled in test
|
||||
|
||||
Reference in New Issue
Block a user