feat(ai): streaming chat endpoint with live tool events
Add LlmClient::chat_with_tools_stream and SSE endpoint POST /insights/chat/stream that emits text deltas, tool_call / tool_result pairs, truncated notice, and a terminal done frame as the agentic loop runs. - Ollama: parses NDJSON from /api/chat stream, accumulates content deltas, emits Done with tool_calls from the final chunk. - OpenRouter: parses OpenAI-compatible SSE, reassembles tool_call argument deltas by index, asks for stream_options.include_usage. - InsightChatService spawns the loop on a tokio task, feeds events through an mpsc channel, persists training_messages at the end. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
40
Cargo.lock
generated
40
Cargo.lock
generated
@@ -486,6 +486,28 @@ version = "0.7.6"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-stream"
|
||||||
|
version = "0.3.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||||
|
dependencies = [
|
||||||
|
"async-stream-impl",
|
||||||
|
"futures-core",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-stream-impl"
|
||||||
|
version = "0.3.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-trait"
|
name = "async-trait"
|
||||||
version = "0.1.89"
|
version = "0.1.89"
|
||||||
@@ -1843,10 +1865,12 @@ dependencies = [
|
|||||||
"actix-web",
|
"actix-web",
|
||||||
"actix-web-prom",
|
"actix-web-prom",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64",
|
"base64",
|
||||||
"bcrypt",
|
"bcrypt",
|
||||||
"blake3",
|
"blake3",
|
||||||
|
"bytes",
|
||||||
"chrono",
|
"chrono",
|
||||||
"clap",
|
"clap",
|
||||||
"diesel",
|
"diesel",
|
||||||
@@ -1878,6 +1902,7 @@ dependencies = [
|
|||||||
"serde_json",
|
"serde_json",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
"urlencoding",
|
"urlencoding",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
"zerocopy",
|
"zerocopy",
|
||||||
@@ -3125,12 +3150,14 @@ dependencies = [
|
|||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-native-tls",
|
"tokio-native-tls",
|
||||||
|
"tokio-util",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"url",
|
"url",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"wasm-bindgen-futures",
|
"wasm-bindgen-futures",
|
||||||
|
"wasm-streams",
|
||||||
"web-sys",
|
"web-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -4219,6 +4246,19 @@ dependencies = [
|
|||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-streams"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"js-sys",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"wasm-bindgen-futures",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "web-sys"
|
name = "web-sys"
|
||||||
version = "0.3.77"
|
version = "0.3.77"
|
||||||
|
|||||||
@@ -49,7 +49,10 @@ opentelemetry-appender-log = "0.31.0"
|
|||||||
tempfile = "3.20.0"
|
tempfile = "3.20.0"
|
||||||
regex = "1.11.1"
|
regex = "1.11.1"
|
||||||
exif = { package = "kamadak-exif", version = "0.6.1" }
|
exif = { package = "kamadak-exif", version = "0.6.1" }
|
||||||
reqwest = { version = "0.12", features = ["json"] }
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||||
|
async-stream = "0.3"
|
||||||
|
tokio-util = { version = "0.7", features = ["io"] }
|
||||||
|
bytes = "1"
|
||||||
urlencoding = "2.1"
|
urlencoding = "2.1"
|
||||||
zerocopy = "0.8"
|
zerocopy = "0.8"
|
||||||
ical = "0.11"
|
ical = "0.11"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use opentelemetry::KeyValue;
|
|||||||
use opentelemetry::trace::{Span, Status, Tracer};
|
use opentelemetry::trace::{Span, Status, Tracer};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::ai::insight_chat::ChatTurnRequest;
|
use crate::ai::insight_chat::{ChatStreamEvent, ChatTurnRequest};
|
||||||
use crate::ai::{InsightGenerator, ModelCapabilities, OllamaClient};
|
use crate::ai::{InsightGenerator, ModelCapabilities, OllamaClient};
|
||||||
use crate::data::Claims;
|
use crate::data::Claims;
|
||||||
use crate::database::{ExifDao, InsightDao};
|
use crate::database::{ExifDao, InsightDao};
|
||||||
@@ -826,3 +826,109 @@ pub async fn chat_history_handler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// POST /insights/chat/stream — streaming variant of /insights/chat.
|
||||||
|
/// Returns `text/event-stream` with one event per chat stream event.
|
||||||
|
#[post("/insights/chat/stream")]
|
||||||
|
pub async fn chat_stream_handler(
|
||||||
|
_claims: Claims,
|
||||||
|
request: web::Json<ChatTurnHttpRequest>,
|
||||||
|
app_state: web::Data<AppState>,
|
||||||
|
) -> HttpResponse {
|
||||||
|
let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) {
|
||||||
|
Ok(Some(lib)) => lib,
|
||||||
|
Ok(None) => app_state.primary_library(),
|
||||||
|
Err(e) => {
|
||||||
|
return HttpResponse::BadRequest().json(serde_json::json!({
|
||||||
|
"error": format!("invalid library: {}", e)
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let chat_req = ChatTurnRequest {
|
||||||
|
library_id: library.id,
|
||||||
|
file_path: request.file_path.clone(),
|
||||||
|
user_message: request.user_message.clone(),
|
||||||
|
model: request.model.clone(),
|
||||||
|
backend: request.backend.clone(),
|
||||||
|
num_ctx: request.num_ctx,
|
||||||
|
temperature: request.temperature,
|
||||||
|
top_p: request.top_p,
|
||||||
|
top_k: request.top_k,
|
||||||
|
min_p: request.min_p,
|
||||||
|
max_iterations: request.max_iterations,
|
||||||
|
amend: request.amend,
|
||||||
|
};
|
||||||
|
|
||||||
|
let service = app_state.insight_chat.clone();
|
||||||
|
let events = service.chat_turn_stream(chat_req);
|
||||||
|
|
||||||
|
// Map ChatStreamEvent → SSE frame bytes.
|
||||||
|
let sse_stream = futures::stream::StreamExt::map(events, |ev| {
|
||||||
|
let frame = render_sse_frame(&ev);
|
||||||
|
Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(frame))
|
||||||
|
});
|
||||||
|
|
||||||
|
HttpResponse::Ok()
|
||||||
|
.content_type("text/event-stream")
|
||||||
|
.insert_header(("Cache-Control", "no-cache"))
|
||||||
|
.insert_header(("X-Accel-Buffering", "no")) // nginx: disable response buffering
|
||||||
|
.streaming(sse_stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_sse_frame(ev: &ChatStreamEvent) -> String {
|
||||||
|
let (event_name, payload) = match ev {
|
||||||
|
ChatStreamEvent::IterationStart { n, max } => {
|
||||||
|
("iteration_start", serde_json::json!({ "n": n, "max": max }))
|
||||||
|
}
|
||||||
|
ChatStreamEvent::Truncated => ("truncated", serde_json::json!({})),
|
||||||
|
ChatStreamEvent::TextDelta(delta) => ("text", serde_json::json!({ "delta": delta })),
|
||||||
|
ChatStreamEvent::ToolCall {
|
||||||
|
index,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
} => (
|
||||||
|
"tool_call",
|
||||||
|
serde_json::json!({ "index": index, "name": name, "arguments": arguments }),
|
||||||
|
),
|
||||||
|
ChatStreamEvent::ToolResult {
|
||||||
|
index,
|
||||||
|
name,
|
||||||
|
result,
|
||||||
|
result_truncated,
|
||||||
|
} => (
|
||||||
|
"tool_result",
|
||||||
|
serde_json::json!({
|
||||||
|
"index": index,
|
||||||
|
"name": name,
|
||||||
|
"result": result,
|
||||||
|
"result_truncated": result_truncated,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
ChatStreamEvent::Done {
|
||||||
|
tool_calls_made,
|
||||||
|
iterations_used,
|
||||||
|
truncated,
|
||||||
|
prompt_eval_count,
|
||||||
|
eval_count,
|
||||||
|
amended_insight_id,
|
||||||
|
backend_used,
|
||||||
|
model_used,
|
||||||
|
} => (
|
||||||
|
"done",
|
||||||
|
serde_json::json!({
|
||||||
|
"tool_calls_made": tool_calls_made,
|
||||||
|
"iterations_used": iterations_used,
|
||||||
|
"truncated": truncated,
|
||||||
|
"prompt_eval_count": prompt_eval_count,
|
||||||
|
"eval_count": eval_count,
|
||||||
|
"amended_insight_id": amended_insight_id,
|
||||||
|
"backend": backend_used,
|
||||||
|
"model": model_used,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
ChatStreamEvent::Error(msg) => ("error", serde_json::json!({ "message": msg })),
|
||||||
|
};
|
||||||
|
let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string());
|
||||||
|
format!("event: {}\ndata: {}\n\n", event_name, data)
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,13 +7,14 @@ use std::sync::{Arc, Mutex};
|
|||||||
use tokio::sync::Mutex as TokioMutex;
|
use tokio::sync::Mutex as TokioMutex;
|
||||||
|
|
||||||
use crate::ai::insight_generator::InsightGenerator;
|
use crate::ai::insight_generator::InsightGenerator;
|
||||||
use crate::ai::llm_client::{ChatMessage, LlmClient};
|
use crate::ai::llm_client::{ChatMessage, LlmClient, LlmStreamEvent};
|
||||||
use crate::ai::ollama::OllamaClient;
|
use crate::ai::ollama::OllamaClient;
|
||||||
use crate::ai::openrouter::OpenRouterClient;
|
use crate::ai::openrouter::OpenRouterClient;
|
||||||
use crate::database::InsightDao;
|
use crate::database::InsightDao;
|
||||||
use crate::database::models::InsertPhotoInsight;
|
use crate::database::models::InsertPhotoInsight;
|
||||||
use crate::otel::global_tracer;
|
use crate::otel::global_tracer;
|
||||||
use crate::utils::normalize_path;
|
use crate::utils::normalize_path;
|
||||||
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
|
||||||
const DEFAULT_MAX_ITERATIONS: usize = 6;
|
const DEFAULT_MAX_ITERATIONS: usize = 6;
|
||||||
const DEFAULT_NUM_CTX: i32 = 8192;
|
const DEFAULT_NUM_CTX: i32 = 8192;
|
||||||
@@ -583,6 +584,442 @@ impl InsightChatService {
|
|||||||
.map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?;
|
.map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Streaming variant of `chat_turn`. Emits user-facing events as the
|
||||||
|
/// conversation progresses: iteration starts, tool dispatch + result,
|
||||||
|
/// text deltas from the final assistant reply, and a terminal `Done`
|
||||||
|
/// frame. Persistence happens inside the stream after the loop ends.
|
||||||
|
///
|
||||||
|
/// The stream takes ownership of the service via `Arc<Self>` (passed by
|
||||||
|
/// the caller) so it can live past the handler's await boundary.
|
||||||
|
pub fn chat_turn_stream(
|
||||||
|
self: Arc<Self>,
|
||||||
|
req: ChatTurnRequest,
|
||||||
|
) -> BoxStream<'static, ChatStreamEvent> {
|
||||||
|
let svc = self;
|
||||||
|
let s = async_stream::stream! {
|
||||||
|
match svc.chat_turn_stream_inner(req, |ev| Ok(ev)).await {
|
||||||
|
Ok(mut rx) => {
|
||||||
|
while let Some(ev) = rx.recv().await {
|
||||||
|
yield ev;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
yield ChatStreamEvent::Error(format!("{}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Box::pin(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal: drives the streaming loop on a background task, returning
|
||||||
|
/// a receiver the caller drains. Keeping the work on a spawned task
|
||||||
|
/// decouples the HTTP request lifetime from the chat execution, which
|
||||||
|
/// matters because the chat may run longer than any single network hop
|
||||||
|
/// and we want clean cancellation semantics via the channel close.
|
||||||
|
async fn chat_turn_stream_inner<F>(
|
||||||
|
self: Arc<Self>,
|
||||||
|
req: ChatTurnRequest,
|
||||||
|
_ev_mapper: F,
|
||||||
|
) -> Result<tokio::sync::mpsc::Receiver<ChatStreamEvent>>
|
||||||
|
where
|
||||||
|
F: Fn(ChatStreamEvent) -> Result<ChatStreamEvent> + Send + 'static,
|
||||||
|
{
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel::<ChatStreamEvent>(64);
|
||||||
|
let svc = self.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let result = svc.run_streaming_turn(req, tx.clone()).await;
|
||||||
|
if let Err(e) = result {
|
||||||
|
let _ = tx.send(ChatStreamEvent::Error(format!("{}", e))).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Ok(rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_streaming_turn(
|
||||||
|
self: Arc<Self>,
|
||||||
|
req: ChatTurnRequest,
|
||||||
|
tx: tokio::sync::mpsc::Sender<ChatStreamEvent>,
|
||||||
|
) -> Result<()> {
|
||||||
|
if req.user_message.trim().is_empty() {
|
||||||
|
bail!("user_message must not be empty");
|
||||||
|
}
|
||||||
|
if req.user_message.len() > 8192 {
|
||||||
|
bail!("user_message exceeds 8192 chars");
|
||||||
|
}
|
||||||
|
let normalized = normalize_path(&req.file_path);
|
||||||
|
|
||||||
|
let lock_key = (req.library_id, normalized.clone());
|
||||||
|
let entry_lock = {
|
||||||
|
let mut locks = self.chat_locks.lock().await;
|
||||||
|
locks
|
||||||
|
.entry(lock_key.clone())
|
||||||
|
.or_insert_with(|| Arc::new(TokioMutex::new(())))
|
||||||
|
.clone()
|
||||||
|
};
|
||||||
|
let _guard = entry_lock.lock().await;
|
||||||
|
|
||||||
|
let insight = {
|
||||||
|
let cx = opentelemetry::Context::new();
|
||||||
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||||
|
dao.get_insight(&cx, &normalized)
|
||||||
|
.map_err(|e| anyhow!("failed to load insight: {:?}", e))?
|
||||||
|
.ok_or_else(|| anyhow!("no insight found for path"))?
|
||||||
|
};
|
||||||
|
let raw_history = insight
|
||||||
|
.training_messages
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow!("insight has no chat history; regenerate this insight in agentic mode")
|
||||||
|
})?
|
||||||
|
.clone();
|
||||||
|
let mut messages: Vec<ChatMessage> = serde_json::from_str(&raw_history)
|
||||||
|
.map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?;
|
||||||
|
|
||||||
|
// Backend selection — same rules as non-streaming chat_turn.
|
||||||
|
let stored_backend = insight.backend.clone();
|
||||||
|
let effective_backend = req
|
||||||
|
.backend
|
||||||
|
.as_deref()
|
||||||
|
.map(|s| s.trim().to_lowercase())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(|| stored_backend.clone());
|
||||||
|
if !matches!(effective_backend.as_str(), "local" | "hybrid") {
|
||||||
|
bail!(
|
||||||
|
"unknown backend '{}'; expected 'local' or 'hybrid'",
|
||||||
|
effective_backend
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if stored_backend == "local" && effective_backend == "hybrid" {
|
||||||
|
bail!(
|
||||||
|
"switching from local to hybrid mid-chat isn't supported yet; \
|
||||||
|
regenerate the insight in hybrid mode if you want OpenRouter chat"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let is_hybrid = effective_backend == "hybrid";
|
||||||
|
|
||||||
|
let max_iterations = req
|
||||||
|
.max_iterations
|
||||||
|
.unwrap_or(DEFAULT_MAX_ITERATIONS)
|
||||||
|
.clamp(1, env_max_iterations());
|
||||||
|
|
||||||
|
let stored_model = insight.model_version.clone();
|
||||||
|
let custom_model = req
|
||||||
|
.model
|
||||||
|
.clone()
|
||||||
|
.or_else(|| Some(stored_model.clone()))
|
||||||
|
.filter(|m| !m.is_empty());
|
||||||
|
|
||||||
|
let mut ollama_client = self.ollama.clone();
|
||||||
|
let mut openrouter_client: Option<OpenRouterClient> = None;
|
||||||
|
|
||||||
|
if is_hybrid {
|
||||||
|
let arc = self.openrouter.as_ref().ok_or_else(|| {
|
||||||
|
anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured")
|
||||||
|
})?;
|
||||||
|
let mut c: OpenRouterClient = (**arc).clone();
|
||||||
|
if let Some(ref m) = custom_model {
|
||||||
|
c.primary_model = m.clone();
|
||||||
|
}
|
||||||
|
if req.temperature.is_some()
|
||||||
|
|| req.top_p.is_some()
|
||||||
|
|| req.top_k.is_some()
|
||||||
|
|| req.min_p.is_some()
|
||||||
|
{
|
||||||
|
c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
|
||||||
|
}
|
||||||
|
if let Some(ctx) = req.num_ctx {
|
||||||
|
c.set_num_ctx(Some(ctx));
|
||||||
|
}
|
||||||
|
openrouter_client = Some(c);
|
||||||
|
} else {
|
||||||
|
if let Some(ref m) = custom_model
|
||||||
|
&& m != &self.ollama.primary_model
|
||||||
|
{
|
||||||
|
ollama_client = OllamaClient::new(
|
||||||
|
self.ollama.primary_url.clone(),
|
||||||
|
self.ollama.fallback_url.clone(),
|
||||||
|
m.clone(),
|
||||||
|
Some(m.clone()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if req.temperature.is_some()
|
||||||
|
|| req.top_p.is_some()
|
||||||
|
|| req.top_k.is_some()
|
||||||
|
|| req.min_p.is_some()
|
||||||
|
{
|
||||||
|
ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
|
||||||
|
}
|
||||||
|
if let Some(ctx) = req.num_ctx {
|
||||||
|
ollama_client.set_num_ctx(Some(ctx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
&ollama_client
|
||||||
|
};
|
||||||
|
let model_used = chat_backend.primary_model().to_string();
|
||||||
|
|
||||||
|
// Tool set.
|
||||||
|
let local_first_user_has_image = messages
|
||||||
|
.iter()
|
||||||
|
.find(|m| m.role == "user")
|
||||||
|
.and_then(|m| m.images.as_ref())
|
||||||
|
.map(|imgs| !imgs.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let offer_describe_tool = !is_hybrid && local_first_user_has_image;
|
||||||
|
let tools = InsightGenerator::build_tool_definitions(offer_describe_tool);
|
||||||
|
|
||||||
|
let image_base64: Option<String> = if offer_describe_tool {
|
||||||
|
self.generator.load_image_as_base64(&normalized).ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Truncate before appending the new user turn.
|
||||||
|
let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize)
|
||||||
|
.saturating_sub(RESPONSE_HEADROOM_TOKENS);
|
||||||
|
let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN);
|
||||||
|
let truncated = apply_context_budget(&mut messages, budget_bytes);
|
||||||
|
if truncated {
|
||||||
|
let _ = tx.send(ChatStreamEvent::Truncated).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push(ChatMessage::user(req.user_message.clone()));
|
||||||
|
|
||||||
|
let mut tool_calls_made = 0usize;
|
||||||
|
let mut iterations_used = 0usize;
|
||||||
|
let mut last_prompt_eval_count: Option<i32> = None;
|
||||||
|
let mut last_eval_count: Option<i32> = None;
|
||||||
|
let mut final_content = String::new();
|
||||||
|
|
||||||
|
for iteration in 0..max_iterations {
|
||||||
|
iterations_used = iteration + 1;
|
||||||
|
let _ = tx
|
||||||
|
.send(ChatStreamEvent::IterationStart {
|
||||||
|
n: iterations_used,
|
||||||
|
max: max_iterations,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut stream = chat_backend
|
||||||
|
.chat_with_tools_stream(messages.clone(), tools.clone())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut final_message: Option<ChatMessage> = None;
|
||||||
|
while let Some(ev) = stream.next().await {
|
||||||
|
let ev = ev?;
|
||||||
|
match ev {
|
||||||
|
LlmStreamEvent::TextDelta(delta) => {
|
||||||
|
let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await;
|
||||||
|
}
|
||||||
|
LlmStreamEvent::Done {
|
||||||
|
message,
|
||||||
|
prompt_eval_count,
|
||||||
|
eval_count,
|
||||||
|
} => {
|
||||||
|
last_prompt_eval_count = prompt_eval_count;
|
||||||
|
last_eval_count = eval_count;
|
||||||
|
final_message = Some(message);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut response =
|
||||||
|
final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?;
|
||||||
|
|
||||||
|
// Normalize non-object tool arguments (same as non-streaming path).
|
||||||
|
if let Some(ref mut tcs) = response.tool_calls {
|
||||||
|
for tc in tcs.iter_mut() {
|
||||||
|
if !tc.function.arguments.is_object() {
|
||||||
|
tc.function.arguments = serde_json::Value::Object(Default::default());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push(response.clone());
|
||||||
|
|
||||||
|
if let Some(ref tool_calls) = response.tool_calls
|
||||||
|
&& !tool_calls.is_empty()
|
||||||
|
{
|
||||||
|
for (i, tool_call) in tool_calls.iter().enumerate() {
|
||||||
|
tool_calls_made += 1;
|
||||||
|
let call_index = tool_calls_made - 1;
|
||||||
|
let _ = tx
|
||||||
|
.send(ChatStreamEvent::ToolCall {
|
||||||
|
index: call_index,
|
||||||
|
name: tool_call.function.name.clone(),
|
||||||
|
arguments: tool_call.function.arguments.clone(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
let cx = opentelemetry::Context::new();
|
||||||
|
let result = self
|
||||||
|
.generator
|
||||||
|
.execute_tool(
|
||||||
|
&tool_call.function.name,
|
||||||
|
&tool_call.function.arguments,
|
||||||
|
&ollama_client,
|
||||||
|
&image_base64,
|
||||||
|
&normalized,
|
||||||
|
&cx,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let (result_preview, result_truncated) = truncate_tool_result(&result);
|
||||||
|
let _ = tx
|
||||||
|
.send(ChatStreamEvent::ToolResult {
|
||||||
|
index: call_index,
|
||||||
|
name: tool_call.function.name.clone(),
|
||||||
|
result: result_preview,
|
||||||
|
result_truncated,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
messages.push(ChatMessage::tool_result(result));
|
||||||
|
let _ = i; // reserved for per-call ordering if needed
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
final_content = response.content;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if final_content.is_empty() {
|
||||||
|
messages.push(ChatMessage::user(
|
||||||
|
"Please write your final answer now without calling any more tools.",
|
||||||
|
));
|
||||||
|
let mut stream = chat_backend
|
||||||
|
.chat_with_tools_stream(messages.clone(), vec![])
|
||||||
|
.await?;
|
||||||
|
let mut final_message: Option<ChatMessage> = None;
|
||||||
|
while let Some(ev) = stream.next().await {
|
||||||
|
let ev = ev?;
|
||||||
|
match ev {
|
||||||
|
LlmStreamEvent::TextDelta(delta) => {
|
||||||
|
let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await;
|
||||||
|
}
|
||||||
|
LlmStreamEvent::Done {
|
||||||
|
message,
|
||||||
|
prompt_eval_count,
|
||||||
|
eval_count,
|
||||||
|
} => {
|
||||||
|
last_prompt_eval_count = prompt_eval_count;
|
||||||
|
last_eval_count = eval_count;
|
||||||
|
final_message = Some(message);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let final_response =
|
||||||
|
final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?;
|
||||||
|
final_content = final_response.content.clone();
|
||||||
|
messages.push(final_response);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist.
|
||||||
|
let json = serde_json::to_string(&messages)
|
||||||
|
.map_err(|e| anyhow!("failed to serialize chat history: {}", e))?;
|
||||||
|
|
||||||
|
let mut amended_insight_id: Option<i32> = None;
|
||||||
|
if req.amend {
|
||||||
|
let title_prompt = format!(
|
||||||
|
"Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\
|
||||||
|
Capture the key moment or theme. Return ONLY the title, nothing else.",
|
||||||
|
final_content
|
||||||
|
);
|
||||||
|
let title_raw = chat_backend
|
||||||
|
.generate(
|
||||||
|
&title_prompt,
|
||||||
|
Some(
|
||||||
|
"You are my long term memory assistant. Use only the information provided. Do not invent details.",
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let title = title_raw.trim().trim_matches('"').to_string();
|
||||||
|
|
||||||
|
let new_row = InsertPhotoInsight {
|
||||||
|
library_id: req.library_id,
|
||||||
|
file_path: normalized.clone(),
|
||||||
|
title,
|
||||||
|
summary: final_content.clone(),
|
||||||
|
generated_at: Utc::now().timestamp(),
|
||||||
|
model_version: model_used.clone(),
|
||||||
|
is_current: true,
|
||||||
|
training_messages: Some(json),
|
||||||
|
backend: effective_backend.clone(),
|
||||||
|
};
|
||||||
|
let cx = opentelemetry::Context::new();
|
||||||
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||||
|
let stored = dao
|
||||||
|
.store_insight(&cx, new_row)
|
||||||
|
.map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?;
|
||||||
|
amended_insight_id = Some(stored.id);
|
||||||
|
} else {
|
||||||
|
let cx = opentelemetry::Context::new();
|
||||||
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
||||||
|
dao.update_training_messages(&cx, req.library_id, &normalized, &json)
|
||||||
|
.map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = tx
|
||||||
|
.send(ChatStreamEvent::Done {
|
||||||
|
tool_calls_made,
|
||||||
|
iterations_used,
|
||||||
|
truncated,
|
||||||
|
prompt_eval_count: last_prompt_eval_count,
|
||||||
|
eval_count: last_eval_count,
|
||||||
|
amended_insight_id,
|
||||||
|
backend_used: effective_backend,
|
||||||
|
model_used,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Events emitted by `chat_turn_stream`. One stream per turn; ends after
|
||||||
|
/// `Done` or `Error`.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ChatStreamEvent {
|
||||||
|
/// Starting iteration `n` of up to `max` (1-based).
|
||||||
|
IterationStart { n: usize, max: usize },
|
||||||
|
/// History was trimmed to fit the context budget before the turn ran.
|
||||||
|
/// Emitted at most once, before any tool or text events.
|
||||||
|
Truncated,
|
||||||
|
/// Incremental content from the final assistant reply. Concatenate to
|
||||||
|
/// reconstruct the reply body. Tool-dispatch turns don't produce these.
|
||||||
|
TextDelta(String),
|
||||||
|
/// The model requested this tool call. Emitted just before execution.
|
||||||
|
/// `index` is a monotonically-increasing counter across the turn so the
|
||||||
|
/// client can pair `ToolCall` with its matching `ToolResult`.
|
||||||
|
ToolCall {
|
||||||
|
index: usize,
|
||||||
|
name: String,
|
||||||
|
arguments: serde_json::Value,
|
||||||
|
},
|
||||||
|
/// The tool finished; `result` is the (possibly truncated) output.
|
||||||
|
ToolResult {
|
||||||
|
index: usize,
|
||||||
|
name: String,
|
||||||
|
result: String,
|
||||||
|
result_truncated: bool,
|
||||||
|
},
|
||||||
|
/// Terminal success event with counters + persistence result.
|
||||||
|
Done {
|
||||||
|
tool_calls_made: usize,
|
||||||
|
iterations_used: usize,
|
||||||
|
truncated: bool,
|
||||||
|
prompt_eval_count: Option<i32>,
|
||||||
|
eval_count: Option<i32>,
|
||||||
|
amended_insight_id: Option<i32>,
|
||||||
|
backend_used: String,
|
||||||
|
model_used: String,
|
||||||
|
},
|
||||||
|
/// Terminal failure event. No further events follow.
|
||||||
|
Error(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Is this raw message visible in the rendered transcript? Must match
|
/// Is this raw message visible in the rendered transcript? Must match
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Provider-agnostic surface for LLM backends (Ollama, OpenRouter, …).
|
/// Provider-agnostic surface for LLM backends (Ollama, OpenRouter, …).
|
||||||
@@ -30,6 +31,18 @@ pub trait LlmClient: Send + Sync {
|
|||||||
tools: Vec<Tool>,
|
tools: Vec<Tool>,
|
||||||
) -> Result<(ChatMessage, Option<i32>, Option<i32>)>;
|
) -> Result<(ChatMessage, Option<i32>, Option<i32>)>;
|
||||||
|
|
||||||
|
/// Streaming variant of `chat_with_tools`. The returned stream yields
|
||||||
|
/// `TextDelta` items as content is produced, then a single terminal
|
||||||
|
/// `Done` carrying the complete assembled message (with tool_calls, if
|
||||||
|
/// any) plus token usage counts. Implementations that can't stream may
|
||||||
|
/// fall back to calling `chat_with_tools` and emitting the full reply
|
||||||
|
/// as one `Done` event.
|
||||||
|
async fn chat_with_tools_stream(
|
||||||
|
&self,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
tools: Vec<Tool>,
|
||||||
|
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>>;
|
||||||
|
|
||||||
/// Batch embedding generation. Dimensionality is provider/model specific.
|
/// Batch embedding generation. Dimensionality is provider/model specific.
|
||||||
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
|
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
|
||||||
|
|
||||||
@@ -47,6 +60,25 @@ pub trait LlmClient: Send + Sync {
|
|||||||
fn primary_model(&self) -> &str;
|
fn primary_model(&self) -> &str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Events emitted by streaming `chat_with_tools_stream`. A stream is a
|
||||||
|
/// sequence of zero or more `TextDelta` events followed by exactly one
|
||||||
|
/// `Done`. Callers should treat `Done` as terminal — further items (if any
|
||||||
|
/// slip through due to upstream misbehavior) are safe to ignore.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum LlmStreamEvent {
|
||||||
|
/// Incremental content token(s) from the model. Concatenate in order to
|
||||||
|
/// reconstruct the assistant's final text.
|
||||||
|
TextDelta(String),
|
||||||
|
/// Terminal event with the full assembled message (content + any
|
||||||
|
/// tool_calls). `message.content` equals the concatenation of every
|
||||||
|
/// preceding `TextDelta.0`.
|
||||||
|
Done {
|
||||||
|
message: ChatMessage,
|
||||||
|
prompt_eval_count: Option<i32>,
|
||||||
|
eval_count: Option<i32>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
/// Tool definition sent to the model (OpenAI-compatible function schema).
|
/// Tool definition sent to the model (OpenAI-compatible function schema).
|
||||||
#[derive(Serialize, Clone, Debug)]
|
#[derive(Serialize, Clone, Debug)]
|
||||||
pub struct Tool {
|
pub struct Tool {
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ pub mod sms_client;
|
|||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
pub use daily_summary_job::{generate_daily_summaries, strip_summary_boilerplate};
|
pub use daily_summary_job::{generate_daily_summaries, strip_summary_boilerplate};
|
||||||
pub use handlers::{
|
pub use handlers::{
|
||||||
chat_history_handler, chat_rewind_handler, chat_turn_handler, delete_insight_handler,
|
chat_history_handler, chat_rewind_handler, chat_stream_handler, chat_turn_handler,
|
||||||
export_training_data_handler, generate_agentic_insight_handler, generate_insight_handler,
|
delete_insight_handler, export_training_data_handler, generate_agentic_insight_handler,
|
||||||
get_all_insights_handler, get_available_models_handler, get_insight_handler,
|
generate_insight_handler, get_all_insights_handler, get_available_models_handler,
|
||||||
get_openrouter_models_handler, rate_insight_handler,
|
get_insight_handler, get_openrouter_models_handler, rate_insight_handler,
|
||||||
};
|
};
|
||||||
pub use insight_generator::InsightGenerator;
|
pub use insight_generator::InsightGenerator;
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
|
|||||||
208
src/ai/ollama.rs
208
src/ai/ollama.rs
@@ -7,7 +7,8 @@ use std::collections::HashMap;
|
|||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use crate::ai::llm_client::LlmClient;
|
use crate::ai::llm_client::{LlmClient, LlmStreamEvent};
|
||||||
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
|
||||||
// Re-export shared types so existing `crate::ai::ollama::{...}` imports
|
// Re-export shared types so existing `crate::ai::ollama::{...}` imports
|
||||||
// continue to resolve.
|
// continue to resolve.
|
||||||
@@ -634,6 +635,174 @@ Analyze the image and use specific details from both the visual content and the
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Streaming variant of `chat_with_tools`. Tries primary, then falls
|
||||||
|
/// back if the initial connection fails; once the stream has begun
|
||||||
|
/// emitting, mid-stream errors propagate to the caller. Emits
|
||||||
|
/// `TextDelta` events as content tokens arrive and a single terminal
|
||||||
|
/// `Done` event when the model marks the turn complete (tool_calls, if
|
||||||
|
/// any, live on the final message).
|
||||||
|
pub async fn chat_with_tools_stream(
|
||||||
|
&self,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
tools: Vec<Tool>,
|
||||||
|
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
|
||||||
|
// Attempt primary. If it can't be opened at all, try fallback.
|
||||||
|
match self
|
||||||
|
.try_chat_with_tools_stream(&self.primary_url, messages.clone(), tools.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(s) => Ok(s),
|
||||||
|
Err(e) => {
|
||||||
|
if let Some(fallback_url) = self.fallback_url.clone() {
|
||||||
|
log::warn!(
|
||||||
|
"Streaming chat primary failed ({}); trying fallback {}",
|
||||||
|
e,
|
||||||
|
fallback_url
|
||||||
|
);
|
||||||
|
self.try_chat_with_tools_stream(&fallback_url, messages, tools)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
Err(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn try_chat_with_tools_stream(
|
||||||
|
&self,
|
||||||
|
base_url: &str,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
tools: Vec<Tool>,
|
||||||
|
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
|
||||||
|
let url = format!("{}/api/chat", base_url);
|
||||||
|
let model = if base_url == self.primary_url {
|
||||||
|
&self.primary_model
|
||||||
|
} else {
|
||||||
|
self.fallback_model
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(&self.primary_model)
|
||||||
|
};
|
||||||
|
let options = self.build_options();
|
||||||
|
|
||||||
|
let request_body = OllamaChatRequest {
|
||||||
|
model,
|
||||||
|
messages: &messages,
|
||||||
|
stream: true,
|
||||||
|
tools,
|
||||||
|
options,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.json(&request_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Failed to connect to Ollama at {}", url))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!(
|
||||||
|
"Ollama stream request failed with status {}: {}",
|
||||||
|
status,
|
||||||
|
body
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama streams NDJSON: each line is a full `OllamaStreamChunk`.
|
||||||
|
// We buffer partial lines across chunks from the byte stream.
|
||||||
|
let byte_stream = response.bytes_stream();
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
let mut buf: Vec<u8> = Vec::new();
|
||||||
|
let mut accumulated = String::new();
|
||||||
|
let mut tool_calls: Option<Vec<crate::ai::llm_client::ToolCall>> = None;
|
||||||
|
let mut role = "assistant".to_string();
|
||||||
|
let mut prompt_eval_count: Option<i32> = None;
|
||||||
|
let mut eval_count: Option<i32> = None;
|
||||||
|
let mut prompt_eval_duration: Option<u64> = None;
|
||||||
|
let mut eval_duration: Option<u64> = None;
|
||||||
|
let mut done_seen = false;
|
||||||
|
|
||||||
|
let mut byte_stream = byte_stream;
|
||||||
|
while let Some(chunk) = byte_stream.next().await {
|
||||||
|
let chunk = match chunk {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(anyhow::anyhow!("stream read failed: {}", e));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
buf.extend_from_slice(&chunk);
|
||||||
|
|
||||||
|
// Drain complete lines; hold any trailing partial.
|
||||||
|
while let Some(nl) = buf.iter().position(|b| *b == b'\n') {
|
||||||
|
let line = buf.drain(..=nl).collect::<Vec<_>>();
|
||||||
|
let line_str = match std::str::from_utf8(&line) {
|
||||||
|
Ok(s) => s.trim(),
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
if line_str.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match serde_json::from_str::<OllamaStreamChunk>(line_str) {
|
||||||
|
Ok(chunk) => {
|
||||||
|
// Accumulate content delta.
|
||||||
|
if !chunk.message.content.is_empty() {
|
||||||
|
accumulated.push_str(&chunk.message.content);
|
||||||
|
yield Ok(LlmStreamEvent::TextDelta(chunk.message.content));
|
||||||
|
}
|
||||||
|
if !chunk.message.role.is_empty() {
|
||||||
|
role = chunk.message.role;
|
||||||
|
}
|
||||||
|
// Ollama only attaches tool_calls on the final chunk.
|
||||||
|
if let Some(tcs) = chunk.message.tool_calls
|
||||||
|
&& !tcs.is_empty()
|
||||||
|
{
|
||||||
|
tool_calls = Some(tcs);
|
||||||
|
}
|
||||||
|
if chunk.done {
|
||||||
|
prompt_eval_count = chunk.prompt_eval_count;
|
||||||
|
eval_count = chunk.eval_count;
|
||||||
|
prompt_eval_duration = chunk.prompt_eval_duration;
|
||||||
|
eval_duration = chunk.eval_duration;
|
||||||
|
done_seen = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("malformed Ollama stream line: {} ({})", line_str, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if done_seen {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit the terminal Done event with the assembled message.
|
||||||
|
log_chat_metrics(
|
||||||
|
prompt_eval_count,
|
||||||
|
prompt_eval_duration,
|
||||||
|
eval_count,
|
||||||
|
eval_duration,
|
||||||
|
);
|
||||||
|
let message = ChatMessage {
|
||||||
|
role,
|
||||||
|
content: accumulated,
|
||||||
|
tool_calls,
|
||||||
|
images: None,
|
||||||
|
};
|
||||||
|
yield Ok(LlmStreamEvent::Done {
|
||||||
|
message,
|
||||||
|
prompt_eval_count,
|
||||||
|
eval_count,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Box::pin(stream))
|
||||||
|
}
|
||||||
|
|
||||||
async fn try_chat_with_tools(
|
async fn try_chat_with_tools(
|
||||||
&self,
|
&self,
|
||||||
base_url: &str,
|
base_url: &str,
|
||||||
@@ -857,6 +1026,14 @@ impl LlmClient for OllamaClient {
|
|||||||
OllamaClient::chat_with_tools(self, messages, tools).await
|
OllamaClient::chat_with_tools(self, messages, tools).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_tools_stream(
|
||||||
|
&self,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
tools: Vec<Tool>,
|
||||||
|
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
|
||||||
|
OllamaClient::chat_with_tools_stream(self, messages, tools).await
|
||||||
|
}
|
||||||
|
|
||||||
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
||||||
OllamaClient::generate_embeddings(self, texts).await
|
OllamaClient::generate_embeddings(self, texts).await
|
||||||
}
|
}
|
||||||
@@ -936,6 +1113,35 @@ struct OllamaChatResponse {
|
|||||||
eval_duration: Option<u64>,
|
eval_duration: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// One chunk in the NDJSON stream from `/api/chat` with `stream: true`.
|
||||||
|
/// Early chunks carry content deltas in `message.content`; the final chunk
|
||||||
|
/// has `done: true`, optional `tool_calls`, and usage counters.
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct OllamaStreamChunk {
|
||||||
|
#[serde(default)]
|
||||||
|
message: OllamaStreamMessage,
|
||||||
|
#[serde(default)]
|
||||||
|
done: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
prompt_eval_count: Option<i32>,
|
||||||
|
#[serde(default)]
|
||||||
|
prompt_eval_duration: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
eval_count: Option<i32>,
|
||||||
|
#[serde(default)]
|
||||||
|
eval_duration: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Default)]
|
||||||
|
struct OllamaStreamMessage {
|
||||||
|
#[serde(default)]
|
||||||
|
role: String,
|
||||||
|
#[serde(default)]
|
||||||
|
content: String,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Option<Vec<crate::ai::llm_client::ToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct OllamaResponse {
|
struct OllamaResponse {
|
||||||
response: String,
|
response: String,
|
||||||
|
|||||||
@@ -12,8 +12,9 @@ use std::sync::{Arc, Mutex};
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use crate::ai::llm_client::{
|
use crate::ai::llm_client::{
|
||||||
ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction,
|
ChatMessage, LlmClient, LlmStreamEvent, ModelCapabilities, Tool, ToolCall, ToolCallFunction,
|
||||||
};
|
};
|
||||||
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
|
||||||
const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||||
const DEFAULT_EMBEDDING_MODEL: &str = "openai/text-embedding-3-small";
|
const DEFAULT_EMBEDDING_MODEL: &str = "openai/text-embedding-3-small";
|
||||||
@@ -378,6 +379,220 @@ impl LlmClient for OpenRouterClient {
|
|||||||
Ok((chat_msg, prompt_tokens, completion_tokens))
|
Ok((chat_msg, prompt_tokens, completion_tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_tools_stream(
|
||||||
|
&self,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
tools: Vec<Tool>,
|
||||||
|
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
|
||||||
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
let mut body = serde_json::Map::new();
|
||||||
|
body.insert("model".into(), Value::String(self.primary_model.clone()));
|
||||||
|
body.insert(
|
||||||
|
"messages".into(),
|
||||||
|
Value::Array(Self::messages_to_openai(&messages)),
|
||||||
|
);
|
||||||
|
body.insert("stream".into(), Value::Bool(true));
|
||||||
|
// Ask for usage data in the final chunk (OpenAI + OpenRouter
|
||||||
|
// both honor this options bag).
|
||||||
|
body.insert(
|
||||||
|
"stream_options".into(),
|
||||||
|
serde_json::json!({ "include_usage": true }),
|
||||||
|
);
|
||||||
|
if !tools.is_empty() {
|
||||||
|
body.insert(
|
||||||
|
"tools".into(),
|
||||||
|
serde_json::to_value(&tools).context("serializing tools")?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for (k, v) in self.build_options() {
|
||||||
|
body.insert(k.into(), v);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.authed(self.client.post(&url))
|
||||||
|
.json(&Value::Object(body))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("POST {} failed", url))?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let status = resp.status();
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
bail!("OpenRouter stream request failed: {} — {}", status, body);
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI-compat SSE stream. Each event is `data: <json>\n\n`, with
|
||||||
|
// `data: [DONE]` signalling completion. Tool calls arrive as
|
||||||
|
// `delta.tool_calls[i]` chunks that must be concatenated by index.
|
||||||
|
let byte_stream = resp.bytes_stream();
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
let mut byte_stream = byte_stream;
|
||||||
|
let mut buf: Vec<u8> = Vec::new();
|
||||||
|
let mut accumulated_content = String::new();
|
||||||
|
// tool call state: index -> (id, name, args_string)
|
||||||
|
let mut tool_state: std::collections::BTreeMap<
|
||||||
|
usize,
|
||||||
|
(Option<String>, Option<String>, String),
|
||||||
|
> = std::collections::BTreeMap::new();
|
||||||
|
let mut role = "assistant".to_string();
|
||||||
|
let mut prompt_tokens: Option<i32> = None;
|
||||||
|
let mut completion_tokens: Option<i32> = None;
|
||||||
|
let mut done_seen = false;
|
||||||
|
|
||||||
|
while let Some(chunk) = byte_stream.next().await {
|
||||||
|
let chunk = match chunk {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(anyhow!("stream read failed: {}", e));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
buf.extend_from_slice(&chunk);
|
||||||
|
|
||||||
|
// SSE frames are delimited by a blank line. Walk the buffer
|
||||||
|
// for "\n\n" markers; anything before them is a complete
|
||||||
|
// frame (possibly multi-line).
|
||||||
|
loop {
|
||||||
|
let Some(sep) = find_double_newline(&buf) else { break };
|
||||||
|
let frame = buf.drain(..sep + 2).collect::<Vec<_>>();
|
||||||
|
let frame_str = match std::str::from_utf8(&frame) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
// A frame is one or more lines; the payload is on data:
|
||||||
|
// lines. Ignore comments and other fields.
|
||||||
|
for line in frame_str.lines() {
|
||||||
|
let line = line.trim_end_matches('\r');
|
||||||
|
let payload = match line.strip_prefix("data: ") {
|
||||||
|
Some(p) => p,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
if payload == "[DONE]" {
|
||||||
|
done_seen = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let v: Value = match serde_json::from_str(payload) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!(
|
||||||
|
"malformed OpenRouter SSE frame: {} ({})",
|
||||||
|
payload,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Usage can arrive in a dedicated final frame with
|
||||||
|
// empty choices.
|
||||||
|
if let Some(usage) = v.get("usage") {
|
||||||
|
prompt_tokens = usage
|
||||||
|
.get("prompt_tokens")
|
||||||
|
.and_then(|n| n.as_i64())
|
||||||
|
.map(|n| n as i32);
|
||||||
|
completion_tokens = usage
|
||||||
|
.get("completion_tokens")
|
||||||
|
.and_then(|n| n.as_i64())
|
||||||
|
.map(|n| n as i32);
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(choices) = v.get("choices").and_then(|c| c.as_array())
|
||||||
|
else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let Some(choice) = choices.first() else { continue };
|
||||||
|
let delta = match choice.get("delta") {
|
||||||
|
Some(d) => d,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
if let Some(r) = delta.get("role").and_then(|v| v.as_str()) {
|
||||||
|
role = r.to_string();
|
||||||
|
}
|
||||||
|
if let Some(content) =
|
||||||
|
delta.get("content").and_then(|v| v.as_str())
|
||||||
|
&& !content.is_empty()
|
||||||
|
{
|
||||||
|
accumulated_content.push_str(content);
|
||||||
|
yield Ok(LlmStreamEvent::TextDelta(content.to_string()));
|
||||||
|
}
|
||||||
|
if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) {
|
||||||
|
for tc_delta in tcs {
|
||||||
|
let idx = tc_delta
|
||||||
|
.get("index")
|
||||||
|
.and_then(|n| n.as_u64())
|
||||||
|
.unwrap_or(0) as usize;
|
||||||
|
let entry = tool_state
|
||||||
|
.entry(idx)
|
||||||
|
.or_insert((None, None, String::new()));
|
||||||
|
if let Some(id) =
|
||||||
|
tc_delta.get("id").and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
|
entry.0 = Some(id.to_string());
|
||||||
|
}
|
||||||
|
if let Some(func) = tc_delta.get("function") {
|
||||||
|
if let Some(name) =
|
||||||
|
func.get("name").and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
|
entry.1 = Some(name.to_string());
|
||||||
|
}
|
||||||
|
if let Some(args) =
|
||||||
|
func.get("arguments").and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
|
entry.2.push_str(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if done_seen {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if done_seen {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize tool calls: parse accumulated argument strings.
|
||||||
|
let tool_calls: Option<Vec<ToolCall>> = if tool_state.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mut v = Vec::with_capacity(tool_state.len());
|
||||||
|
for (_idx, (id, name, args)) in tool_state {
|
||||||
|
let arguments: Value = if args.trim().is_empty() {
|
||||||
|
Value::Object(Default::default())
|
||||||
|
} else {
|
||||||
|
serde_json::from_str(&args).unwrap_or_else(|_| {
|
||||||
|
Value::Object(Default::default())
|
||||||
|
})
|
||||||
|
};
|
||||||
|
v.push(ToolCall {
|
||||||
|
id,
|
||||||
|
function: ToolCallFunction {
|
||||||
|
name: name.unwrap_or_default(),
|
||||||
|
arguments,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Some(v)
|
||||||
|
};
|
||||||
|
|
||||||
|
let message = ChatMessage {
|
||||||
|
role,
|
||||||
|
content: accumulated_content,
|
||||||
|
tool_calls,
|
||||||
|
images: None,
|
||||||
|
};
|
||||||
|
yield Ok(LlmStreamEvent::Done {
|
||||||
|
message,
|
||||||
|
prompt_eval_count: prompt_tokens,
|
||||||
|
eval_count: completion_tokens,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Box::pin(stream))
|
||||||
|
}
|
||||||
|
|
||||||
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
||||||
let url = format!("{}/embeddings", self.base_url);
|
let url = format!("{}/embeddings", self.base_url);
|
||||||
let body = json!({
|
let body = json!({
|
||||||
@@ -473,6 +688,28 @@ impl LlmClient for OpenRouterClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Find the byte offset of the first `\n\n` (end of an SSE frame) in `buf`.
|
||||||
|
/// Returns the index of the first `\n` of the pair, so the full separator is
|
||||||
|
/// `buf[idx..=idx+1]`. Also handles `\r\n\r\n` since some servers emit it.
|
||||||
|
fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
||||||
|
for i in 0..buf.len().saturating_sub(1) {
|
||||||
|
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
|
||||||
|
return Some(i);
|
||||||
|
}
|
||||||
|
// \r\n\r\n: the second \n of this pattern is at i+2; flag at i so the
|
||||||
|
// drain call (which consumes ..sep+2) takes exactly the frame.
|
||||||
|
if i + 3 < buf.len()
|
||||||
|
&& buf[i] == b'\r'
|
||||||
|
&& buf[i + 1] == b'\n'
|
||||||
|
&& buf[i + 2] == b'\r'
|
||||||
|
&& buf[i + 3] == b'\n'
|
||||||
|
{
|
||||||
|
return Some(i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
/// Build a `data:` URL if the provided string is raw base64, otherwise pass it through.
|
/// Build a `data:` URL if the provided string is raw base64, otherwise pass it through.
|
||||||
fn image_to_data_url(img: &str) -> String {
|
fn image_to_data_url(img: &str) -> String {
|
||||||
if img.starts_with("data:") {
|
if img.starts_with("data:") {
|
||||||
|
|||||||
@@ -1357,6 +1357,7 @@ fn main() -> std::io::Result<()> {
|
|||||||
.service(ai::get_available_models_handler)
|
.service(ai::get_available_models_handler)
|
||||||
.service(ai::get_openrouter_models_handler)
|
.service(ai::get_openrouter_models_handler)
|
||||||
.service(ai::chat_turn_handler)
|
.service(ai::chat_turn_handler)
|
||||||
|
.service(ai::chat_stream_handler)
|
||||||
.service(ai::chat_history_handler)
|
.service(ai::chat_history_handler)
|
||||||
.service(ai::chat_rewind_handler)
|
.service(ai::chat_rewind_handler)
|
||||||
.service(ai::rate_insight_handler)
|
.service(ai::rate_insight_handler)
|
||||||
|
|||||||
Reference in New Issue
Block a user