OpenRouter Support, Insight Chat and User injection #56

Merged
cameron merged 24 commits from 005-llm-client-trait into master 2026-04-26 23:01:35 +00:00
9 changed files with 1071 additions and 9 deletions
Showing only changes of commit 079cd4c5b9 - Show all commits

40
Cargo.lock generated
View File

@@ -486,6 +486,28 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "async-stream"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
dependencies = [
"async-stream-impl",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-stream-impl"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "async-trait"
version = "0.1.89"
@@ -1843,10 +1865,12 @@ dependencies = [
"actix-web",
"actix-web-prom",
"anyhow",
"async-stream",
"async-trait",
"base64",
"bcrypt",
"blake3",
"bytes",
"chrono",
"clap",
"diesel",
@@ -1878,6 +1902,7 @@ dependencies = [
"serde_json",
"tempfile",
"tokio",
"tokio-util",
"urlencoding",
"walkdir",
"zerocopy",
@@ -3125,12 +3150,14 @@ dependencies = [
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-util",
"tower",
"tower-http",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
]
@@ -4219,6 +4246,19 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"

View File

@@ -49,7 +49,10 @@ opentelemetry-appender-log = "0.31.0"
tempfile = "3.20.0"
regex = "1.11.1"
exif = { package = "kamadak-exif", version = "0.6.1" }
reqwest = { version = "0.12", features = ["json"] }
reqwest = { version = "0.12", features = ["json", "stream"] }
async-stream = "0.3"
tokio-util = { version = "0.7", features = ["io"] }
bytes = "1"
urlencoding = "2.1"
zerocopy = "0.8"
ical = "0.11"

View File

@@ -3,7 +3,7 @@ use opentelemetry::KeyValue;
use opentelemetry::trace::{Span, Status, Tracer};
use serde::{Deserialize, Serialize};
use crate::ai::insight_chat::ChatTurnRequest;
use crate::ai::insight_chat::{ChatStreamEvent, ChatTurnRequest};
use crate::ai::{InsightGenerator, ModelCapabilities, OllamaClient};
use crate::data::Claims;
use crate::database::{ExifDao, InsightDao};
@@ -826,3 +826,109 @@ pub async fn chat_history_handler(
}
}
}
/// POST /insights/chat/stream — streaming variant of /insights/chat.
/// Returns `text/event-stream` with one event per chat stream event.
#[post("/insights/chat/stream")]
pub async fn chat_stream_handler(
_claims: Claims,
request: web::Json<ChatTurnHttpRequest>,
app_state: web::Data<AppState>,
) -> HttpResponse {
let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) {
Ok(Some(lib)) => lib,
Ok(None) => app_state.primary_library(),
Err(e) => {
return HttpResponse::BadRequest().json(serde_json::json!({
"error": format!("invalid library: {}", e)
}));
}
};
let chat_req = ChatTurnRequest {
library_id: library.id,
file_path: request.file_path.clone(),
user_message: request.user_message.clone(),
model: request.model.clone(),
backend: request.backend.clone(),
num_ctx: request.num_ctx,
temperature: request.temperature,
top_p: request.top_p,
top_k: request.top_k,
min_p: request.min_p,
max_iterations: request.max_iterations,
amend: request.amend,
};
let service = app_state.insight_chat.clone();
let events = service.chat_turn_stream(chat_req);
// Map ChatStreamEvent → SSE frame bytes.
let sse_stream = futures::stream::StreamExt::map(events, |ev| {
let frame = render_sse_frame(&ev);
Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(frame))
});
HttpResponse::Ok()
.content_type("text/event-stream")
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("X-Accel-Buffering", "no")) // nginx: disable response buffering
.streaming(sse_stream)
}
fn render_sse_frame(ev: &ChatStreamEvent) -> String {
let (event_name, payload) = match ev {
ChatStreamEvent::IterationStart { n, max } => {
("iteration_start", serde_json::json!({ "n": n, "max": max }))
}
ChatStreamEvent::Truncated => ("truncated", serde_json::json!({})),
ChatStreamEvent::TextDelta(delta) => ("text", serde_json::json!({ "delta": delta })),
ChatStreamEvent::ToolCall {
index,
name,
arguments,
} => (
"tool_call",
serde_json::json!({ "index": index, "name": name, "arguments": arguments }),
),
ChatStreamEvent::ToolResult {
index,
name,
result,
result_truncated,
} => (
"tool_result",
serde_json::json!({
"index": index,
"name": name,
"result": result,
"result_truncated": result_truncated,
}),
),
ChatStreamEvent::Done {
tool_calls_made,
iterations_used,
truncated,
prompt_eval_count,
eval_count,
amended_insight_id,
backend_used,
model_used,
} => (
"done",
serde_json::json!({
"tool_calls_made": tool_calls_made,
"iterations_used": iterations_used,
"truncated": truncated,
"prompt_eval_count": prompt_eval_count,
"eval_count": eval_count,
"amended_insight_id": amended_insight_id,
"backend": backend_used,
"model": model_used,
}),
),
ChatStreamEvent::Error(msg) => ("error", serde_json::json!({ "message": msg })),
};
let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string());
format!("event: {}\ndata: {}\n\n", event_name, data)
}

View File

@@ -7,13 +7,14 @@ use std::sync::{Arc, Mutex};
use tokio::sync::Mutex as TokioMutex;
use crate::ai::insight_generator::InsightGenerator;
use crate::ai::llm_client::{ChatMessage, LlmClient};
use crate::ai::llm_client::{ChatMessage, LlmClient, LlmStreamEvent};
use crate::ai::ollama::OllamaClient;
use crate::ai::openrouter::OpenRouterClient;
use crate::database::InsightDao;
use crate::database::models::InsertPhotoInsight;
use crate::otel::global_tracer;
use crate::utils::normalize_path;
use futures::stream::{BoxStream, StreamExt};
const DEFAULT_MAX_ITERATIONS: usize = 6;
const DEFAULT_NUM_CTX: i32 = 8192;
@@ -583,6 +584,442 @@ impl InsightChatService {
.map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?;
Ok(())
}
/// Streaming variant of `chat_turn`. Emits user-facing events as the
/// conversation progresses: iteration starts, tool dispatch + result,
/// text deltas from the final assistant reply, and a terminal `Done`
/// frame. Persistence happens inside the stream after the loop ends.
///
/// The stream takes ownership of the service via `Arc<Self>` (passed by
/// the caller) so it can live past the handler's await boundary.
pub fn chat_turn_stream(
self: Arc<Self>,
req: ChatTurnRequest,
) -> BoxStream<'static, ChatStreamEvent> {
let svc = self;
let s = async_stream::stream! {
match svc.chat_turn_stream_inner(req, |ev| Ok(ev)).await {
Ok(mut rx) => {
while let Some(ev) = rx.recv().await {
yield ev;
}
}
Err(e) => {
yield ChatStreamEvent::Error(format!("{}", e));
}
}
};
Box::pin(s)
}
/// Internal: drives the streaming loop on a background task, returning
/// a receiver the caller drains. Keeping the work on a spawned task
/// decouples the HTTP request lifetime from the chat execution, which
/// matters because the chat may run longer than any single network hop
/// and we want clean cancellation semantics via the channel close.
async fn chat_turn_stream_inner<F>(
self: Arc<Self>,
req: ChatTurnRequest,
_ev_mapper: F,
) -> Result<tokio::sync::mpsc::Receiver<ChatStreamEvent>>
where
F: Fn(ChatStreamEvent) -> Result<ChatStreamEvent> + Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel::<ChatStreamEvent>(64);
let svc = self.clone();
tokio::spawn(async move {
let result = svc.run_streaming_turn(req, tx.clone()).await;
if let Err(e) = result {
let _ = tx.send(ChatStreamEvent::Error(format!("{}", e))).await;
}
});
Ok(rx)
}
async fn run_streaming_turn(
self: Arc<Self>,
req: ChatTurnRequest,
tx: tokio::sync::mpsc::Sender<ChatStreamEvent>,
) -> Result<()> {
if req.user_message.trim().is_empty() {
bail!("user_message must not be empty");
}
if req.user_message.len() > 8192 {
bail!("user_message exceeds 8192 chars");
}
let normalized = normalize_path(&req.file_path);
let lock_key = (req.library_id, normalized.clone());
let entry_lock = {
let mut locks = self.chat_locks.lock().await;
locks
.entry(lock_key.clone())
.or_insert_with(|| Arc::new(TokioMutex::new(())))
.clone()
};
let _guard = entry_lock.lock().await;
let insight = {
let cx = opentelemetry::Context::new();
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
dao.get_insight(&cx, &normalized)
.map_err(|e| anyhow!("failed to load insight: {:?}", e))?
.ok_or_else(|| anyhow!("no insight found for path"))?
};
let raw_history = insight
.training_messages
.as_ref()
.ok_or_else(|| {
anyhow!("insight has no chat history; regenerate this insight in agentic mode")
})?
.clone();
let mut messages: Vec<ChatMessage> = serde_json::from_str(&raw_history)
.map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?;
// Backend selection — same rules as non-streaming chat_turn.
let stored_backend = insight.backend.clone();
let effective_backend = req
.backend
.as_deref()
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty())
.unwrap_or_else(|| stored_backend.clone());
if !matches!(effective_backend.as_str(), "local" | "hybrid") {
bail!(
"unknown backend '{}'; expected 'local' or 'hybrid'",
effective_backend
);
}
if stored_backend == "local" && effective_backend == "hybrid" {
bail!(
"switching from local to hybrid mid-chat isn't supported yet; \
regenerate the insight in hybrid mode if you want OpenRouter chat"
);
}
let is_hybrid = effective_backend == "hybrid";
let max_iterations = req
.max_iterations
.unwrap_or(DEFAULT_MAX_ITERATIONS)
.clamp(1, env_max_iterations());
let stored_model = insight.model_version.clone();
let custom_model = req
.model
.clone()
.or_else(|| Some(stored_model.clone()))
.filter(|m| !m.is_empty());
let mut ollama_client = self.ollama.clone();
let mut openrouter_client: Option<OpenRouterClient> = None;
if is_hybrid {
let arc = self.openrouter.as_ref().ok_or_else(|| {
anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured")
})?;
let mut c: OpenRouterClient = (**arc).clone();
if let Some(ref m) = custom_model {
c.primary_model = m.clone();
}
if req.temperature.is_some()
|| req.top_p.is_some()
|| req.top_k.is_some()
|| req.min_p.is_some()
{
c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
}
if let Some(ctx) = req.num_ctx {
c.set_num_ctx(Some(ctx));
}
openrouter_client = Some(c);
} else {
if let Some(ref m) = custom_model
&& m != &self.ollama.primary_model
{
ollama_client = OllamaClient::new(
self.ollama.primary_url.clone(),
self.ollama.fallback_url.clone(),
m.clone(),
Some(m.clone()),
);
}
if req.temperature.is_some()
|| req.top_p.is_some()
|| req.top_k.is_some()
|| req.min_p.is_some()
{
ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
}
if let Some(ctx) = req.num_ctx {
ollama_client.set_num_ctx(Some(ctx));
}
}
let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client {
c
} else {
&ollama_client
};
let model_used = chat_backend.primary_model().to_string();
// Tool set.
let local_first_user_has_image = messages
.iter()
.find(|m| m.role == "user")
.and_then(|m| m.images.as_ref())
.map(|imgs| !imgs.is_empty())
.unwrap_or(false);
let offer_describe_tool = !is_hybrid && local_first_user_has_image;
let tools = InsightGenerator::build_tool_definitions(offer_describe_tool);
let image_base64: Option<String> = if offer_describe_tool {
self.generator.load_image_as_base64(&normalized).ok()
} else {
None
};
// Truncate before appending the new user turn.
let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize)
.saturating_sub(RESPONSE_HEADROOM_TOKENS);
let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN);
let truncated = apply_context_budget(&mut messages, budget_bytes);
if truncated {
let _ = tx.send(ChatStreamEvent::Truncated).await;
}
messages.push(ChatMessage::user(req.user_message.clone()));
let mut tool_calls_made = 0usize;
let mut iterations_used = 0usize;
let mut last_prompt_eval_count: Option<i32> = None;
let mut last_eval_count: Option<i32> = None;
let mut final_content = String::new();
for iteration in 0..max_iterations {
iterations_used = iteration + 1;
let _ = tx
.send(ChatStreamEvent::IterationStart {
n: iterations_used,
max: max_iterations,
})
.await;
let mut stream = chat_backend
.chat_with_tools_stream(messages.clone(), tools.clone())
.await?;
let mut final_message: Option<ChatMessage> = None;
while let Some(ev) = stream.next().await {
let ev = ev?;
match ev {
LlmStreamEvent::TextDelta(delta) => {
let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await;
}
LlmStreamEvent::Done {
message,
prompt_eval_count,
eval_count,
} => {
last_prompt_eval_count = prompt_eval_count;
last_eval_count = eval_count;
final_message = Some(message);
break;
}
}
}
let mut response =
final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?;
// Normalize non-object tool arguments (same as non-streaming path).
if let Some(ref mut tcs) = response.tool_calls {
for tc in tcs.iter_mut() {
if !tc.function.arguments.is_object() {
tc.function.arguments = serde_json::Value::Object(Default::default());
}
}
}
messages.push(response.clone());
if let Some(ref tool_calls) = response.tool_calls
&& !tool_calls.is_empty()
{
for (i, tool_call) in tool_calls.iter().enumerate() {
tool_calls_made += 1;
let call_index = tool_calls_made - 1;
let _ = tx
.send(ChatStreamEvent::ToolCall {
index: call_index,
name: tool_call.function.name.clone(),
arguments: tool_call.function.arguments.clone(),
})
.await;
let cx = opentelemetry::Context::new();
let result = self
.generator
.execute_tool(
&tool_call.function.name,
&tool_call.function.arguments,
&ollama_client,
&image_base64,
&normalized,
&cx,
)
.await;
let (result_preview, result_truncated) = truncate_tool_result(&result);
let _ = tx
.send(ChatStreamEvent::ToolResult {
index: call_index,
name: tool_call.function.name.clone(),
result: result_preview,
result_truncated,
})
.await;
messages.push(ChatMessage::tool_result(result));
let _ = i; // reserved for per-call ordering if needed
}
continue;
}
final_content = response.content;
break;
}
if final_content.is_empty() {
messages.push(ChatMessage::user(
"Please write your final answer now without calling any more tools.",
));
let mut stream = chat_backend
.chat_with_tools_stream(messages.clone(), vec![])
.await?;
let mut final_message: Option<ChatMessage> = None;
while let Some(ev) = stream.next().await {
let ev = ev?;
match ev {
LlmStreamEvent::TextDelta(delta) => {
let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await;
}
LlmStreamEvent::Done {
message,
prompt_eval_count,
eval_count,
} => {
last_prompt_eval_count = prompt_eval_count;
last_eval_count = eval_count;
final_message = Some(message);
break;
}
}
}
let final_response =
final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?;
final_content = final_response.content.clone();
messages.push(final_response);
}
// Persist.
let json = serde_json::to_string(&messages)
.map_err(|e| anyhow!("failed to serialize chat history: {}", e))?;
let mut amended_insight_id: Option<i32> = None;
if req.amend {
let title_prompt = format!(
"Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\
Capture the key moment or theme. Return ONLY the title, nothing else.",
final_content
);
let title_raw = chat_backend
.generate(
&title_prompt,
Some(
"You are my long term memory assistant. Use only the information provided. Do not invent details.",
),
None,
)
.await?;
let title = title_raw.trim().trim_matches('"').to_string();
let new_row = InsertPhotoInsight {
library_id: req.library_id,
file_path: normalized.clone(),
title,
summary: final_content.clone(),
generated_at: Utc::now().timestamp(),
model_version: model_used.clone(),
is_current: true,
training_messages: Some(json),
backend: effective_backend.clone(),
};
let cx = opentelemetry::Context::new();
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
let stored = dao
.store_insight(&cx, new_row)
.map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?;
amended_insight_id = Some(stored.id);
} else {
let cx = opentelemetry::Context::new();
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
dao.update_training_messages(&cx, req.library_id, &normalized, &json)
.map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?;
}
let _ = tx
.send(ChatStreamEvent::Done {
tool_calls_made,
iterations_used,
truncated,
prompt_eval_count: last_prompt_eval_count,
eval_count: last_eval_count,
amended_insight_id,
backend_used: effective_backend,
model_used,
})
.await;
Ok(())
}
}
/// Events emitted by `chat_turn_stream`. One stream per turn; ends after
/// `Done` or `Error`.
#[derive(Debug, Clone)]
pub enum ChatStreamEvent {
/// Starting iteration `n` of up to `max` (1-based).
IterationStart { n: usize, max: usize },
/// History was trimmed to fit the context budget before the turn ran.
/// Emitted at most once, before any tool or text events.
Truncated,
/// Incremental content from the final assistant reply. Concatenate to
/// reconstruct the reply body. Tool-dispatch turns don't produce these.
TextDelta(String),
/// The model requested this tool call. Emitted just before execution.
/// `index` is a monotonically-increasing counter across the turn so the
/// client can pair `ToolCall` with its matching `ToolResult`.
ToolCall {
index: usize,
name: String,
arguments: serde_json::Value,
},
/// The tool finished; `result` is the (possibly truncated) output.
ToolResult {
index: usize,
name: String,
result: String,
result_truncated: bool,
},
/// Terminal success event with counters + persistence result.
Done {
tool_calls_made: usize,
iterations_used: usize,
truncated: bool,
prompt_eval_count: Option<i32>,
eval_count: Option<i32>,
amended_insight_id: Option<i32>,
backend_used: String,
model_used: String,
},
/// Terminal failure event. No further events follow.
Error(String),
}
/// Is this raw message visible in the rendered transcript? Must match

View File

@@ -1,5 +1,6 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
/// Provider-agnostic surface for LLM backends (Ollama, OpenRouter, …).
@@ -30,6 +31,18 @@ pub trait LlmClient: Send + Sync {
tools: Vec<Tool>,
) -> Result<(ChatMessage, Option<i32>, Option<i32>)>;
/// Streaming variant of `chat_with_tools`. The returned stream yields
/// `TextDelta` items as content is produced, then a single terminal
/// `Done` carrying the complete assembled message (with tool_calls, if
/// any) plus token usage counts. Implementations that can't stream may
/// fall back to calling `chat_with_tools` and emitting the full reply
/// as one `Done` event.
async fn chat_with_tools_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>>;
/// Batch embedding generation. Dimensionality is provider/model specific.
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
@@ -47,6 +60,25 @@ pub trait LlmClient: Send + Sync {
fn primary_model(&self) -> &str;
}
/// Events emitted by streaming `chat_with_tools_stream`. A stream is a
/// sequence of zero or more `TextDelta` events followed by exactly one
/// `Done`. Callers should treat `Done` as terminal — further items (if any
/// slip through due to upstream misbehavior) are safe to ignore.
#[derive(Debug, Clone)]
pub enum LlmStreamEvent {
/// Incremental content token(s) from the model. Concatenate in order to
/// reconstruct the assistant's final text.
TextDelta(String),
/// Terminal event with the full assembled message (content + any
/// tool_calls). `message.content` equals the concatenation of every
/// preceding `TextDelta.0`.
Done {
message: ChatMessage,
prompt_eval_count: Option<i32>,
eval_count: Option<i32>,
},
}
/// Tool definition sent to the model (OpenAI-compatible function schema).
#[derive(Serialize, Clone, Debug)]
pub struct Tool {

View File

@@ -11,10 +11,10 @@ pub mod sms_client;
#[allow(unused_imports)]
pub use daily_summary_job::{generate_daily_summaries, strip_summary_boilerplate};
pub use handlers::{
chat_history_handler, chat_rewind_handler, chat_turn_handler, delete_insight_handler,
export_training_data_handler, generate_agentic_insight_handler, generate_insight_handler,
get_all_insights_handler, get_available_models_handler, get_insight_handler,
get_openrouter_models_handler, rate_insight_handler,
chat_history_handler, chat_rewind_handler, chat_stream_handler, chat_turn_handler,
delete_insight_handler, export_training_data_handler, generate_agentic_insight_handler,
generate_insight_handler, get_all_insights_handler, get_available_models_handler,
get_insight_handler, get_openrouter_models_handler, rate_insight_handler,
};
pub use insight_generator::InsightGenerator;
#[allow(unused_imports)]

View File

@@ -7,7 +7,8 @@ use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::ai::llm_client::LlmClient;
use crate::ai::llm_client::{LlmClient, LlmStreamEvent};
use futures::stream::{BoxStream, StreamExt};
// Re-export shared types so existing `crate::ai::ollama::{...}` imports
// continue to resolve.
@@ -634,6 +635,174 @@ Analyze the image and use specific details from both the visual content and the
}
}
/// Streaming variant of `chat_with_tools`. Tries primary, then falls
/// back if the initial connection fails; once the stream has begun
/// emitting, mid-stream errors propagate to the caller. Emits
/// `TextDelta` events as content tokens arrive and a single terminal
/// `Done` event when the model marks the turn complete (tool_calls, if
/// any, live on the final message).
pub async fn chat_with_tools_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
// Attempt primary. If it can't be opened at all, try fallback.
match self
.try_chat_with_tools_stream(&self.primary_url, messages.clone(), tools.clone())
.await
{
Ok(s) => Ok(s),
Err(e) => {
if let Some(fallback_url) = self.fallback_url.clone() {
log::warn!(
"Streaming chat primary failed ({}); trying fallback {}",
e,
fallback_url
);
self.try_chat_with_tools_stream(&fallback_url, messages, tools)
.await
} else {
Err(e)
}
}
}
}
async fn try_chat_with_tools_stream(
&self,
base_url: &str,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
let url = format!("{}/api/chat", base_url);
let model = if base_url == self.primary_url {
&self.primary_model
} else {
self.fallback_model
.as_deref()
.unwrap_or(&self.primary_model)
};
let options = self.build_options();
let request_body = OllamaChatRequest {
model,
messages: &messages,
stream: true,
tools,
options,
};
let response = self
.client
.post(&url)
.json(&request_body)
.send()
.await
.with_context(|| format!("Failed to connect to Ollama at {}", url))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!(
"Ollama stream request failed with status {}: {}",
status,
body
);
}
// Ollama streams NDJSON: each line is a full `OllamaStreamChunk`.
// We buffer partial lines across chunks from the byte stream.
let byte_stream = response.bytes_stream();
let stream = async_stream::stream! {
let mut buf: Vec<u8> = Vec::new();
let mut accumulated = String::new();
let mut tool_calls: Option<Vec<crate::ai::llm_client::ToolCall>> = None;
let mut role = "assistant".to_string();
let mut prompt_eval_count: Option<i32> = None;
let mut eval_count: Option<i32> = None;
let mut prompt_eval_duration: Option<u64> = None;
let mut eval_duration: Option<u64> = None;
let mut done_seen = false;
let mut byte_stream = byte_stream;
while let Some(chunk) = byte_stream.next().await {
let chunk = match chunk {
Ok(b) => b,
Err(e) => {
yield Err(anyhow::anyhow!("stream read failed: {}", e));
return;
}
};
buf.extend_from_slice(&chunk);
// Drain complete lines; hold any trailing partial.
while let Some(nl) = buf.iter().position(|b| *b == b'\n') {
let line = buf.drain(..=nl).collect::<Vec<_>>();
let line_str = match std::str::from_utf8(&line) {
Ok(s) => s.trim(),
Err(_) => continue,
};
if line_str.is_empty() {
continue;
}
match serde_json::from_str::<OllamaStreamChunk>(line_str) {
Ok(chunk) => {
// Accumulate content delta.
if !chunk.message.content.is_empty() {
accumulated.push_str(&chunk.message.content);
yield Ok(LlmStreamEvent::TextDelta(chunk.message.content));
}
if !chunk.message.role.is_empty() {
role = chunk.message.role;
}
// Ollama only attaches tool_calls on the final chunk.
if let Some(tcs) = chunk.message.tool_calls
&& !tcs.is_empty()
{
tool_calls = Some(tcs);
}
if chunk.done {
prompt_eval_count = chunk.prompt_eval_count;
eval_count = chunk.eval_count;
prompt_eval_duration = chunk.prompt_eval_duration;
eval_duration = chunk.eval_duration;
done_seen = true;
break;
}
}
Err(e) => {
log::warn!("malformed Ollama stream line: {} ({})", line_str, e);
}
}
}
if done_seen {
break;
}
}
// Emit the terminal Done event with the assembled message.
log_chat_metrics(
prompt_eval_count,
prompt_eval_duration,
eval_count,
eval_duration,
);
let message = ChatMessage {
role,
content: accumulated,
tool_calls,
images: None,
};
yield Ok(LlmStreamEvent::Done {
message,
prompt_eval_count,
eval_count,
});
};
Ok(Box::pin(stream))
}
async fn try_chat_with_tools(
&self,
base_url: &str,
@@ -857,6 +1026,14 @@ impl LlmClient for OllamaClient {
OllamaClient::chat_with_tools(self, messages, tools).await
}
async fn chat_with_tools_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
OllamaClient::chat_with_tools_stream(self, messages, tools).await
}
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
OllamaClient::generate_embeddings(self, texts).await
}
@@ -936,6 +1113,35 @@ struct OllamaChatResponse {
eval_duration: Option<u64>,
}
/// One chunk in the NDJSON stream from `/api/chat` with `stream: true`.
/// Early chunks carry content deltas in `message.content`; the final chunk
/// has `done: true`, optional `tool_calls`, and usage counters.
#[derive(Deserialize, Debug)]
struct OllamaStreamChunk {
#[serde(default)]
message: OllamaStreamMessage,
#[serde(default)]
done: bool,
#[serde(default)]
prompt_eval_count: Option<i32>,
#[serde(default)]
prompt_eval_duration: Option<u64>,
#[serde(default)]
eval_count: Option<i32>,
#[serde(default)]
eval_duration: Option<u64>,
}
#[derive(Deserialize, Debug, Default)]
struct OllamaStreamMessage {
#[serde(default)]
role: String,
#[serde(default)]
content: String,
#[serde(default)]
tool_calls: Option<Vec<crate::ai::llm_client::ToolCall>>,
}
#[derive(Deserialize)]
struct OllamaResponse {
response: String,

View File

@@ -12,8 +12,9 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::ai::llm_client::{
ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction,
ChatMessage, LlmClient, LlmStreamEvent, ModelCapabilities, Tool, ToolCall, ToolCallFunction,
};
use futures::stream::{BoxStream, StreamExt};
const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1";
const DEFAULT_EMBEDDING_MODEL: &str = "openai/text-embedding-3-small";
@@ -378,6 +379,220 @@ impl LlmClient for OpenRouterClient {
Ok((chat_msg, prompt_tokens, completion_tokens))
}
async fn chat_with_tools_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
let url = format!("{}/chat/completions", self.base_url);
let mut body = serde_json::Map::new();
body.insert("model".into(), Value::String(self.primary_model.clone()));
body.insert(
"messages".into(),
Value::Array(Self::messages_to_openai(&messages)),
);
body.insert("stream".into(), Value::Bool(true));
// Ask for usage data in the final chunk (OpenAI + OpenRouter
// both honor this options bag).
body.insert(
"stream_options".into(),
serde_json::json!({ "include_usage": true }),
);
if !tools.is_empty() {
body.insert(
"tools".into(),
serde_json::to_value(&tools).context("serializing tools")?,
);
}
for (k, v) in self.build_options() {
body.insert(k.into(), v);
}
let resp = self
.authed(self.client.post(&url))
.json(&Value::Object(body))
.send()
.await
.with_context(|| format!("POST {} failed", url))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
bail!("OpenRouter stream request failed: {} — {}", status, body);
}
// OpenAI-compat SSE stream. Each event is `data: <json>\n\n`, with
// `data: [DONE]` signalling completion. Tool calls arrive as
// `delta.tool_calls[i]` chunks that must be concatenated by index.
let byte_stream = resp.bytes_stream();
let stream = async_stream::stream! {
let mut byte_stream = byte_stream;
let mut buf: Vec<u8> = Vec::new();
let mut accumulated_content = String::new();
// tool call state: index -> (id, name, args_string)
let mut tool_state: std::collections::BTreeMap<
usize,
(Option<String>, Option<String>, String),
> = std::collections::BTreeMap::new();
let mut role = "assistant".to_string();
let mut prompt_tokens: Option<i32> = None;
let mut completion_tokens: Option<i32> = None;
let mut done_seen = false;
while let Some(chunk) = byte_stream.next().await {
let chunk = match chunk {
Ok(b) => b,
Err(e) => {
yield Err(anyhow!("stream read failed: {}", e));
return;
}
};
buf.extend_from_slice(&chunk);
// SSE frames are delimited by a blank line. Walk the buffer
// for "\n\n" markers; anything before them is a complete
// frame (possibly multi-line).
loop {
let Some(sep) = find_double_newline(&buf) else { break };
let frame = buf.drain(..sep + 2).collect::<Vec<_>>();
let frame_str = match std::str::from_utf8(&frame) {
Ok(s) => s,
Err(_) => continue,
};
// A frame is one or more lines; the payload is on data:
// lines. Ignore comments and other fields.
for line in frame_str.lines() {
let line = line.trim_end_matches('\r');
let payload = match line.strip_prefix("data: ") {
Some(p) => p,
None => continue,
};
if payload == "[DONE]" {
done_seen = true;
break;
}
let v: Value = match serde_json::from_str(payload) {
Ok(v) => v,
Err(e) => {
log::warn!(
"malformed OpenRouter SSE frame: {} ({})",
payload,
e
);
continue;
}
};
// Usage can arrive in a dedicated final frame with
// empty choices.
if let Some(usage) = v.get("usage") {
prompt_tokens = usage
.get("prompt_tokens")
.and_then(|n| n.as_i64())
.map(|n| n as i32);
completion_tokens = usage
.get("completion_tokens")
.and_then(|n| n.as_i64())
.map(|n| n as i32);
}
let Some(choices) = v.get("choices").and_then(|c| c.as_array())
else {
continue;
};
let Some(choice) = choices.first() else { continue };
let delta = match choice.get("delta") {
Some(d) => d,
None => continue,
};
if let Some(r) = delta.get("role").and_then(|v| v.as_str()) {
role = r.to_string();
}
if let Some(content) =
delta.get("content").and_then(|v| v.as_str())
&& !content.is_empty()
{
accumulated_content.push_str(content);
yield Ok(LlmStreamEvent::TextDelta(content.to_string()));
}
if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) {
for tc_delta in tcs {
let idx = tc_delta
.get("index")
.and_then(|n| n.as_u64())
.unwrap_or(0) as usize;
let entry = tool_state
.entry(idx)
.or_insert((None, None, String::new()));
if let Some(id) =
tc_delta.get("id").and_then(|v| v.as_str())
{
entry.0 = Some(id.to_string());
}
if let Some(func) = tc_delta.get("function") {
if let Some(name) =
func.get("name").and_then(|v| v.as_str())
{
entry.1 = Some(name.to_string());
}
if let Some(args) =
func.get("arguments").and_then(|v| v.as_str())
{
entry.2.push_str(args);
}
}
}
}
}
if done_seen {
break;
}
}
if done_seen {
break;
}
}
// Finalize tool calls: parse accumulated argument strings.
let tool_calls: Option<Vec<ToolCall>> = if tool_state.is_empty() {
None
} else {
let mut v = Vec::with_capacity(tool_state.len());
for (_idx, (id, name, args)) in tool_state {
let arguments: Value = if args.trim().is_empty() {
Value::Object(Default::default())
} else {
serde_json::from_str(&args).unwrap_or_else(|_| {
Value::Object(Default::default())
})
};
v.push(ToolCall {
id,
function: ToolCallFunction {
name: name.unwrap_or_default(),
arguments,
},
});
}
Some(v)
};
let message = ChatMessage {
role,
content: accumulated_content,
tool_calls,
images: None,
};
yield Ok(LlmStreamEvent::Done {
message,
prompt_eval_count: prompt_tokens,
eval_count: completion_tokens,
});
};
Ok(Box::pin(stream))
}
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let url = format!("{}/embeddings", self.base_url);
let body = json!({
@@ -473,6 +688,28 @@ impl LlmClient for OpenRouterClient {
}
}
/// Find the byte offset of the first `\n\n` (end of an SSE frame) in `buf`.
/// Returns the index of the first `\n` of the pair, so the full separator is
/// `buf[idx..=idx+1]`. Also handles `\r\n\r\n` since some servers emit it.
fn find_double_newline(buf: &[u8]) -> Option<usize> {
for i in 0..buf.len().saturating_sub(1) {
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
return Some(i);
}
// \r\n\r\n: the second \n of this pattern is at i+2; flag at i so the
// drain call (which consumes ..sep+2) takes exactly the frame.
if i + 3 < buf.len()
&& buf[i] == b'\r'
&& buf[i + 1] == b'\n'
&& buf[i + 2] == b'\r'
&& buf[i + 3] == b'\n'
{
return Some(i + 1);
}
}
None
}
/// Build a `data:` URL if the provided string is raw base64, otherwise pass it through.
fn image_to_data_url(img: &str) -> String {
if img.starts_with("data:") {

View File

@@ -1357,6 +1357,7 @@ fn main() -> std::io::Result<()> {
.service(ai::get_available_models_handler)
.service(ai::get_openrouter_models_handler)
.service(ai::chat_turn_handler)
.service(ai::chat_stream_handler)
.service(ai::chat_history_handler)
.service(ai::chat_rewind_handler)
.service(ai::rate_insight_handler)