feat(ai): streaming chat endpoint with live tool events

Add LlmClient::chat_with_tools_stream and SSE endpoint
POST /insights/chat/stream that emits text deltas, tool_call /
tool_result pairs, truncated notice, and a terminal done frame as the
agentic loop runs.

- Ollama: parses NDJSON from /api/chat stream, accumulates content
  deltas, emits Done with tool_calls from the final chunk.
- OpenRouter: parses OpenAI-compatible SSE, reassembles tool_call
  argument deltas by index, asks for stream_options.include_usage.
- InsightChatService spawns the loop on a tokio task, feeds events
  through an mpsc channel, persists training_messages at the end.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Cameron
2026-04-21 16:57:41 -04:00
parent c2bd3c08e1
commit 079cd4c5b9
9 changed files with 1071 additions and 9 deletions

View File

@@ -12,8 +12,9 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::ai::llm_client::{
ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction,
ChatMessage, LlmClient, LlmStreamEvent, ModelCapabilities, Tool, ToolCall, ToolCallFunction,
};
use futures::stream::{BoxStream, StreamExt};
const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1";
const DEFAULT_EMBEDDING_MODEL: &str = "openai/text-embedding-3-small";
@@ -378,6 +379,220 @@ impl LlmClient for OpenRouterClient {
Ok((chat_msg, prompt_tokens, completion_tokens))
}
async fn chat_with_tools_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
let url = format!("{}/chat/completions", self.base_url);
let mut body = serde_json::Map::new();
body.insert("model".into(), Value::String(self.primary_model.clone()));
body.insert(
"messages".into(),
Value::Array(Self::messages_to_openai(&messages)),
);
body.insert("stream".into(), Value::Bool(true));
// Ask for usage data in the final chunk (OpenAI + OpenRouter
// both honor this options bag).
body.insert(
"stream_options".into(),
serde_json::json!({ "include_usage": true }),
);
if !tools.is_empty() {
body.insert(
"tools".into(),
serde_json::to_value(&tools).context("serializing tools")?,
);
}
for (k, v) in self.build_options() {
body.insert(k.into(), v);
}
let resp = self
.authed(self.client.post(&url))
.json(&Value::Object(body))
.send()
.await
.with_context(|| format!("POST {} failed", url))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
bail!("OpenRouter stream request failed: {} — {}", status, body);
}
// OpenAI-compat SSE stream. Each event is `data: <json>\n\n`, with
// `data: [DONE]` signalling completion. Tool calls arrive as
// `delta.tool_calls[i]` chunks that must be concatenated by index.
let byte_stream = resp.bytes_stream();
let stream = async_stream::stream! {
let mut byte_stream = byte_stream;
let mut buf: Vec<u8> = Vec::new();
let mut accumulated_content = String::new();
// tool call state: index -> (id, name, args_string)
let mut tool_state: std::collections::BTreeMap<
usize,
(Option<String>, Option<String>, String),
> = std::collections::BTreeMap::new();
let mut role = "assistant".to_string();
let mut prompt_tokens: Option<i32> = None;
let mut completion_tokens: Option<i32> = None;
let mut done_seen = false;
while let Some(chunk) = byte_stream.next().await {
let chunk = match chunk {
Ok(b) => b,
Err(e) => {
yield Err(anyhow!("stream read failed: {}", e));
return;
}
};
buf.extend_from_slice(&chunk);
// SSE frames are delimited by a blank line. Walk the buffer
// for "\n\n" markers; anything before them is a complete
// frame (possibly multi-line).
loop {
let Some(sep) = find_double_newline(&buf) else { break };
let frame = buf.drain(..sep + 2).collect::<Vec<_>>();
let frame_str = match std::str::from_utf8(&frame) {
Ok(s) => s,
Err(_) => continue,
};
// A frame is one or more lines; the payload is on data:
// lines. Ignore comments and other fields.
for line in frame_str.lines() {
let line = line.trim_end_matches('\r');
let payload = match line.strip_prefix("data: ") {
Some(p) => p,
None => continue,
};
if payload == "[DONE]" {
done_seen = true;
break;
}
let v: Value = match serde_json::from_str(payload) {
Ok(v) => v,
Err(e) => {
log::warn!(
"malformed OpenRouter SSE frame: {} ({})",
payload,
e
);
continue;
}
};
// Usage can arrive in a dedicated final frame with
// empty choices.
if let Some(usage) = v.get("usage") {
prompt_tokens = usage
.get("prompt_tokens")
.and_then(|n| n.as_i64())
.map(|n| n as i32);
completion_tokens = usage
.get("completion_tokens")
.and_then(|n| n.as_i64())
.map(|n| n as i32);
}
let Some(choices) = v.get("choices").and_then(|c| c.as_array())
else {
continue;
};
let Some(choice) = choices.first() else { continue };
let delta = match choice.get("delta") {
Some(d) => d,
None => continue,
};
if let Some(r) = delta.get("role").and_then(|v| v.as_str()) {
role = r.to_string();
}
if let Some(content) =
delta.get("content").and_then(|v| v.as_str())
&& !content.is_empty()
{
accumulated_content.push_str(content);
yield Ok(LlmStreamEvent::TextDelta(content.to_string()));
}
if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) {
for tc_delta in tcs {
let idx = tc_delta
.get("index")
.and_then(|n| n.as_u64())
.unwrap_or(0) as usize;
let entry = tool_state
.entry(idx)
.or_insert((None, None, String::new()));
if let Some(id) =
tc_delta.get("id").and_then(|v| v.as_str())
{
entry.0 = Some(id.to_string());
}
if let Some(func) = tc_delta.get("function") {
if let Some(name) =
func.get("name").and_then(|v| v.as_str())
{
entry.1 = Some(name.to_string());
}
if let Some(args) =
func.get("arguments").and_then(|v| v.as_str())
{
entry.2.push_str(args);
}
}
}
}
}
if done_seen {
break;
}
}
if done_seen {
break;
}
}
// Finalize tool calls: parse accumulated argument strings.
let tool_calls: Option<Vec<ToolCall>> = if tool_state.is_empty() {
None
} else {
let mut v = Vec::with_capacity(tool_state.len());
for (_idx, (id, name, args)) in tool_state {
let arguments: Value = if args.trim().is_empty() {
Value::Object(Default::default())
} else {
serde_json::from_str(&args).unwrap_or_else(|_| {
Value::Object(Default::default())
})
};
v.push(ToolCall {
id,
function: ToolCallFunction {
name: name.unwrap_or_default(),
arguments,
},
});
}
Some(v)
};
let message = ChatMessage {
role,
content: accumulated_content,
tool_calls,
images: None,
};
yield Ok(LlmStreamEvent::Done {
message,
prompt_eval_count: prompt_tokens,
eval_count: completion_tokens,
});
};
Ok(Box::pin(stream))
}
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let url = format!("{}/embeddings", self.base_url);
let body = json!({
@@ -473,6 +688,28 @@ impl LlmClient for OpenRouterClient {
}
}
/// Find the byte offset of the first `\n\n` (end of an SSE frame) in `buf`.
/// Returns the index of the first `\n` of the pair, so the full separator is
/// `buf[idx..=idx+1]`. Also handles `\r\n\r\n` since some servers emit it.
fn find_double_newline(buf: &[u8]) -> Option<usize> {
for i in 0..buf.len().saturating_sub(1) {
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
return Some(i);
}
// \r\n\r\n: the second \n of this pattern is at i+2; flag at i so the
// drain call (which consumes ..sep+2) takes exactly the frame.
if i + 3 < buf.len()
&& buf[i] == b'\r'
&& buf[i + 1] == b'\n'
&& buf[i + 2] == b'\r'
&& buf[i + 3] == b'\n'
{
return Some(i + 1);
}
}
None
}
/// Build a `data:` URL if the provided string is raw base64, otherwise pass it through.
fn image_to_data_url(img: &str) -> String {
if img.starts_with("data:") {