feat(ai): streaming chat endpoint with live tool events
Add LlmClient::chat_with_tools_stream and SSE endpoint POST /insights/chat/stream that emits text deltas, tool_call / tool_result pairs, truncated notice, and a terminal done frame as the agentic loop runs. - Ollama: parses NDJSON from /api/chat stream, accumulates content deltas, emits Done with tool_calls from the final chunk. - OpenRouter: parses OpenAI-compatible SSE, reassembles tool_call argument deltas by index, asks for stream_options.include_usage. - InsightChatService spawns the loop on a tokio task, feeds events through an mpsc channel, persists training_messages at the end. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -7,13 +7,14 @@ use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
|
||||
use crate::ai::insight_generator::InsightGenerator;
|
||||
use crate::ai::llm_client::{ChatMessage, LlmClient};
|
||||
use crate::ai::llm_client::{ChatMessage, LlmClient, LlmStreamEvent};
|
||||
use crate::ai::ollama::OllamaClient;
|
||||
use crate::ai::openrouter::OpenRouterClient;
|
||||
use crate::database::InsightDao;
|
||||
use crate::database::models::InsertPhotoInsight;
|
||||
use crate::otel::global_tracer;
|
||||
use crate::utils::normalize_path;
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
|
||||
const DEFAULT_MAX_ITERATIONS: usize = 6;
|
||||
const DEFAULT_NUM_CTX: i32 = 8192;
|
||||
@@ -583,6 +584,442 @@ impl InsightChatService {
|
||||
.map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Streaming variant of `chat_turn`. Emits user-facing events as the
|
||||
/// conversation progresses: iteration starts, tool dispatch + result,
|
||||
/// text deltas from the final assistant reply, and a terminal `Done`
|
||||
/// frame. Persistence happens inside the stream after the loop ends.
|
||||
///
|
||||
/// The stream takes ownership of the service via `Arc<Self>` (passed by
|
||||
/// the caller) so it can live past the handler's await boundary.
|
||||
pub fn chat_turn_stream(
|
||||
self: Arc<Self>,
|
||||
req: ChatTurnRequest,
|
||||
) -> BoxStream<'static, ChatStreamEvent> {
|
||||
let svc = self;
|
||||
let s = async_stream::stream! {
|
||||
match svc.chat_turn_stream_inner(req, |ev| Ok(ev)).await {
|
||||
Ok(mut rx) => {
|
||||
while let Some(ev) = rx.recv().await {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
yield ChatStreamEvent::Error(format!("{}", e));
|
||||
}
|
||||
}
|
||||
};
|
||||
Box::pin(s)
|
||||
}
|
||||
|
||||
/// Internal: drives the streaming loop on a background task, returning
|
||||
/// a receiver the caller drains. Keeping the work on a spawned task
|
||||
/// decouples the HTTP request lifetime from the chat execution, which
|
||||
/// matters because the chat may run longer than any single network hop
|
||||
/// and we want clean cancellation semantics via the channel close.
|
||||
async fn chat_turn_stream_inner<F>(
|
||||
self: Arc<Self>,
|
||||
req: ChatTurnRequest,
|
||||
_ev_mapper: F,
|
||||
) -> Result<tokio::sync::mpsc::Receiver<ChatStreamEvent>>
|
||||
where
|
||||
F: Fn(ChatStreamEvent) -> Result<ChatStreamEvent> + Send + 'static,
|
||||
{
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<ChatStreamEvent>(64);
|
||||
let svc = self.clone();
|
||||
tokio::spawn(async move {
|
||||
let result = svc.run_streaming_turn(req, tx.clone()).await;
|
||||
if let Err(e) = result {
|
||||
let _ = tx.send(ChatStreamEvent::Error(format!("{}", e))).await;
|
||||
}
|
||||
});
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
async fn run_streaming_turn(
|
||||
self: Arc<Self>,
|
||||
req: ChatTurnRequest,
|
||||
tx: tokio::sync::mpsc::Sender<ChatStreamEvent>,
|
||||
) -> Result<()> {
|
||||
if req.user_message.trim().is_empty() {
|
||||
bail!("user_message must not be empty");
|
||||
}
|
||||
if req.user_message.len() > 8192 {
|
||||
bail!("user_message exceeds 8192 chars");
|
||||
}
|
||||
let normalized = normalize_path(&req.file_path);
|
||||
|
||||
let lock_key = (req.library_id, normalized.clone());
|
||||
let entry_lock = {
|
||||
let mut locks = self.chat_locks.lock().await;
|
||||
locks
|
||||
.entry(lock_key.clone())
|
||||
.or_insert_with(|| Arc::new(TokioMutex::new(())))
|
||||
.clone()
|
||||
};
|
||||
let _guard = entry_lock.lock().await;
|
||||
|
||||
let insight = {
|
||||
let cx = opentelemetry::Context::new();
|
||||
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||
dao.get_insight(&cx, &normalized)
|
||||
.map_err(|e| anyhow!("failed to load insight: {:?}", e))?
|
||||
.ok_or_else(|| anyhow!("no insight found for path"))?
|
||||
};
|
||||
let raw_history = insight
|
||||
.training_messages
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
anyhow!("insight has no chat history; regenerate this insight in agentic mode")
|
||||
})?
|
||||
.clone();
|
||||
let mut messages: Vec<ChatMessage> = serde_json::from_str(&raw_history)
|
||||
.map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?;
|
||||
|
||||
// Backend selection — same rules as non-streaming chat_turn.
|
||||
let stored_backend = insight.backend.clone();
|
||||
let effective_backend = req
|
||||
.backend
|
||||
.as_deref()
|
||||
.map(|s| s.trim().to_lowercase())
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(|| stored_backend.clone());
|
||||
if !matches!(effective_backend.as_str(), "local" | "hybrid") {
|
||||
bail!(
|
||||
"unknown backend '{}'; expected 'local' or 'hybrid'",
|
||||
effective_backend
|
||||
);
|
||||
}
|
||||
if stored_backend == "local" && effective_backend == "hybrid" {
|
||||
bail!(
|
||||
"switching from local to hybrid mid-chat isn't supported yet; \
|
||||
regenerate the insight in hybrid mode if you want OpenRouter chat"
|
||||
);
|
||||
}
|
||||
let is_hybrid = effective_backend == "hybrid";
|
||||
|
||||
let max_iterations = req
|
||||
.max_iterations
|
||||
.unwrap_or(DEFAULT_MAX_ITERATIONS)
|
||||
.clamp(1, env_max_iterations());
|
||||
|
||||
let stored_model = insight.model_version.clone();
|
||||
let custom_model = req
|
||||
.model
|
||||
.clone()
|
||||
.or_else(|| Some(stored_model.clone()))
|
||||
.filter(|m| !m.is_empty());
|
||||
|
||||
let mut ollama_client = self.ollama.clone();
|
||||
let mut openrouter_client: Option<OpenRouterClient> = None;
|
||||
|
||||
if is_hybrid {
|
||||
let arc = self.openrouter.as_ref().ok_or_else(|| {
|
||||
anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured")
|
||||
})?;
|
||||
let mut c: OpenRouterClient = (**arc).clone();
|
||||
if let Some(ref m) = custom_model {
|
||||
c.primary_model = m.clone();
|
||||
}
|
||||
if req.temperature.is_some()
|
||||
|| req.top_p.is_some()
|
||||
|| req.top_k.is_some()
|
||||
|| req.min_p.is_some()
|
||||
{
|
||||
c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
|
||||
}
|
||||
if let Some(ctx) = req.num_ctx {
|
||||
c.set_num_ctx(Some(ctx));
|
||||
}
|
||||
openrouter_client = Some(c);
|
||||
} else {
|
||||
if let Some(ref m) = custom_model
|
||||
&& m != &self.ollama.primary_model
|
||||
{
|
||||
ollama_client = OllamaClient::new(
|
||||
self.ollama.primary_url.clone(),
|
||||
self.ollama.fallback_url.clone(),
|
||||
m.clone(),
|
||||
Some(m.clone()),
|
||||
);
|
||||
}
|
||||
if req.temperature.is_some()
|
||||
|| req.top_p.is_some()
|
||||
|| req.top_k.is_some()
|
||||
|| req.min_p.is_some()
|
||||
{
|
||||
ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
|
||||
}
|
||||
if let Some(ctx) = req.num_ctx {
|
||||
ollama_client.set_num_ctx(Some(ctx));
|
||||
}
|
||||
}
|
||||
|
||||
let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client {
|
||||
c
|
||||
} else {
|
||||
&ollama_client
|
||||
};
|
||||
let model_used = chat_backend.primary_model().to_string();
|
||||
|
||||
// Tool set.
|
||||
let local_first_user_has_image = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "user")
|
||||
.and_then(|m| m.images.as_ref())
|
||||
.map(|imgs| !imgs.is_empty())
|
||||
.unwrap_or(false);
|
||||
let offer_describe_tool = !is_hybrid && local_first_user_has_image;
|
||||
let tools = InsightGenerator::build_tool_definitions(offer_describe_tool);
|
||||
|
||||
let image_base64: Option<String> = if offer_describe_tool {
|
||||
self.generator.load_image_as_base64(&normalized).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Truncate before appending the new user turn.
|
||||
let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize)
|
||||
.saturating_sub(RESPONSE_HEADROOM_TOKENS);
|
||||
let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN);
|
||||
let truncated = apply_context_budget(&mut messages, budget_bytes);
|
||||
if truncated {
|
||||
let _ = tx.send(ChatStreamEvent::Truncated).await;
|
||||
}
|
||||
|
||||
messages.push(ChatMessage::user(req.user_message.clone()));
|
||||
|
||||
let mut tool_calls_made = 0usize;
|
||||
let mut iterations_used = 0usize;
|
||||
let mut last_prompt_eval_count: Option<i32> = None;
|
||||
let mut last_eval_count: Option<i32> = None;
|
||||
let mut final_content = String::new();
|
||||
|
||||
for iteration in 0..max_iterations {
|
||||
iterations_used = iteration + 1;
|
||||
let _ = tx
|
||||
.send(ChatStreamEvent::IterationStart {
|
||||
n: iterations_used,
|
||||
max: max_iterations,
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut stream = chat_backend
|
||||
.chat_with_tools_stream(messages.clone(), tools.clone())
|
||||
.await?;
|
||||
|
||||
let mut final_message: Option<ChatMessage> = None;
|
||||
while let Some(ev) = stream.next().await {
|
||||
let ev = ev?;
|
||||
match ev {
|
||||
LlmStreamEvent::TextDelta(delta) => {
|
||||
let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await;
|
||||
}
|
||||
LlmStreamEvent::Done {
|
||||
message,
|
||||
prompt_eval_count,
|
||||
eval_count,
|
||||
} => {
|
||||
last_prompt_eval_count = prompt_eval_count;
|
||||
last_eval_count = eval_count;
|
||||
final_message = Some(message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut response =
|
||||
final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?;
|
||||
|
||||
// Normalize non-object tool arguments (same as non-streaming path).
|
||||
if let Some(ref mut tcs) = response.tool_calls {
|
||||
for tc in tcs.iter_mut() {
|
||||
if !tc.function.arguments.is_object() {
|
||||
tc.function.arguments = serde_json::Value::Object(Default::default());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(response.clone());
|
||||
|
||||
if let Some(ref tool_calls) = response.tool_calls
|
||||
&& !tool_calls.is_empty()
|
||||
{
|
||||
for (i, tool_call) in tool_calls.iter().enumerate() {
|
||||
tool_calls_made += 1;
|
||||
let call_index = tool_calls_made - 1;
|
||||
let _ = tx
|
||||
.send(ChatStreamEvent::ToolCall {
|
||||
index: call_index,
|
||||
name: tool_call.function.name.clone(),
|
||||
arguments: tool_call.function.arguments.clone(),
|
||||
})
|
||||
.await;
|
||||
let cx = opentelemetry::Context::new();
|
||||
let result = self
|
||||
.generator
|
||||
.execute_tool(
|
||||
&tool_call.function.name,
|
||||
&tool_call.function.arguments,
|
||||
&ollama_client,
|
||||
&image_base64,
|
||||
&normalized,
|
||||
&cx,
|
||||
)
|
||||
.await;
|
||||
let (result_preview, result_truncated) = truncate_tool_result(&result);
|
||||
let _ = tx
|
||||
.send(ChatStreamEvent::ToolResult {
|
||||
index: call_index,
|
||||
name: tool_call.function.name.clone(),
|
||||
result: result_preview,
|
||||
result_truncated,
|
||||
})
|
||||
.await;
|
||||
messages.push(ChatMessage::tool_result(result));
|
||||
let _ = i; // reserved for per-call ordering if needed
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
final_content = response.content;
|
||||
break;
|
||||
}
|
||||
|
||||
if final_content.is_empty() {
|
||||
messages.push(ChatMessage::user(
|
||||
"Please write your final answer now without calling any more tools.",
|
||||
));
|
||||
let mut stream = chat_backend
|
||||
.chat_with_tools_stream(messages.clone(), vec![])
|
||||
.await?;
|
||||
let mut final_message: Option<ChatMessage> = None;
|
||||
while let Some(ev) = stream.next().await {
|
||||
let ev = ev?;
|
||||
match ev {
|
||||
LlmStreamEvent::TextDelta(delta) => {
|
||||
let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await;
|
||||
}
|
||||
LlmStreamEvent::Done {
|
||||
message,
|
||||
prompt_eval_count,
|
||||
eval_count,
|
||||
} => {
|
||||
last_prompt_eval_count = prompt_eval_count;
|
||||
last_eval_count = eval_count;
|
||||
final_message = Some(message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let final_response =
|
||||
final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?;
|
||||
final_content = final_response.content.clone();
|
||||
messages.push(final_response);
|
||||
}
|
||||
|
||||
// Persist.
|
||||
let json = serde_json::to_string(&messages)
|
||||
.map_err(|e| anyhow!("failed to serialize chat history: {}", e))?;
|
||||
|
||||
let mut amended_insight_id: Option<i32> = None;
|
||||
if req.amend {
|
||||
let title_prompt = format!(
|
||||
"Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\
|
||||
Capture the key moment or theme. Return ONLY the title, nothing else.",
|
||||
final_content
|
||||
);
|
||||
let title_raw = chat_backend
|
||||
.generate(
|
||||
&title_prompt,
|
||||
Some(
|
||||
"You are my long term memory assistant. Use only the information provided. Do not invent details.",
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
let title = title_raw.trim().trim_matches('"').to_string();
|
||||
|
||||
let new_row = InsertPhotoInsight {
|
||||
library_id: req.library_id,
|
||||
file_path: normalized.clone(),
|
||||
title,
|
||||
summary: final_content.clone(),
|
||||
generated_at: Utc::now().timestamp(),
|
||||
model_version: model_used.clone(),
|
||||
is_current: true,
|
||||
training_messages: Some(json),
|
||||
backend: effective_backend.clone(),
|
||||
};
|
||||
let cx = opentelemetry::Context::new();
|
||||
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||
let stored = dao
|
||||
.store_insight(&cx, new_row)
|
||||
.map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?;
|
||||
amended_insight_id = Some(stored.id);
|
||||
} else {
|
||||
let cx = opentelemetry::Context::new();
|
||||
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||
dao.update_training_messages(&cx, req.library_id, &normalized, &json)
|
||||
.map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?;
|
||||
}
|
||||
|
||||
let _ = tx
|
||||
.send(ChatStreamEvent::Done {
|
||||
tool_calls_made,
|
||||
iterations_used,
|
||||
truncated,
|
||||
prompt_eval_count: last_prompt_eval_count,
|
||||
eval_count: last_eval_count,
|
||||
amended_insight_id,
|
||||
backend_used: effective_backend,
|
||||
model_used,
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Events emitted by `chat_turn_stream`. One stream per turn; ends after
|
||||
/// `Done` or `Error`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ChatStreamEvent {
|
||||
/// Starting iteration `n` of up to `max` (1-based).
|
||||
IterationStart { n: usize, max: usize },
|
||||
/// History was trimmed to fit the context budget before the turn ran.
|
||||
/// Emitted at most once, before any tool or text events.
|
||||
Truncated,
|
||||
/// Incremental content from the final assistant reply. Concatenate to
|
||||
/// reconstruct the reply body. Tool-dispatch turns don't produce these.
|
||||
TextDelta(String),
|
||||
/// The model requested this tool call. Emitted just before execution.
|
||||
/// `index` is a monotonically-increasing counter across the turn so the
|
||||
/// client can pair `ToolCall` with its matching `ToolResult`.
|
||||
ToolCall {
|
||||
index: usize,
|
||||
name: String,
|
||||
arguments: serde_json::Value,
|
||||
},
|
||||
/// The tool finished; `result` is the (possibly truncated) output.
|
||||
ToolResult {
|
||||
index: usize,
|
||||
name: String,
|
||||
result: String,
|
||||
result_truncated: bool,
|
||||
},
|
||||
/// Terminal success event with counters + persistence result.
|
||||
Done {
|
||||
tool_calls_made: usize,
|
||||
iterations_used: usize,
|
||||
truncated: bool,
|
||||
prompt_eval_count: Option<i32>,
|
||||
eval_count: Option<i32>,
|
||||
amended_insight_id: Option<i32>,
|
||||
backend_used: String,
|
||||
model_used: String,
|
||||
},
|
||||
/// Terminal failure event. No further events follow.
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Is this raw message visible in the rendered transcript? Must match
|
||||
|
||||
Reference in New Issue
Block a user