diff --git a/Cargo.lock b/Cargo.lock index 2e210ea..a404abc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -486,6 +486,28 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -1843,10 +1865,12 @@ dependencies = [ "actix-web", "actix-web-prom", "anyhow", + "async-stream", "async-trait", "base64", "bcrypt", "blake3", + "bytes", "chrono", "clap", "diesel", @@ -1878,6 +1902,7 @@ dependencies = [ "serde_json", "tempfile", "tokio", + "tokio-util", "urlencoding", "walkdir", "zerocopy", @@ -3125,12 +3150,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", ] @@ -4219,6 +4246,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" diff --git a/Cargo.toml b/Cargo.toml index be60128..2b966b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,10 @@ opentelemetry-appender-log = "0.31.0" tempfile = "3.20.0" regex = "1.11.1" exif = { package = "kamadak-exif", version = "0.6.1" } -reqwest = { version = "0.12", features = ["json"] } +reqwest = { version = "0.12", features = ["json", "stream"] } +async-stream = "0.3" +tokio-util = { version = "0.7", features = ["io"] } +bytes = "1" urlencoding = "2.1" zerocopy = "0.8" ical = "0.11" diff --git a/src/ai/handlers.rs b/src/ai/handlers.rs index 6bcb592..d24927f 100644 --- a/src/ai/handlers.rs +++ b/src/ai/handlers.rs @@ -3,7 +3,7 @@ use opentelemetry::KeyValue; use opentelemetry::trace::{Span, Status, Tracer}; use serde::{Deserialize, Serialize}; -use crate::ai::insight_chat::ChatTurnRequest; +use crate::ai::insight_chat::{ChatStreamEvent, ChatTurnRequest}; use crate::ai::{InsightGenerator, ModelCapabilities, OllamaClient}; use crate::data::Claims; use crate::database::{ExifDao, InsightDao}; @@ -826,3 +826,109 @@ pub async fn chat_history_handler( } } } + +/// POST /insights/chat/stream — streaming variant of /insights/chat. +/// Returns `text/event-stream` with one event per chat stream event. +#[post("/insights/chat/stream")] +pub async fn chat_stream_handler( + _claims: Claims, + request: web::Json, + app_state: web::Data, +) -> HttpResponse { + let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) { + Ok(Some(lib)) => lib, + Ok(None) => app_state.primary_library(), + Err(e) => { + return HttpResponse::BadRequest().json(serde_json::json!({ + "error": format!("invalid library: {}", e) + })); + } + }; + + let chat_req = ChatTurnRequest { + library_id: library.id, + file_path: request.file_path.clone(), + user_message: request.user_message.clone(), + model: request.model.clone(), + backend: request.backend.clone(), + num_ctx: request.num_ctx, + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + min_p: request.min_p, + max_iterations: request.max_iterations, + amend: request.amend, + }; + + let service = app_state.insight_chat.clone(); + let events = service.chat_turn_stream(chat_req); + + // Map ChatStreamEvent → SSE frame bytes. + let sse_stream = futures::stream::StreamExt::map(events, |ev| { + let frame = render_sse_frame(&ev); + Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(frame)) + }); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("X-Accel-Buffering", "no")) // nginx: disable response buffering + .streaming(sse_stream) +} + +fn render_sse_frame(ev: &ChatStreamEvent) -> String { + let (event_name, payload) = match ev { + ChatStreamEvent::IterationStart { n, max } => { + ("iteration_start", serde_json::json!({ "n": n, "max": max })) + } + ChatStreamEvent::Truncated => ("truncated", serde_json::json!({})), + ChatStreamEvent::TextDelta(delta) => ("text", serde_json::json!({ "delta": delta })), + ChatStreamEvent::ToolCall { + index, + name, + arguments, + } => ( + "tool_call", + serde_json::json!({ "index": index, "name": name, "arguments": arguments }), + ), + ChatStreamEvent::ToolResult { + index, + name, + result, + result_truncated, + } => ( + "tool_result", + serde_json::json!({ + "index": index, + "name": name, + "result": result, + "result_truncated": result_truncated, + }), + ), + ChatStreamEvent::Done { + tool_calls_made, + iterations_used, + truncated, + prompt_eval_count, + eval_count, + amended_insight_id, + backend_used, + model_used, + } => ( + "done", + serde_json::json!({ + "tool_calls_made": tool_calls_made, + "iterations_used": iterations_used, + "truncated": truncated, + "prompt_eval_count": prompt_eval_count, + "eval_count": eval_count, + "amended_insight_id": amended_insight_id, + "backend": backend_used, + "model": model_used, + }), + ), + ChatStreamEvent::Error(msg) => ("error", serde_json::json!({ "message": msg })), + }; + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); + format!("event: {}\ndata: {}\n\n", event_name, data) +} diff --git a/src/ai/insight_chat.rs b/src/ai/insight_chat.rs index 2cc1da7..a4d21ea 100644 --- a/src/ai/insight_chat.rs +++ b/src/ai/insight_chat.rs @@ -7,13 +7,14 @@ use std::sync::{Arc, Mutex}; use tokio::sync::Mutex as TokioMutex; use crate::ai::insight_generator::InsightGenerator; -use crate::ai::llm_client::{ChatMessage, LlmClient}; +use crate::ai::llm_client::{ChatMessage, LlmClient, LlmStreamEvent}; use crate::ai::ollama::OllamaClient; use crate::ai::openrouter::OpenRouterClient; use crate::database::InsightDao; use crate::database::models::InsertPhotoInsight; use crate::otel::global_tracer; use crate::utils::normalize_path; +use futures::stream::{BoxStream, StreamExt}; const DEFAULT_MAX_ITERATIONS: usize = 6; const DEFAULT_NUM_CTX: i32 = 8192; @@ -583,6 +584,442 @@ impl InsightChatService { .map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?; Ok(()) } + + /// Streaming variant of `chat_turn`. Emits user-facing events as the + /// conversation progresses: iteration starts, tool dispatch + result, + /// text deltas from the final assistant reply, and a terminal `Done` + /// frame. Persistence happens inside the stream after the loop ends. + /// + /// The stream takes ownership of the service via `Arc` (passed by + /// the caller) so it can live past the handler's await boundary. + pub fn chat_turn_stream( + self: Arc, + req: ChatTurnRequest, + ) -> BoxStream<'static, ChatStreamEvent> { + let svc = self; + let s = async_stream::stream! { + match svc.chat_turn_stream_inner(req, |ev| Ok(ev)).await { + Ok(mut rx) => { + while let Some(ev) = rx.recv().await { + yield ev; + } + } + Err(e) => { + yield ChatStreamEvent::Error(format!("{}", e)); + } + } + }; + Box::pin(s) + } + + /// Internal: drives the streaming loop on a background task, returning + /// a receiver the caller drains. Keeping the work on a spawned task + /// decouples the HTTP request lifetime from the chat execution, which + /// matters because the chat may run longer than any single network hop + /// and we want clean cancellation semantics via the channel close. + async fn chat_turn_stream_inner( + self: Arc, + req: ChatTurnRequest, + _ev_mapper: F, + ) -> Result> + where + F: Fn(ChatStreamEvent) -> Result + Send + 'static, + { + let (tx, rx) = tokio::sync::mpsc::channel::(64); + let svc = self.clone(); + tokio::spawn(async move { + let result = svc.run_streaming_turn(req, tx.clone()).await; + if let Err(e) = result { + let _ = tx.send(ChatStreamEvent::Error(format!("{}", e))).await; + } + }); + Ok(rx) + } + + async fn run_streaming_turn( + self: Arc, + req: ChatTurnRequest, + tx: tokio::sync::mpsc::Sender, + ) -> Result<()> { + if req.user_message.trim().is_empty() { + bail!("user_message must not be empty"); + } + if req.user_message.len() > 8192 { + bail!("user_message exceeds 8192 chars"); + } + let normalized = normalize_path(&req.file_path); + + let lock_key = (req.library_id, normalized.clone()); + let entry_lock = { + let mut locks = self.chat_locks.lock().await; + locks + .entry(lock_key.clone()) + .or_insert_with(|| Arc::new(TokioMutex::new(()))) + .clone() + }; + let _guard = entry_lock.lock().await; + + let insight = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.get_insight(&cx, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + .ok_or_else(|| anyhow!("no insight found for path"))? + }; + let raw_history = insight + .training_messages + .as_ref() + .ok_or_else(|| { + anyhow!("insight has no chat history; regenerate this insight in agentic mode") + })? + .clone(); + let mut messages: Vec = serde_json::from_str(&raw_history) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + // Backend selection — same rules as non-streaming chat_turn. + let stored_backend = insight.backend.clone(); + let effective_backend = req + .backend + .as_deref() + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| stored_backend.clone()); + if !matches!(effective_backend.as_str(), "local" | "hybrid") { + bail!( + "unknown backend '{}'; expected 'local' or 'hybrid'", + effective_backend + ); + } + if stored_backend == "local" && effective_backend == "hybrid" { + bail!( + "switching from local to hybrid mid-chat isn't supported yet; \ + regenerate the insight in hybrid mode if you want OpenRouter chat" + ); + } + let is_hybrid = effective_backend == "hybrid"; + + let max_iterations = req + .max_iterations + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .clamp(1, env_max_iterations()); + + let stored_model = insight.model_version.clone(); + let custom_model = req + .model + .clone() + .or_else(|| Some(stored_model.clone())) + .filter(|m| !m.is_empty()); + + let mut ollama_client = self.ollama.clone(); + let mut openrouter_client: Option = None; + + if is_hybrid { + let arc = self.openrouter.as_ref().ok_or_else(|| { + anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured") + })?; + let mut c: OpenRouterClient = (**arc).clone(); + if let Some(ref m) = custom_model { + c.primary_model = m.clone(); + } + if req.temperature.is_some() + || req.top_p.is_some() + || req.top_k.is_some() + || req.min_p.is_some() + { + c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p); + } + if let Some(ctx) = req.num_ctx { + c.set_num_ctx(Some(ctx)); + } + openrouter_client = Some(c); + } else { + if let Some(ref m) = custom_model + && m != &self.ollama.primary_model + { + ollama_client = OllamaClient::new( + self.ollama.primary_url.clone(), + self.ollama.fallback_url.clone(), + m.clone(), + Some(m.clone()), + ); + } + if req.temperature.is_some() + || req.top_p.is_some() + || req.top_k.is_some() + || req.min_p.is_some() + { + ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p); + } + if let Some(ctx) = req.num_ctx { + ollama_client.set_num_ctx(Some(ctx)); + } + } + + let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client { + c + } else { + &ollama_client + }; + let model_used = chat_backend.primary_model().to_string(); + + // Tool set. + let local_first_user_has_image = messages + .iter() + .find(|m| m.role == "user") + .and_then(|m| m.images.as_ref()) + .map(|imgs| !imgs.is_empty()) + .unwrap_or(false); + let offer_describe_tool = !is_hybrid && local_first_user_has_image; + let tools = InsightGenerator::build_tool_definitions(offer_describe_tool); + + let image_base64: Option = if offer_describe_tool { + self.generator.load_image_as_base64(&normalized).ok() + } else { + None + }; + + // Truncate before appending the new user turn. + let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize) + .saturating_sub(RESPONSE_HEADROOM_TOKENS); + let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN); + let truncated = apply_context_budget(&mut messages, budget_bytes); + if truncated { + let _ = tx.send(ChatStreamEvent::Truncated).await; + } + + messages.push(ChatMessage::user(req.user_message.clone())); + + let mut tool_calls_made = 0usize; + let mut iterations_used = 0usize; + let mut last_prompt_eval_count: Option = None; + let mut last_eval_count: Option = None; + let mut final_content = String::new(); + + for iteration in 0..max_iterations { + iterations_used = iteration + 1; + let _ = tx + .send(ChatStreamEvent::IterationStart { + n: iterations_used, + max: max_iterations, + }) + .await; + + let mut stream = chat_backend + .chat_with_tools_stream(messages.clone(), tools.clone()) + .await?; + + let mut final_message: Option = None; + while let Some(ev) = stream.next().await { + let ev = ev?; + match ev { + LlmStreamEvent::TextDelta(delta) => { + let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await; + } + LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + } => { + last_prompt_eval_count = prompt_eval_count; + last_eval_count = eval_count; + final_message = Some(message); + break; + } + } + } + let mut response = + final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?; + + // Normalize non-object tool arguments (same as non-streaming path). + if let Some(ref mut tcs) = response.tool_calls { + for tc in tcs.iter_mut() { + if !tc.function.arguments.is_object() { + tc.function.arguments = serde_json::Value::Object(Default::default()); + } + } + } + + messages.push(response.clone()); + + if let Some(ref tool_calls) = response.tool_calls + && !tool_calls.is_empty() + { + for (i, tool_call) in tool_calls.iter().enumerate() { + tool_calls_made += 1; + let call_index = tool_calls_made - 1; + let _ = tx + .send(ChatStreamEvent::ToolCall { + index: call_index, + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + }) + .await; + let cx = opentelemetry::Context::new(); + let result = self + .generator + .execute_tool( + &tool_call.function.name, + &tool_call.function.arguments, + &ollama_client, + &image_base64, + &normalized, + &cx, + ) + .await; + let (result_preview, result_truncated) = truncate_tool_result(&result); + let _ = tx + .send(ChatStreamEvent::ToolResult { + index: call_index, + name: tool_call.function.name.clone(), + result: result_preview, + result_truncated, + }) + .await; + messages.push(ChatMessage::tool_result(result)); + let _ = i; // reserved for per-call ordering if needed + } + continue; + } + + final_content = response.content; + break; + } + + if final_content.is_empty() { + messages.push(ChatMessage::user( + "Please write your final answer now without calling any more tools.", + )); + let mut stream = chat_backend + .chat_with_tools_stream(messages.clone(), vec![]) + .await?; + let mut final_message: Option = None; + while let Some(ev) = stream.next().await { + let ev = ev?; + match ev { + LlmStreamEvent::TextDelta(delta) => { + let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await; + } + LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + } => { + last_prompt_eval_count = prompt_eval_count; + last_eval_count = eval_count; + final_message = Some(message); + break; + } + } + } + let final_response = + final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?; + final_content = final_response.content.clone(); + messages.push(final_response); + } + + // Persist. + let json = serde_json::to_string(&messages) + .map_err(|e| anyhow!("failed to serialize chat history: {}", e))?; + + let mut amended_insight_id: Option = None; + if req.amend { + let title_prompt = format!( + "Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\ + Capture the key moment or theme. Return ONLY the title, nothing else.", + final_content + ); + let title_raw = chat_backend + .generate( + &title_prompt, + Some( + "You are my long term memory assistant. Use only the information provided. Do not invent details.", + ), + None, + ) + .await?; + let title = title_raw.trim().trim_matches('"').to_string(); + + let new_row = InsertPhotoInsight { + library_id: req.library_id, + file_path: normalized.clone(), + title, + summary: final_content.clone(), + generated_at: Utc::now().timestamp(), + model_version: model_used.clone(), + is_current: true, + training_messages: Some(json), + backend: effective_backend.clone(), + }; + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let stored = dao + .store_insight(&cx, new_row) + .map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?; + amended_insight_id = Some(stored.id); + } else { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.update_training_messages(&cx, req.library_id, &normalized, &json) + .map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?; + } + + let _ = tx + .send(ChatStreamEvent::Done { + tool_calls_made, + iterations_used, + truncated, + prompt_eval_count: last_prompt_eval_count, + eval_count: last_eval_count, + amended_insight_id, + backend_used: effective_backend, + model_used, + }) + .await; + + Ok(()) + } +} + +/// Events emitted by `chat_turn_stream`. One stream per turn; ends after +/// `Done` or `Error`. +#[derive(Debug, Clone)] +pub enum ChatStreamEvent { + /// Starting iteration `n` of up to `max` (1-based). + IterationStart { n: usize, max: usize }, + /// History was trimmed to fit the context budget before the turn ran. + /// Emitted at most once, before any tool or text events. + Truncated, + /// Incremental content from the final assistant reply. Concatenate to + /// reconstruct the reply body. Tool-dispatch turns don't produce these. + TextDelta(String), + /// The model requested this tool call. Emitted just before execution. + /// `index` is a monotonically-increasing counter across the turn so the + /// client can pair `ToolCall` with its matching `ToolResult`. + ToolCall { + index: usize, + name: String, + arguments: serde_json::Value, + }, + /// The tool finished; `result` is the (possibly truncated) output. + ToolResult { + index: usize, + name: String, + result: String, + result_truncated: bool, + }, + /// Terminal success event with counters + persistence result. + Done { + tool_calls_made: usize, + iterations_used: usize, + truncated: bool, + prompt_eval_count: Option, + eval_count: Option, + amended_insight_id: Option, + backend_used: String, + model_used: String, + }, + /// Terminal failure event. No further events follow. + Error(String), } /// Is this raw message visible in the rendered transcript? Must match diff --git a/src/ai/llm_client.rs b/src/ai/llm_client.rs index c1f1bca..8d68978 100644 --- a/src/ai/llm_client.rs +++ b/src/ai/llm_client.rs @@ -1,5 +1,6 @@ use anyhow::Result; use async_trait::async_trait; +use futures::stream::BoxStream; use serde::{Deserialize, Serialize}; /// Provider-agnostic surface for LLM backends (Ollama, OpenRouter, …). @@ -30,6 +31,18 @@ pub trait LlmClient: Send + Sync { tools: Vec, ) -> Result<(ChatMessage, Option, Option)>; + /// Streaming variant of `chat_with_tools`. The returned stream yields + /// `TextDelta` items as content is produced, then a single terminal + /// `Done` carrying the complete assembled message (with tool_calls, if + /// any) plus token usage counts. Implementations that can't stream may + /// fall back to calling `chat_with_tools` and emitting the full reply + /// as one `Done` event. + async fn chat_with_tools_stream( + &self, + messages: Vec, + tools: Vec, + ) -> Result>>; + /// Batch embedding generation. Dimensionality is provider/model specific. async fn generate_embeddings(&self, texts: &[&str]) -> Result>>; @@ -47,6 +60,25 @@ pub trait LlmClient: Send + Sync { fn primary_model(&self) -> &str; } +/// Events emitted by streaming `chat_with_tools_stream`. A stream is a +/// sequence of zero or more `TextDelta` events followed by exactly one +/// `Done`. Callers should treat `Done` as terminal — further items (if any +/// slip through due to upstream misbehavior) are safe to ignore. +#[derive(Debug, Clone)] +pub enum LlmStreamEvent { + /// Incremental content token(s) from the model. Concatenate in order to + /// reconstruct the assistant's final text. + TextDelta(String), + /// Terminal event with the full assembled message (content + any + /// tool_calls). `message.content` equals the concatenation of every + /// preceding `TextDelta.0`. + Done { + message: ChatMessage, + prompt_eval_count: Option, + eval_count: Option, + }, +} + /// Tool definition sent to the model (OpenAI-compatible function schema). #[derive(Serialize, Clone, Debug)] pub struct Tool { diff --git a/src/ai/mod.rs b/src/ai/mod.rs index 3c58b2a..8e38930 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -11,10 +11,10 @@ pub mod sms_client; #[allow(unused_imports)] pub use daily_summary_job::{generate_daily_summaries, strip_summary_boilerplate}; pub use handlers::{ - chat_history_handler, chat_rewind_handler, chat_turn_handler, delete_insight_handler, - export_training_data_handler, generate_agentic_insight_handler, generate_insight_handler, - get_all_insights_handler, get_available_models_handler, get_insight_handler, - get_openrouter_models_handler, rate_insight_handler, + chat_history_handler, chat_rewind_handler, chat_stream_handler, chat_turn_handler, + delete_insight_handler, export_training_data_handler, generate_agentic_insight_handler, + generate_insight_handler, get_all_insights_handler, get_available_models_handler, + get_insight_handler, get_openrouter_models_handler, rate_insight_handler, }; pub use insight_generator::InsightGenerator; #[allow(unused_imports)] diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs index 8c487c5..1dc67f8 100644 --- a/src/ai/ollama.rs +++ b/src/ai/ollama.rs @@ -7,7 +7,8 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; -use crate::ai::llm_client::LlmClient; +use crate::ai::llm_client::{LlmClient, LlmStreamEvent}; +use futures::stream::{BoxStream, StreamExt}; // Re-export shared types so existing `crate::ai::ollama::{...}` imports // continue to resolve. @@ -634,6 +635,174 @@ Analyze the image and use specific details from both the visual content and the } } + /// Streaming variant of `chat_with_tools`. Tries primary, then falls + /// back if the initial connection fails; once the stream has begun + /// emitting, mid-stream errors propagate to the caller. Emits + /// `TextDelta` events as content tokens arrive and a single terminal + /// `Done` event when the model marks the turn complete (tool_calls, if + /// any, live on the final message). + pub async fn chat_with_tools_stream( + &self, + messages: Vec, + tools: Vec, + ) -> Result>> { + // Attempt primary. If it can't be opened at all, try fallback. + match self + .try_chat_with_tools_stream(&self.primary_url, messages.clone(), tools.clone()) + .await + { + Ok(s) => Ok(s), + Err(e) => { + if let Some(fallback_url) = self.fallback_url.clone() { + log::warn!( + "Streaming chat primary failed ({}); trying fallback {}", + e, + fallback_url + ); + self.try_chat_with_tools_stream(&fallback_url, messages, tools) + .await + } else { + Err(e) + } + } + } + } + + async fn try_chat_with_tools_stream( + &self, + base_url: &str, + messages: Vec, + tools: Vec, + ) -> Result>> { + let url = format!("{}/api/chat", base_url); + let model = if base_url == self.primary_url { + &self.primary_model + } else { + self.fallback_model + .as_deref() + .unwrap_or(&self.primary_model) + }; + let options = self.build_options(); + + let request_body = OllamaChatRequest { + model, + messages: &messages, + stream: true, + tools, + options, + }; + + let response = self + .client + .post(&url) + .json(&request_body) + .send() + .await + .with_context(|| format!("Failed to connect to Ollama at {}", url))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!( + "Ollama stream request failed with status {}: {}", + status, + body + ); + } + + // Ollama streams NDJSON: each line is a full `OllamaStreamChunk`. + // We buffer partial lines across chunks from the byte stream. + let byte_stream = response.bytes_stream(); + let stream = async_stream::stream! { + let mut buf: Vec = Vec::new(); + let mut accumulated = String::new(); + let mut tool_calls: Option> = None; + let mut role = "assistant".to_string(); + let mut prompt_eval_count: Option = None; + let mut eval_count: Option = None; + let mut prompt_eval_duration: Option = None; + let mut eval_duration: Option = None; + let mut done_seen = false; + + let mut byte_stream = byte_stream; + while let Some(chunk) = byte_stream.next().await { + let chunk = match chunk { + Ok(b) => b, + Err(e) => { + yield Err(anyhow::anyhow!("stream read failed: {}", e)); + return; + } + }; + buf.extend_from_slice(&chunk); + + // Drain complete lines; hold any trailing partial. + while let Some(nl) = buf.iter().position(|b| *b == b'\n') { + let line = buf.drain(..=nl).collect::>(); + let line_str = match std::str::from_utf8(&line) { + Ok(s) => s.trim(), + Err(_) => continue, + }; + if line_str.is_empty() { + continue; + } + match serde_json::from_str::(line_str) { + Ok(chunk) => { + // Accumulate content delta. + if !chunk.message.content.is_empty() { + accumulated.push_str(&chunk.message.content); + yield Ok(LlmStreamEvent::TextDelta(chunk.message.content)); + } + if !chunk.message.role.is_empty() { + role = chunk.message.role; + } + // Ollama only attaches tool_calls on the final chunk. + if let Some(tcs) = chunk.message.tool_calls + && !tcs.is_empty() + { + tool_calls = Some(tcs); + } + if chunk.done { + prompt_eval_count = chunk.prompt_eval_count; + eval_count = chunk.eval_count; + prompt_eval_duration = chunk.prompt_eval_duration; + eval_duration = chunk.eval_duration; + done_seen = true; + break; + } + } + Err(e) => { + log::warn!("malformed Ollama stream line: {} ({})", line_str, e); + } + } + } + if done_seen { + break; + } + } + + // Emit the terminal Done event with the assembled message. + log_chat_metrics( + prompt_eval_count, + prompt_eval_duration, + eval_count, + eval_duration, + ); + let message = ChatMessage { + role, + content: accumulated, + tool_calls, + images: None, + }; + yield Ok(LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + }); + }; + + Ok(Box::pin(stream)) + } + async fn try_chat_with_tools( &self, base_url: &str, @@ -857,6 +1026,14 @@ impl LlmClient for OllamaClient { OllamaClient::chat_with_tools(self, messages, tools).await } + async fn chat_with_tools_stream( + &self, + messages: Vec, + tools: Vec, + ) -> Result>> { + OllamaClient::chat_with_tools_stream(self, messages, tools).await + } + async fn generate_embeddings(&self, texts: &[&str]) -> Result>> { OllamaClient::generate_embeddings(self, texts).await } @@ -936,6 +1113,35 @@ struct OllamaChatResponse { eval_duration: Option, } +/// One chunk in the NDJSON stream from `/api/chat` with `stream: true`. +/// Early chunks carry content deltas in `message.content`; the final chunk +/// has `done: true`, optional `tool_calls`, and usage counters. +#[derive(Deserialize, Debug)] +struct OllamaStreamChunk { + #[serde(default)] + message: OllamaStreamMessage, + #[serde(default)] + done: bool, + #[serde(default)] + prompt_eval_count: Option, + #[serde(default)] + prompt_eval_duration: Option, + #[serde(default)] + eval_count: Option, + #[serde(default)] + eval_duration: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct OllamaStreamMessage { + #[serde(default)] + role: String, + #[serde(default)] + content: String, + #[serde(default)] + tool_calls: Option>, +} + #[derive(Deserialize)] struct OllamaResponse { response: String, diff --git a/src/ai/openrouter.rs b/src/ai/openrouter.rs index 2c46852..1559479 100644 --- a/src/ai/openrouter.rs +++ b/src/ai/openrouter.rs @@ -12,8 +12,9 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use crate::ai::llm_client::{ - ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction, + ChatMessage, LlmClient, LlmStreamEvent, ModelCapabilities, Tool, ToolCall, ToolCallFunction, }; +use futures::stream::{BoxStream, StreamExt}; const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1"; const DEFAULT_EMBEDDING_MODEL: &str = "openai/text-embedding-3-small"; @@ -378,6 +379,220 @@ impl LlmClient for OpenRouterClient { Ok((chat_msg, prompt_tokens, completion_tokens)) } + async fn chat_with_tools_stream( + &self, + messages: Vec, + tools: Vec, + ) -> Result>> { + let url = format!("{}/chat/completions", self.base_url); + let mut body = serde_json::Map::new(); + body.insert("model".into(), Value::String(self.primary_model.clone())); + body.insert( + "messages".into(), + Value::Array(Self::messages_to_openai(&messages)), + ); + body.insert("stream".into(), Value::Bool(true)); + // Ask for usage data in the final chunk (OpenAI + OpenRouter + // both honor this options bag). + body.insert( + "stream_options".into(), + serde_json::json!({ "include_usage": true }), + ); + if !tools.is_empty() { + body.insert( + "tools".into(), + serde_json::to_value(&tools).context("serializing tools")?, + ); + } + for (k, v) in self.build_options() { + body.insert(k.into(), v); + } + + let resp = self + .authed(self.client.post(&url)) + .json(&Value::Object(body)) + .send() + .await + .with_context(|| format!("POST {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("OpenRouter stream request failed: {} — {}", status, body); + } + + // OpenAI-compat SSE stream. Each event is `data: \n\n`, with + // `data: [DONE]` signalling completion. Tool calls arrive as + // `delta.tool_calls[i]` chunks that must be concatenated by index. + let byte_stream = resp.bytes_stream(); + let stream = async_stream::stream! { + let mut byte_stream = byte_stream; + let mut buf: Vec = Vec::new(); + let mut accumulated_content = String::new(); + // tool call state: index -> (id, name, args_string) + let mut tool_state: std::collections::BTreeMap< + usize, + (Option, Option, String), + > = std::collections::BTreeMap::new(); + let mut role = "assistant".to_string(); + let mut prompt_tokens: Option = None; + let mut completion_tokens: Option = None; + let mut done_seen = false; + + while let Some(chunk) = byte_stream.next().await { + let chunk = match chunk { + Ok(b) => b, + Err(e) => { + yield Err(anyhow!("stream read failed: {}", e)); + return; + } + }; + buf.extend_from_slice(&chunk); + + // SSE frames are delimited by a blank line. Walk the buffer + // for "\n\n" markers; anything before them is a complete + // frame (possibly multi-line). + loop { + let Some(sep) = find_double_newline(&buf) else { break }; + let frame = buf.drain(..sep + 2).collect::>(); + let frame_str = match std::str::from_utf8(&frame) { + Ok(s) => s, + Err(_) => continue, + }; + // A frame is one or more lines; the payload is on data: + // lines. Ignore comments and other fields. + for line in frame_str.lines() { + let line = line.trim_end_matches('\r'); + let payload = match line.strip_prefix("data: ") { + Some(p) => p, + None => continue, + }; + if payload == "[DONE]" { + done_seen = true; + break; + } + let v: Value = match serde_json::from_str(payload) { + Ok(v) => v, + Err(e) => { + log::warn!( + "malformed OpenRouter SSE frame: {} ({})", + payload, + e + ); + continue; + } + }; + + // Usage can arrive in a dedicated final frame with + // empty choices. + if let Some(usage) = v.get("usage") { + prompt_tokens = usage + .get("prompt_tokens") + .and_then(|n| n.as_i64()) + .map(|n| n as i32); + completion_tokens = usage + .get("completion_tokens") + .and_then(|n| n.as_i64()) + .map(|n| n as i32); + } + + let Some(choices) = v.get("choices").and_then(|c| c.as_array()) + else { + continue; + }; + let Some(choice) = choices.first() else { continue }; + let delta = match choice.get("delta") { + Some(d) => d, + None => continue, + }; + if let Some(r) = delta.get("role").and_then(|v| v.as_str()) { + role = r.to_string(); + } + if let Some(content) = + delta.get("content").and_then(|v| v.as_str()) + && !content.is_empty() + { + accumulated_content.push_str(content); + yield Ok(LlmStreamEvent::TextDelta(content.to_string())); + } + if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) { + for tc_delta in tcs { + let idx = tc_delta + .get("index") + .and_then(|n| n.as_u64()) + .unwrap_or(0) as usize; + let entry = tool_state + .entry(idx) + .or_insert((None, None, String::new())); + if let Some(id) = + tc_delta.get("id").and_then(|v| v.as_str()) + { + entry.0 = Some(id.to_string()); + } + if let Some(func) = tc_delta.get("function") { + if let Some(name) = + func.get("name").and_then(|v| v.as_str()) + { + entry.1 = Some(name.to_string()); + } + if let Some(args) = + func.get("arguments").and_then(|v| v.as_str()) + { + entry.2.push_str(args); + } + } + } + } + } + if done_seen { + break; + } + } + if done_seen { + break; + } + } + + // Finalize tool calls: parse accumulated argument strings. + let tool_calls: Option> = if tool_state.is_empty() { + None + } else { + let mut v = Vec::with_capacity(tool_state.len()); + for (_idx, (id, name, args)) in tool_state { + let arguments: Value = if args.trim().is_empty() { + Value::Object(Default::default()) + } else { + serde_json::from_str(&args).unwrap_or_else(|_| { + Value::Object(Default::default()) + }) + }; + v.push(ToolCall { + id, + function: ToolCallFunction { + name: name.unwrap_or_default(), + arguments, + }, + }); + } + Some(v) + }; + + let message = ChatMessage { + role, + content: accumulated_content, + tool_calls, + images: None, + }; + yield Ok(LlmStreamEvent::Done { + message, + prompt_eval_count: prompt_tokens, + eval_count: completion_tokens, + }); + }; + + Ok(Box::pin(stream)) + } + async fn generate_embeddings(&self, texts: &[&str]) -> Result>> { let url = format!("{}/embeddings", self.base_url); let body = json!({ @@ -473,6 +688,28 @@ impl LlmClient for OpenRouterClient { } } +/// Find the byte offset of the first `\n\n` (end of an SSE frame) in `buf`. +/// Returns the index of the first `\n` of the pair, so the full separator is +/// `buf[idx..=idx+1]`. Also handles `\r\n\r\n` since some servers emit it. +fn find_double_newline(buf: &[u8]) -> Option { + for i in 0..buf.len().saturating_sub(1) { + if buf[i] == b'\n' && buf[i + 1] == b'\n' { + return Some(i); + } + // \r\n\r\n: the second \n of this pattern is at i+2; flag at i so the + // drain call (which consumes ..sep+2) takes exactly the frame. + if i + 3 < buf.len() + && buf[i] == b'\r' + && buf[i + 1] == b'\n' + && buf[i + 2] == b'\r' + && buf[i + 3] == b'\n' + { + return Some(i + 1); + } + } + None +} + /// Build a `data:` URL if the provided string is raw base64, otherwise pass it through. fn image_to_data_url(img: &str) -> String { if img.starts_with("data:") { diff --git a/src/main.rs b/src/main.rs index 8be397f..da0c4dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1357,6 +1357,7 @@ fn main() -> std::io::Result<()> { .service(ai::get_available_models_handler) .service(ai::get_openrouter_models_handler) .service(ai::chat_turn_handler) + .service(ai::chat_stream_handler) .service(ai::chat_history_handler) .service(ai::chat_rewind_handler) .service(ai::rate_insight_handler)