feat(ai): streaming chat endpoint with live tool events
Add LlmClient::chat_with_tools_stream and SSE endpoint POST /insights/chat/stream that emits text deltas, tool_call / tool_result pairs, truncated notice, and a terminal done frame as the agentic loop runs. - Ollama: parses NDJSON from /api/chat stream, accumulates content deltas, emits Done with tool_calls from the final chunk. - OpenRouter: parses OpenAI-compatible SSE, reassembles tool_call argument deltas by index, asks for stream_options.include_usage. - InsightChatService spawns the loop on a tokio task, feeds events through an mpsc channel, persists training_messages at the end. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -12,8 +12,9 @@ use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::ai::llm_client::{
|
||||
ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction,
|
||||
ChatMessage, LlmClient, LlmStreamEvent, ModelCapabilities, Tool, ToolCall, ToolCallFunction,
|
||||
};
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||
const DEFAULT_EMBEDDING_MODEL: &str = "openai/text-embedding-3-small";
|
||||
@@ -378,6 +379,220 @@ impl LlmClient for OpenRouterClient {
|
||||
Ok((chat_msg, prompt_tokens, completion_tokens))
|
||||
}
|
||||
|
||||
async fn chat_with_tools_stream(
|
||||
&self,
|
||||
messages: Vec<ChatMessage>,
|
||||
tools: Vec<Tool>,
|
||||
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
let mut body = serde_json::Map::new();
|
||||
body.insert("model".into(), Value::String(self.primary_model.clone()));
|
||||
body.insert(
|
||||
"messages".into(),
|
||||
Value::Array(Self::messages_to_openai(&messages)),
|
||||
);
|
||||
body.insert("stream".into(), Value::Bool(true));
|
||||
// Ask for usage data in the final chunk (OpenAI + OpenRouter
|
||||
// both honor this options bag).
|
||||
body.insert(
|
||||
"stream_options".into(),
|
||||
serde_json::json!({ "include_usage": true }),
|
||||
);
|
||||
if !tools.is_empty() {
|
||||
body.insert(
|
||||
"tools".into(),
|
||||
serde_json::to_value(&tools).context("serializing tools")?,
|
||||
);
|
||||
}
|
||||
for (k, v) in self.build_options() {
|
||||
body.insert(k.into(), v);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.authed(self.client.post(&url))
|
||||
.json(&Value::Object(body))
|
||||
.send()
|
||||
.await
|
||||
.with_context(|| format!("POST {} failed", url))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
bail!("OpenRouter stream request failed: {} — {}", status, body);
|
||||
}
|
||||
|
||||
// OpenAI-compat SSE stream. Each event is `data: <json>\n\n`, with
|
||||
// `data: [DONE]` signalling completion. Tool calls arrive as
|
||||
// `delta.tool_calls[i]` chunks that must be concatenated by index.
|
||||
let byte_stream = resp.bytes_stream();
|
||||
let stream = async_stream::stream! {
|
||||
let mut byte_stream = byte_stream;
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut accumulated_content = String::new();
|
||||
// tool call state: index -> (id, name, args_string)
|
||||
let mut tool_state: std::collections::BTreeMap<
|
||||
usize,
|
||||
(Option<String>, Option<String>, String),
|
||||
> = std::collections::BTreeMap::new();
|
||||
let mut role = "assistant".to_string();
|
||||
let mut prompt_tokens: Option<i32> = None;
|
||||
let mut completion_tokens: Option<i32> = None;
|
||||
let mut done_seen = false;
|
||||
|
||||
while let Some(chunk) = byte_stream.next().await {
|
||||
let chunk = match chunk {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
yield Err(anyhow!("stream read failed: {}", e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
buf.extend_from_slice(&chunk);
|
||||
|
||||
// SSE frames are delimited by a blank line. Walk the buffer
|
||||
// for "\n\n" markers; anything before them is a complete
|
||||
// frame (possibly multi-line).
|
||||
loop {
|
||||
let Some(sep) = find_double_newline(&buf) else { break };
|
||||
let frame = buf.drain(..sep + 2).collect::<Vec<_>>();
|
||||
let frame_str = match std::str::from_utf8(&frame) {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
};
|
||||
// A frame is one or more lines; the payload is on data:
|
||||
// lines. Ignore comments and other fields.
|
||||
for line in frame_str.lines() {
|
||||
let line = line.trim_end_matches('\r');
|
||||
let payload = match line.strip_prefix("data: ") {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
if payload == "[DONE]" {
|
||||
done_seen = true;
|
||||
break;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(payload) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"malformed OpenRouter SSE frame: {} ({})",
|
||||
payload,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Usage can arrive in a dedicated final frame with
|
||||
// empty choices.
|
||||
if let Some(usage) = v.get("usage") {
|
||||
prompt_tokens = usage
|
||||
.get("prompt_tokens")
|
||||
.and_then(|n| n.as_i64())
|
||||
.map(|n| n as i32);
|
||||
completion_tokens = usage
|
||||
.get("completion_tokens")
|
||||
.and_then(|n| n.as_i64())
|
||||
.map(|n| n as i32);
|
||||
}
|
||||
|
||||
let Some(choices) = v.get("choices").and_then(|c| c.as_array())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
let Some(choice) = choices.first() else { continue };
|
||||
let delta = match choice.get("delta") {
|
||||
Some(d) => d,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(r) = delta.get("role").and_then(|v| v.as_str()) {
|
||||
role = r.to_string();
|
||||
}
|
||||
if let Some(content) =
|
||||
delta.get("content").and_then(|v| v.as_str())
|
||||
&& !content.is_empty()
|
||||
{
|
||||
accumulated_content.push_str(content);
|
||||
yield Ok(LlmStreamEvent::TextDelta(content.to_string()));
|
||||
}
|
||||
if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) {
|
||||
for tc_delta in tcs {
|
||||
let idx = tc_delta
|
||||
.get("index")
|
||||
.and_then(|n| n.as_u64())
|
||||
.unwrap_or(0) as usize;
|
||||
let entry = tool_state
|
||||
.entry(idx)
|
||||
.or_insert((None, None, String::new()));
|
||||
if let Some(id) =
|
||||
tc_delta.get("id").and_then(|v| v.as_str())
|
||||
{
|
||||
entry.0 = Some(id.to_string());
|
||||
}
|
||||
if let Some(func) = tc_delta.get("function") {
|
||||
if let Some(name) =
|
||||
func.get("name").and_then(|v| v.as_str())
|
||||
{
|
||||
entry.1 = Some(name.to_string());
|
||||
}
|
||||
if let Some(args) =
|
||||
func.get("arguments").and_then(|v| v.as_str())
|
||||
{
|
||||
entry.2.push_str(args);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if done_seen {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if done_seen {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize tool calls: parse accumulated argument strings.
|
||||
let tool_calls: Option<Vec<ToolCall>> = if tool_state.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut v = Vec::with_capacity(tool_state.len());
|
||||
for (_idx, (id, name, args)) in tool_state {
|
||||
let arguments: Value = if args.trim().is_empty() {
|
||||
Value::Object(Default::default())
|
||||
} else {
|
||||
serde_json::from_str(&args).unwrap_or_else(|_| {
|
||||
Value::Object(Default::default())
|
||||
})
|
||||
};
|
||||
v.push(ToolCall {
|
||||
id,
|
||||
function: ToolCallFunction {
|
||||
name: name.unwrap_or_default(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
Some(v)
|
||||
};
|
||||
|
||||
let message = ChatMessage {
|
||||
role,
|
||||
content: accumulated_content,
|
||||
tool_calls,
|
||||
images: None,
|
||||
};
|
||||
yield Ok(LlmStreamEvent::Done {
|
||||
message,
|
||||
prompt_eval_count: prompt_tokens,
|
||||
eval_count: completion_tokens,
|
||||
});
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
||||
let url = format!("{}/embeddings", self.base_url);
|
||||
let body = json!({
|
||||
@@ -473,6 +688,28 @@ impl LlmClient for OpenRouterClient {
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the byte offset of the first `\n\n` (end of an SSE frame) in `buf`.
|
||||
/// Returns the index of the first `\n` of the pair, so the full separator is
|
||||
/// `buf[idx..=idx+1]`. Also handles `\r\n\r\n` since some servers emit it.
|
||||
fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
||||
for i in 0..buf.len().saturating_sub(1) {
|
||||
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
|
||||
return Some(i);
|
||||
}
|
||||
// \r\n\r\n: the second \n of this pattern is at i+2; flag at i so the
|
||||
// drain call (which consumes ..sep+2) takes exactly the frame.
|
||||
if i + 3 < buf.len()
|
||||
&& buf[i] == b'\r'
|
||||
&& buf[i + 1] == b'\n'
|
||||
&& buf[i + 2] == b'\r'
|
||||
&& buf[i + 3] == b'\n'
|
||||
{
|
||||
return Some(i + 1);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Build a `data:` URL if the provided string is raw base64, otherwise pass it through.
|
||||
fn image_to_data_url(img: &str) -> String {
|
||||
if img.starts_with("data:") {
|
||||
|
||||
Reference in New Issue
Block a user