Add LlmClient::chat_with_tools_stream and SSE endpoint POST /insights/chat/stream that emits text deltas, tool_call / tool_result pairs, truncated notice, and a terminal done frame as the agentic loop runs. - Ollama: parses NDJSON from /api/chat stream, accumulates content deltas, emits Done with tool_calls from the final chunk. - OpenRouter: parses OpenAI-compatible SSE, reassembles tool_call argument deltas by index, asks for stream_options.include_usage. - InsightChatService spawns the loop on a tokio task, feeds events through an mpsc channel, persists training_messages at the end. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1303 lines
50 KiB
Rust
1303 lines
50 KiB
Rust
use anyhow::{Result, anyhow, bail};
|
|
use chrono::Utc;
|
|
use opentelemetry::KeyValue;
|
|
use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer};
|
|
use std::collections::HashMap;
|
|
use std::sync::{Arc, Mutex};
|
|
use tokio::sync::Mutex as TokioMutex;
|
|
|
|
use crate::ai::insight_generator::InsightGenerator;
|
|
use crate::ai::llm_client::{ChatMessage, LlmClient, LlmStreamEvent};
|
|
use crate::ai::ollama::OllamaClient;
|
|
use crate::ai::openrouter::OpenRouterClient;
|
|
use crate::database::InsightDao;
|
|
use crate::database::models::InsertPhotoInsight;
|
|
use crate::otel::global_tracer;
|
|
use crate::utils::normalize_path;
|
|
use futures::stream::{BoxStream, StreamExt};
|
|
|
|
const DEFAULT_MAX_ITERATIONS: usize = 6;
|
|
const DEFAULT_NUM_CTX: i32 = 8192;
|
|
/// Headroom reserved for the model's response, deducted from the context
|
|
/// budget when deciding whether to truncate the replayed history.
|
|
const RESPONSE_HEADROOM_TOKENS: usize = 2048;
|
|
/// Cheap byte-to-token approximation used by the truncation pass. The real
|
|
/// tokenization is model-specific; this avoids carrying tiktoken just for a
|
|
/// soft bound.
|
|
const BYTES_PER_TOKEN: usize = 4;
|
|
|
|
pub type ChatLockMap = Arc<TokioMutex<HashMap<(i32, String), Arc<TokioMutex<()>>>>>;
|
|
|
|
#[derive(Debug)]
|
|
pub struct ChatTurnRequest {
|
|
pub library_id: i32,
|
|
pub file_path: String,
|
|
pub user_message: String,
|
|
/// Override the model id. Local mode: an Ollama model name. Hybrid:
|
|
/// an OpenRouter id. None defers to the stored insight's `model_version`.
|
|
pub model: Option<String>,
|
|
/// Override the backend used for this turn. None defers to the stored
|
|
/// insight's `backend`. Switching `local -> hybrid` is rejected in v1.
|
|
pub backend: Option<String>,
|
|
pub num_ctx: Option<i32>,
|
|
pub temperature: Option<f32>,
|
|
pub top_p: Option<f32>,
|
|
pub top_k: Option<i32>,
|
|
pub min_p: Option<f32>,
|
|
pub max_iterations: Option<usize>,
|
|
/// When true, write a new insight row (regenerating title) instead of
|
|
/// updating training_messages on the existing row.
|
|
pub amend: bool,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct ChatTurnResult {
|
|
pub assistant_message: String,
|
|
pub tool_calls_made: usize,
|
|
pub iterations_used: usize,
|
|
pub truncated: bool,
|
|
pub prompt_eval_count: Option<i32>,
|
|
pub eval_count: Option<i32>,
|
|
/// Set when `amend=true` and the new insight row was inserted.
|
|
pub amended_insight_id: Option<i32>,
|
|
/// Backend used for this turn — useful when the client overrode the
|
|
/// stored value.
|
|
pub backend_used: String,
|
|
/// Model identifier the chat backend ran with.
|
|
pub model_used: String,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct InsightChatService {
|
|
generator: Arc<InsightGenerator>,
|
|
ollama: OllamaClient,
|
|
openrouter: Option<Arc<OpenRouterClient>>,
|
|
insight_dao: Arc<Mutex<Box<dyn InsightDao>>>,
|
|
chat_locks: ChatLockMap,
|
|
}
|
|
|
|
impl InsightChatService {
|
|
pub fn new(
|
|
generator: Arc<InsightGenerator>,
|
|
ollama: OllamaClient,
|
|
openrouter: Option<Arc<OpenRouterClient>>,
|
|
insight_dao: Arc<Mutex<Box<dyn InsightDao>>>,
|
|
chat_locks: ChatLockMap,
|
|
) -> Self {
|
|
Self {
|
|
generator,
|
|
ollama,
|
|
openrouter,
|
|
insight_dao,
|
|
chat_locks,
|
|
}
|
|
}
|
|
|
|
/// Load the rendered transcript for chat-UI display. Filters internal
|
|
/// scaffolding (system message, tool turns, tool-dispatch-only assistant
|
|
/// messages) and drops base64 images from user turns to keep payloads
|
|
/// small. The first remaining user message is flagged `is_initial`.
|
|
pub fn load_history(&self, file_path: &str) -> Result<HistoryView> {
|
|
let normalized = normalize_path(file_path);
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
let insight = dao
|
|
.get_insight(&cx, &normalized)
|
|
.map_err(|e| anyhow!("failed to load insight: {:?}", e))?
|
|
.ok_or_else(|| anyhow!("no insight found for path"))?;
|
|
|
|
let raw = insight
|
|
.training_messages
|
|
.as_ref()
|
|
.ok_or_else(|| anyhow!("insight has no chat history (pre-agentic insight)"))?;
|
|
let messages: Vec<ChatMessage> = serde_json::from_str(raw)
|
|
.map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?;
|
|
|
|
let mut rendered = Vec::new();
|
|
let mut user_turns_seen = 0usize;
|
|
let mut assistant_turns_seen = 0usize;
|
|
|
|
// Accumulate tool invocations seen since the last user turn. An
|
|
// invocation is: one assistant tool_call message (which may hold
|
|
// multiple calls) + the N following tool-role messages (one per call,
|
|
// in order). They attach to the next assistant-with-content, which
|
|
// is the "final" reply for the current turn.
|
|
//
|
|
// Wire shape from the model:
|
|
// assistant { tool_calls: [A, B], content: "" }
|
|
// tool { content: "result of A" }
|
|
// tool { content: "result of B" }
|
|
// assistant { content: "here's the answer" } ← rendered as final
|
|
let mut pending_tools: Vec<ToolInvocation> = Vec::new();
|
|
// Queue of (name, arguments) awaiting a tool_result to pair with.
|
|
let mut pending_calls: std::collections::VecDeque<(String, serde_json::Value)> =
|
|
std::collections::VecDeque::new();
|
|
|
|
for msg in &messages {
|
|
match msg.role.as_str() {
|
|
"system" => continue,
|
|
"tool" => {
|
|
if let Some((name, arguments)) = pending_calls.pop_front() {
|
|
let (result, result_truncated) = truncate_tool_result(&msg.content);
|
|
pending_tools.push(ToolInvocation {
|
|
name,
|
|
arguments,
|
|
result,
|
|
result_truncated,
|
|
});
|
|
}
|
|
// If there's no pending call, the tool message is an
|
|
// orphan (shouldn't happen in practice) — skip silently.
|
|
}
|
|
"assistant" => {
|
|
let has_tool_calls = msg
|
|
.tool_calls
|
|
.as_ref()
|
|
.map(|c| !c.is_empty())
|
|
.unwrap_or(false);
|
|
if has_tool_calls && msg.content.trim().is_empty() {
|
|
// Tool-dispatch turn: enqueue calls, wait for tool
|
|
// results on subsequent messages.
|
|
if let Some(ref tcs) = msg.tool_calls {
|
|
for tc in tcs {
|
|
pending_calls.push_back((
|
|
tc.function.name.clone(),
|
|
tc.function.arguments.clone(),
|
|
));
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
// Final assistant reply for this turn — drain accumulated
|
|
// tools into it.
|
|
assistant_turns_seen += 1;
|
|
let tools = std::mem::take(&mut pending_tools);
|
|
pending_calls.clear(); // any leftover unpaired calls are dropped
|
|
rendered.push(RenderedMessage {
|
|
role: "assistant".to_string(),
|
|
content: msg.content.clone(),
|
|
is_initial: false,
|
|
tools,
|
|
});
|
|
}
|
|
"user" => {
|
|
let is_initial = user_turns_seen == 0;
|
|
user_turns_seen += 1;
|
|
// New user turn resets any in-flight tool state.
|
|
pending_tools.clear();
|
|
pending_calls.clear();
|
|
rendered.push(RenderedMessage {
|
|
role: "user".to_string(),
|
|
content: msg.content.clone(),
|
|
is_initial,
|
|
tools: Vec::new(),
|
|
});
|
|
}
|
|
_ => continue,
|
|
}
|
|
}
|
|
|
|
Ok(HistoryView {
|
|
messages: rendered,
|
|
turn_count: assistant_turns_seen,
|
|
model_version: insight.model_version,
|
|
backend: insight.backend,
|
|
})
|
|
}
|
|
|
|
pub async fn chat_turn(&self, req: ChatTurnRequest) -> Result<ChatTurnResult> {
|
|
let tracer = global_tracer();
|
|
let parent_cx = opentelemetry::Context::new();
|
|
let mut span = tracer.start_with_context("ai.insight.chat_turn", &parent_cx);
|
|
span.set_attribute(KeyValue::new("file_path", req.file_path.clone()));
|
|
span.set_attribute(KeyValue::new("library_id", req.library_id as i64));
|
|
span.set_attribute(KeyValue::new("amend", req.amend));
|
|
|
|
if req.user_message.trim().is_empty() {
|
|
bail!("user_message must not be empty");
|
|
}
|
|
if req.user_message.len() > 8192 {
|
|
bail!("user_message exceeds 8192 chars");
|
|
}
|
|
|
|
let normalized = normalize_path(&req.file_path);
|
|
|
|
// 1. Acquire the per-(library, file) async mutex. Two concurrent
|
|
// chat turns on the same insight would race on the JSON blob —
|
|
// the lock serialises them.
|
|
let lock_key = (req.library_id, normalized.clone());
|
|
let entry_lock = {
|
|
let mut locks = self.chat_locks.lock().await;
|
|
locks
|
|
.entry(lock_key.clone())
|
|
.or_insert_with(|| Arc::new(TokioMutex::new(())))
|
|
.clone()
|
|
};
|
|
let _guard = entry_lock.lock().await;
|
|
|
|
// 2. Load the current insight + history.
|
|
let insight = {
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
dao.get_insight(&cx, &normalized)
|
|
.map_err(|e| anyhow!("failed to load insight: {:?}", e))?
|
|
.ok_or_else(|| anyhow!("no insight found for path"))?
|
|
};
|
|
let raw_history = insight
|
|
.training_messages
|
|
.as_ref()
|
|
.ok_or_else(|| {
|
|
anyhow!("insight has no chat history; regenerate this insight in agentic mode")
|
|
})?
|
|
.clone();
|
|
let mut messages: Vec<ChatMessage> = serde_json::from_str(&raw_history)
|
|
.map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?;
|
|
|
|
// 3. Resolve effective backend. Reject the unsupported switch.
|
|
let stored_backend = insight.backend.clone();
|
|
let effective_backend = req
|
|
.backend
|
|
.as_deref()
|
|
.map(|s| s.trim().to_lowercase())
|
|
.filter(|s| !s.is_empty())
|
|
.unwrap_or_else(|| stored_backend.clone());
|
|
if !matches!(effective_backend.as_str(), "local" | "hybrid") {
|
|
bail!(
|
|
"unknown backend '{}'; expected 'local' or 'hybrid'",
|
|
effective_backend
|
|
);
|
|
}
|
|
if stored_backend == "local" && effective_backend == "hybrid" {
|
|
bail!(
|
|
"switching from local to hybrid mid-chat isn't supported yet; \
|
|
regenerate the insight in hybrid mode if you want OpenRouter chat"
|
|
);
|
|
}
|
|
let is_hybrid = effective_backend == "hybrid";
|
|
span.set_attribute(KeyValue::new("backend", effective_backend.clone()));
|
|
|
|
// 4. Build the chat backend client. Ollama in local mode, a freshly
|
|
// cloned OpenRouter client in hybrid mode (clone so per-request
|
|
// sampling/model overrides don't leak into shared state).
|
|
let max_iterations = req
|
|
.max_iterations
|
|
.unwrap_or(DEFAULT_MAX_ITERATIONS)
|
|
.clamp(1, env_max_iterations());
|
|
span.set_attribute(KeyValue::new("max_iterations", max_iterations as i64));
|
|
|
|
let stored_model = insight.model_version.clone();
|
|
let custom_model = req
|
|
.model
|
|
.clone()
|
|
.or_else(|| Some(stored_model.clone()))
|
|
.filter(|m| !m.is_empty());
|
|
|
|
let mut ollama_client = self.ollama.clone();
|
|
let mut openrouter_client: Option<OpenRouterClient> = None;
|
|
|
|
if is_hybrid {
|
|
let arc = self.openrouter.as_ref().ok_or_else(|| {
|
|
anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured")
|
|
})?;
|
|
let mut c: OpenRouterClient = (**arc).clone();
|
|
if let Some(ref m) = custom_model {
|
|
c.primary_model = m.clone();
|
|
}
|
|
if req.temperature.is_some()
|
|
|| req.top_p.is_some()
|
|
|| req.top_k.is_some()
|
|
|| req.min_p.is_some()
|
|
{
|
|
c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
|
|
}
|
|
if let Some(ctx) = req.num_ctx {
|
|
c.set_num_ctx(Some(ctx));
|
|
}
|
|
openrouter_client = Some(c);
|
|
} else {
|
|
// Local-mode model swap. Build a new client when the chat model
|
|
// differs from the configured one (mirrors the agentic pattern).
|
|
if let Some(ref m) = custom_model
|
|
&& m != &self.ollama.primary_model
|
|
{
|
|
ollama_client = OllamaClient::new(
|
|
self.ollama.primary_url.clone(),
|
|
self.ollama.fallback_url.clone(),
|
|
m.clone(),
|
|
Some(m.clone()),
|
|
);
|
|
}
|
|
if req.temperature.is_some()
|
|
|| req.top_p.is_some()
|
|
|| req.top_k.is_some()
|
|
|| req.min_p.is_some()
|
|
{
|
|
ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
|
|
}
|
|
if let Some(ctx) = req.num_ctx {
|
|
ollama_client.set_num_ctx(Some(ctx));
|
|
}
|
|
}
|
|
|
|
let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client {
|
|
c
|
|
} else {
|
|
&ollama_client
|
|
};
|
|
let model_used = chat_backend.primary_model().to_string();
|
|
span.set_attribute(KeyValue::new("model", model_used.clone()));
|
|
|
|
// 5. Decide vision + tool set. In hybrid we always omit
|
|
// `describe_photo` (matches the original generation flow). In
|
|
// local we trust the stored history's first-user shape: if it
|
|
// carries `images`, the original model was vision-capable, and
|
|
// we keep `describe_photo` available.
|
|
let local_first_user_has_image = messages
|
|
.iter()
|
|
.find(|m| m.role == "user")
|
|
.and_then(|m| m.images.as_ref())
|
|
.map(|imgs| !imgs.is_empty())
|
|
.unwrap_or(false);
|
|
let offer_describe_tool = !is_hybrid && local_first_user_has_image;
|
|
let tools = InsightGenerator::build_tool_definitions(offer_describe_tool);
|
|
|
|
// Image base64 only needed when describe_photo is on the menu. Load
|
|
// lazily to avoid disk IO when the loop never invokes it.
|
|
let image_base64: Option<String> = if offer_describe_tool {
|
|
self.generator.load_image_as_base64(&normalized).ok()
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// 6. Apply truncation budget. Drops oldest tool_call+tool pairs
|
|
// (preserves system + first user including any images).
|
|
let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize)
|
|
.saturating_sub(RESPONSE_HEADROOM_TOKENS);
|
|
let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN);
|
|
let truncated = apply_context_budget(&mut messages, budget_bytes);
|
|
if truncated {
|
|
span.set_attribute(KeyValue::new("history_truncated", true));
|
|
}
|
|
|
|
// 7. Append the new user turn.
|
|
messages.push(ChatMessage::user(req.user_message.clone()));
|
|
|
|
let insight_cx = parent_cx.with_span(span);
|
|
|
|
// 8. Agentic loop — same shape as insight_generator's, but capped
|
|
// tighter and dispatching tools through the shared executor.
|
|
let loop_span = tracer.start_with_context("ai.chat.loop", &insight_cx);
|
|
let loop_cx = insight_cx.with_span(loop_span);
|
|
let mut tool_calls_made = 0usize;
|
|
let mut iterations_used = 0usize;
|
|
let mut last_prompt_eval_count: Option<i32> = None;
|
|
let mut last_eval_count: Option<i32> = None;
|
|
let mut final_content = String::new();
|
|
|
|
for iteration in 0..max_iterations {
|
|
iterations_used = iteration + 1;
|
|
log::info!("Chat iteration {}/{}", iterations_used, max_iterations);
|
|
|
|
let (response, prompt_tokens, eval_tokens) = chat_backend
|
|
.chat_with_tools(messages.clone(), tools.clone())
|
|
.await?;
|
|
last_prompt_eval_count = prompt_tokens;
|
|
last_eval_count = eval_tokens;
|
|
|
|
// Ollama rejects non-object tool-call arguments on replay.
|
|
let mut response = response;
|
|
if let Some(ref mut tcs) = response.tool_calls {
|
|
for tc in tcs.iter_mut() {
|
|
if !tc.function.arguments.is_object() {
|
|
tc.function.arguments = serde_json::Value::Object(Default::default());
|
|
}
|
|
}
|
|
}
|
|
|
|
messages.push(response.clone());
|
|
|
|
if let Some(ref tool_calls) = response.tool_calls
|
|
&& !tool_calls.is_empty()
|
|
{
|
|
for tool_call in tool_calls {
|
|
tool_calls_made += 1;
|
|
log::info!(
|
|
"Chat tool call [{}]: {} {:?}",
|
|
iteration,
|
|
tool_call.function.name,
|
|
tool_call.function.arguments
|
|
);
|
|
let result = self
|
|
.generator
|
|
.execute_tool(
|
|
&tool_call.function.name,
|
|
&tool_call.function.arguments,
|
|
&ollama_client,
|
|
&image_base64,
|
|
&normalized,
|
|
&loop_cx,
|
|
)
|
|
.await;
|
|
messages.push(ChatMessage::tool_result(result));
|
|
}
|
|
continue;
|
|
}
|
|
|
|
final_content = response.content;
|
|
break;
|
|
}
|
|
|
|
if final_content.is_empty() {
|
|
// The model never produced a final answer; ask once more without
|
|
// tools to force a textual reply.
|
|
log::info!(
|
|
"Chat loop exhausted after {} iterations, requesting final answer",
|
|
iterations_used
|
|
);
|
|
messages.push(ChatMessage::user(
|
|
"Please write your final answer now without calling any more tools.",
|
|
));
|
|
let (final_response, prompt_tokens, eval_tokens) = chat_backend
|
|
.chat_with_tools(messages.clone(), vec![])
|
|
.await?;
|
|
last_prompt_eval_count = prompt_tokens;
|
|
last_eval_count = eval_tokens;
|
|
final_content = final_response.content.clone();
|
|
messages.push(final_response);
|
|
}
|
|
|
|
loop_cx.span().set_status(Status::Ok);
|
|
|
|
// 9. Persist. Append mode rewrites the JSON blob in place; amend
|
|
// mode regenerates the title and inserts a new insight row,
|
|
// relying on store_insight to flip prior rows' is_current=false.
|
|
let json = serde_json::to_string(&messages)
|
|
.map_err(|e| anyhow!("failed to serialize chat history: {}", e))?;
|
|
|
|
let mut amended_insight_id: Option<i32> = None;
|
|
if req.amend {
|
|
let title_prompt = format!(
|
|
"Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\
|
|
Capture the key moment or theme. Return ONLY the title, nothing else.",
|
|
final_content
|
|
);
|
|
let title_raw = chat_backend
|
|
.generate(
|
|
&title_prompt,
|
|
Some(
|
|
"You are my long term memory assistant. Use only the information provided. Do not invent details.",
|
|
),
|
|
None,
|
|
)
|
|
.await?;
|
|
let title = title_raw.trim().trim_matches('"').to_string();
|
|
|
|
let new_row = InsertPhotoInsight {
|
|
library_id: req.library_id,
|
|
file_path: normalized.clone(),
|
|
title,
|
|
summary: final_content.clone(),
|
|
generated_at: Utc::now().timestamp(),
|
|
model_version: model_used.clone(),
|
|
is_current: true,
|
|
training_messages: Some(json),
|
|
backend: effective_backend.clone(),
|
|
};
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
let stored = dao
|
|
.store_insight(&cx, new_row)
|
|
.map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?;
|
|
amended_insight_id = Some(stored.id);
|
|
} else {
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
dao.update_training_messages(&cx, req.library_id, &normalized, &json)
|
|
.map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?;
|
|
}
|
|
|
|
Ok(ChatTurnResult {
|
|
assistant_message: final_content,
|
|
tool_calls_made,
|
|
iterations_used,
|
|
truncated,
|
|
prompt_eval_count: last_prompt_eval_count,
|
|
eval_count: last_eval_count,
|
|
amended_insight_id,
|
|
backend_used: effective_backend,
|
|
model_used,
|
|
})
|
|
}
|
|
|
|
/// Truncate the stored conversation so the rendered message at
|
|
/// `discard_from_rendered_index` (and everything after it — including
|
|
/// the tool-call scaffolding that produced a discarded assistant reply)
|
|
/// is removed. The initial user turn cannot be discarded; attempting to
|
|
/// do so returns an error.
|
|
///
|
|
/// Holds the per-file chat mutex so it serialises with `chat_turn`.
|
|
pub async fn rewind_history(
|
|
&self,
|
|
library_id: i32,
|
|
file_path: &str,
|
|
discard_from_rendered_index: usize,
|
|
) -> Result<()> {
|
|
if discard_from_rendered_index == 0 {
|
|
bail!("cannot discard the initial user message");
|
|
}
|
|
let normalized = normalize_path(file_path);
|
|
|
|
let lock_key = (library_id, normalized.clone());
|
|
let entry_lock = {
|
|
let mut locks = self.chat_locks.lock().await;
|
|
locks
|
|
.entry(lock_key.clone())
|
|
.or_insert_with(|| Arc::new(TokioMutex::new(())))
|
|
.clone()
|
|
};
|
|
let _guard = entry_lock.lock().await;
|
|
|
|
let insight = {
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
dao.get_insight(&cx, &normalized)
|
|
.map_err(|e| anyhow!("failed to load insight: {:?}", e))?
|
|
.ok_or_else(|| anyhow!("no insight found for path"))?
|
|
};
|
|
let raw_history = insight
|
|
.training_messages
|
|
.as_ref()
|
|
.ok_or_else(|| anyhow!("insight has no chat history"))?;
|
|
let messages: Vec<ChatMessage> = serde_json::from_str(raw_history)
|
|
.map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?;
|
|
|
|
let cut_at = find_raw_cut(&messages, discard_from_rendered_index)
|
|
.ok_or_else(|| anyhow!("discard_from_rendered_index out of range"))?;
|
|
|
|
let truncated = &messages[..cut_at];
|
|
let json = serde_json::to_string(truncated)
|
|
.map_err(|e| anyhow!("failed to serialize truncated history: {}", e))?;
|
|
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
dao.update_training_messages(&cx, library_id, &normalized, &json)
|
|
.map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Streaming variant of `chat_turn`. Emits user-facing events as the
|
|
/// conversation progresses: iteration starts, tool dispatch + result,
|
|
/// text deltas from the final assistant reply, and a terminal `Done`
|
|
/// frame. Persistence happens inside the stream after the loop ends.
|
|
///
|
|
/// The stream takes ownership of the service via `Arc<Self>` (passed by
|
|
/// the caller) so it can live past the handler's await boundary.
|
|
pub fn chat_turn_stream(
|
|
self: Arc<Self>,
|
|
req: ChatTurnRequest,
|
|
) -> BoxStream<'static, ChatStreamEvent> {
|
|
let svc = self;
|
|
let s = async_stream::stream! {
|
|
match svc.chat_turn_stream_inner(req, |ev| Ok(ev)).await {
|
|
Ok(mut rx) => {
|
|
while let Some(ev) = rx.recv().await {
|
|
yield ev;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
yield ChatStreamEvent::Error(format!("{}", e));
|
|
}
|
|
}
|
|
};
|
|
Box::pin(s)
|
|
}
|
|
|
|
/// Internal: drives the streaming loop on a background task, returning
|
|
/// a receiver the caller drains. Keeping the work on a spawned task
|
|
/// decouples the HTTP request lifetime from the chat execution, which
|
|
/// matters because the chat may run longer than any single network hop
|
|
/// and we want clean cancellation semantics via the channel close.
|
|
async fn chat_turn_stream_inner<F>(
|
|
self: Arc<Self>,
|
|
req: ChatTurnRequest,
|
|
_ev_mapper: F,
|
|
) -> Result<tokio::sync::mpsc::Receiver<ChatStreamEvent>>
|
|
where
|
|
F: Fn(ChatStreamEvent) -> Result<ChatStreamEvent> + Send + 'static,
|
|
{
|
|
let (tx, rx) = tokio::sync::mpsc::channel::<ChatStreamEvent>(64);
|
|
let svc = self.clone();
|
|
tokio::spawn(async move {
|
|
let result = svc.run_streaming_turn(req, tx.clone()).await;
|
|
if let Err(e) = result {
|
|
let _ = tx.send(ChatStreamEvent::Error(format!("{}", e))).await;
|
|
}
|
|
});
|
|
Ok(rx)
|
|
}
|
|
|
|
async fn run_streaming_turn(
|
|
self: Arc<Self>,
|
|
req: ChatTurnRequest,
|
|
tx: tokio::sync::mpsc::Sender<ChatStreamEvent>,
|
|
) -> Result<()> {
|
|
if req.user_message.trim().is_empty() {
|
|
bail!("user_message must not be empty");
|
|
}
|
|
if req.user_message.len() > 8192 {
|
|
bail!("user_message exceeds 8192 chars");
|
|
}
|
|
let normalized = normalize_path(&req.file_path);
|
|
|
|
let lock_key = (req.library_id, normalized.clone());
|
|
let entry_lock = {
|
|
let mut locks = self.chat_locks.lock().await;
|
|
locks
|
|
.entry(lock_key.clone())
|
|
.or_insert_with(|| Arc::new(TokioMutex::new(())))
|
|
.clone()
|
|
};
|
|
let _guard = entry_lock.lock().await;
|
|
|
|
let insight = {
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
dao.get_insight(&cx, &normalized)
|
|
.map_err(|e| anyhow!("failed to load insight: {:?}", e))?
|
|
.ok_or_else(|| anyhow!("no insight found for path"))?
|
|
};
|
|
let raw_history = insight
|
|
.training_messages
|
|
.as_ref()
|
|
.ok_or_else(|| {
|
|
anyhow!("insight has no chat history; regenerate this insight in agentic mode")
|
|
})?
|
|
.clone();
|
|
let mut messages: Vec<ChatMessage> = serde_json::from_str(&raw_history)
|
|
.map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?;
|
|
|
|
// Backend selection — same rules as non-streaming chat_turn.
|
|
let stored_backend = insight.backend.clone();
|
|
let effective_backend = req
|
|
.backend
|
|
.as_deref()
|
|
.map(|s| s.trim().to_lowercase())
|
|
.filter(|s| !s.is_empty())
|
|
.unwrap_or_else(|| stored_backend.clone());
|
|
if !matches!(effective_backend.as_str(), "local" | "hybrid") {
|
|
bail!(
|
|
"unknown backend '{}'; expected 'local' or 'hybrid'",
|
|
effective_backend
|
|
);
|
|
}
|
|
if stored_backend == "local" && effective_backend == "hybrid" {
|
|
bail!(
|
|
"switching from local to hybrid mid-chat isn't supported yet; \
|
|
regenerate the insight in hybrid mode if you want OpenRouter chat"
|
|
);
|
|
}
|
|
let is_hybrid = effective_backend == "hybrid";
|
|
|
|
let max_iterations = req
|
|
.max_iterations
|
|
.unwrap_or(DEFAULT_MAX_ITERATIONS)
|
|
.clamp(1, env_max_iterations());
|
|
|
|
let stored_model = insight.model_version.clone();
|
|
let custom_model = req
|
|
.model
|
|
.clone()
|
|
.or_else(|| Some(stored_model.clone()))
|
|
.filter(|m| !m.is_empty());
|
|
|
|
let mut ollama_client = self.ollama.clone();
|
|
let mut openrouter_client: Option<OpenRouterClient> = None;
|
|
|
|
if is_hybrid {
|
|
let arc = self.openrouter.as_ref().ok_or_else(|| {
|
|
anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured")
|
|
})?;
|
|
let mut c: OpenRouterClient = (**arc).clone();
|
|
if let Some(ref m) = custom_model {
|
|
c.primary_model = m.clone();
|
|
}
|
|
if req.temperature.is_some()
|
|
|| req.top_p.is_some()
|
|
|| req.top_k.is_some()
|
|
|| req.min_p.is_some()
|
|
{
|
|
c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
|
|
}
|
|
if let Some(ctx) = req.num_ctx {
|
|
c.set_num_ctx(Some(ctx));
|
|
}
|
|
openrouter_client = Some(c);
|
|
} else {
|
|
if let Some(ref m) = custom_model
|
|
&& m != &self.ollama.primary_model
|
|
{
|
|
ollama_client = OllamaClient::new(
|
|
self.ollama.primary_url.clone(),
|
|
self.ollama.fallback_url.clone(),
|
|
m.clone(),
|
|
Some(m.clone()),
|
|
);
|
|
}
|
|
if req.temperature.is_some()
|
|
|| req.top_p.is_some()
|
|
|| req.top_k.is_some()
|
|
|| req.min_p.is_some()
|
|
{
|
|
ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p);
|
|
}
|
|
if let Some(ctx) = req.num_ctx {
|
|
ollama_client.set_num_ctx(Some(ctx));
|
|
}
|
|
}
|
|
|
|
let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client {
|
|
c
|
|
} else {
|
|
&ollama_client
|
|
};
|
|
let model_used = chat_backend.primary_model().to_string();
|
|
|
|
// Tool set.
|
|
let local_first_user_has_image = messages
|
|
.iter()
|
|
.find(|m| m.role == "user")
|
|
.and_then(|m| m.images.as_ref())
|
|
.map(|imgs| !imgs.is_empty())
|
|
.unwrap_or(false);
|
|
let offer_describe_tool = !is_hybrid && local_first_user_has_image;
|
|
let tools = InsightGenerator::build_tool_definitions(offer_describe_tool);
|
|
|
|
let image_base64: Option<String> = if offer_describe_tool {
|
|
self.generator.load_image_as_base64(&normalized).ok()
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Truncate before appending the new user turn.
|
|
let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize)
|
|
.saturating_sub(RESPONSE_HEADROOM_TOKENS);
|
|
let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN);
|
|
let truncated = apply_context_budget(&mut messages, budget_bytes);
|
|
if truncated {
|
|
let _ = tx.send(ChatStreamEvent::Truncated).await;
|
|
}
|
|
|
|
messages.push(ChatMessage::user(req.user_message.clone()));
|
|
|
|
let mut tool_calls_made = 0usize;
|
|
let mut iterations_used = 0usize;
|
|
let mut last_prompt_eval_count: Option<i32> = None;
|
|
let mut last_eval_count: Option<i32> = None;
|
|
let mut final_content = String::new();
|
|
|
|
for iteration in 0..max_iterations {
|
|
iterations_used = iteration + 1;
|
|
let _ = tx
|
|
.send(ChatStreamEvent::IterationStart {
|
|
n: iterations_used,
|
|
max: max_iterations,
|
|
})
|
|
.await;
|
|
|
|
let mut stream = chat_backend
|
|
.chat_with_tools_stream(messages.clone(), tools.clone())
|
|
.await?;
|
|
|
|
let mut final_message: Option<ChatMessage> = None;
|
|
while let Some(ev) = stream.next().await {
|
|
let ev = ev?;
|
|
match ev {
|
|
LlmStreamEvent::TextDelta(delta) => {
|
|
let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await;
|
|
}
|
|
LlmStreamEvent::Done {
|
|
message,
|
|
prompt_eval_count,
|
|
eval_count,
|
|
} => {
|
|
last_prompt_eval_count = prompt_eval_count;
|
|
last_eval_count = eval_count;
|
|
final_message = Some(message);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
let mut response =
|
|
final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?;
|
|
|
|
// Normalize non-object tool arguments (same as non-streaming path).
|
|
if let Some(ref mut tcs) = response.tool_calls {
|
|
for tc in tcs.iter_mut() {
|
|
if !tc.function.arguments.is_object() {
|
|
tc.function.arguments = serde_json::Value::Object(Default::default());
|
|
}
|
|
}
|
|
}
|
|
|
|
messages.push(response.clone());
|
|
|
|
if let Some(ref tool_calls) = response.tool_calls
|
|
&& !tool_calls.is_empty()
|
|
{
|
|
for (i, tool_call) in tool_calls.iter().enumerate() {
|
|
tool_calls_made += 1;
|
|
let call_index = tool_calls_made - 1;
|
|
let _ = tx
|
|
.send(ChatStreamEvent::ToolCall {
|
|
index: call_index,
|
|
name: tool_call.function.name.clone(),
|
|
arguments: tool_call.function.arguments.clone(),
|
|
})
|
|
.await;
|
|
let cx = opentelemetry::Context::new();
|
|
let result = self
|
|
.generator
|
|
.execute_tool(
|
|
&tool_call.function.name,
|
|
&tool_call.function.arguments,
|
|
&ollama_client,
|
|
&image_base64,
|
|
&normalized,
|
|
&cx,
|
|
)
|
|
.await;
|
|
let (result_preview, result_truncated) = truncate_tool_result(&result);
|
|
let _ = tx
|
|
.send(ChatStreamEvent::ToolResult {
|
|
index: call_index,
|
|
name: tool_call.function.name.clone(),
|
|
result: result_preview,
|
|
result_truncated,
|
|
})
|
|
.await;
|
|
messages.push(ChatMessage::tool_result(result));
|
|
let _ = i; // reserved for per-call ordering if needed
|
|
}
|
|
continue;
|
|
}
|
|
|
|
final_content = response.content;
|
|
break;
|
|
}
|
|
|
|
if final_content.is_empty() {
|
|
messages.push(ChatMessage::user(
|
|
"Please write your final answer now without calling any more tools.",
|
|
));
|
|
let mut stream = chat_backend
|
|
.chat_with_tools_stream(messages.clone(), vec![])
|
|
.await?;
|
|
let mut final_message: Option<ChatMessage> = None;
|
|
while let Some(ev) = stream.next().await {
|
|
let ev = ev?;
|
|
match ev {
|
|
LlmStreamEvent::TextDelta(delta) => {
|
|
let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await;
|
|
}
|
|
LlmStreamEvent::Done {
|
|
message,
|
|
prompt_eval_count,
|
|
eval_count,
|
|
} => {
|
|
last_prompt_eval_count = prompt_eval_count;
|
|
last_eval_count = eval_count;
|
|
final_message = Some(message);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
let final_response =
|
|
final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?;
|
|
final_content = final_response.content.clone();
|
|
messages.push(final_response);
|
|
}
|
|
|
|
// Persist.
|
|
let json = serde_json::to_string(&messages)
|
|
.map_err(|e| anyhow!("failed to serialize chat history: {}", e))?;
|
|
|
|
let mut amended_insight_id: Option<i32> = None;
|
|
if req.amend {
|
|
let title_prompt = format!(
|
|
"Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\
|
|
Capture the key moment or theme. Return ONLY the title, nothing else.",
|
|
final_content
|
|
);
|
|
let title_raw = chat_backend
|
|
.generate(
|
|
&title_prompt,
|
|
Some(
|
|
"You are my long term memory assistant. Use only the information provided. Do not invent details.",
|
|
),
|
|
None,
|
|
)
|
|
.await?;
|
|
let title = title_raw.trim().trim_matches('"').to_string();
|
|
|
|
let new_row = InsertPhotoInsight {
|
|
library_id: req.library_id,
|
|
file_path: normalized.clone(),
|
|
title,
|
|
summary: final_content.clone(),
|
|
generated_at: Utc::now().timestamp(),
|
|
model_version: model_used.clone(),
|
|
is_current: true,
|
|
training_messages: Some(json),
|
|
backend: effective_backend.clone(),
|
|
};
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
let stored = dao
|
|
.store_insight(&cx, new_row)
|
|
.map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?;
|
|
amended_insight_id = Some(stored.id);
|
|
} else {
|
|
let cx = opentelemetry::Context::new();
|
|
let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao");
|
|
dao.update_training_messages(&cx, req.library_id, &normalized, &json)
|
|
.map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?;
|
|
}
|
|
|
|
let _ = tx
|
|
.send(ChatStreamEvent::Done {
|
|
tool_calls_made,
|
|
iterations_used,
|
|
truncated,
|
|
prompt_eval_count: last_prompt_eval_count,
|
|
eval_count: last_eval_count,
|
|
amended_insight_id,
|
|
backend_used: effective_backend,
|
|
model_used,
|
|
})
|
|
.await;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Events emitted by `chat_turn_stream`. One stream per turn; ends after
|
|
/// `Done` or `Error`.
|
|
#[derive(Debug, Clone)]
|
|
pub enum ChatStreamEvent {
|
|
/// Starting iteration `n` of up to `max` (1-based).
|
|
IterationStart { n: usize, max: usize },
|
|
/// History was trimmed to fit the context budget before the turn ran.
|
|
/// Emitted at most once, before any tool or text events.
|
|
Truncated,
|
|
/// Incremental content from the final assistant reply. Concatenate to
|
|
/// reconstruct the reply body. Tool-dispatch turns don't produce these.
|
|
TextDelta(String),
|
|
/// The model requested this tool call. Emitted just before execution.
|
|
/// `index` is a monotonically-increasing counter across the turn so the
|
|
/// client can pair `ToolCall` with its matching `ToolResult`.
|
|
ToolCall {
|
|
index: usize,
|
|
name: String,
|
|
arguments: serde_json::Value,
|
|
},
|
|
/// The tool finished; `result` is the (possibly truncated) output.
|
|
ToolResult {
|
|
index: usize,
|
|
name: String,
|
|
result: String,
|
|
result_truncated: bool,
|
|
},
|
|
/// Terminal success event with counters + persistence result.
|
|
Done {
|
|
tool_calls_made: usize,
|
|
iterations_used: usize,
|
|
truncated: bool,
|
|
prompt_eval_count: Option<i32>,
|
|
eval_count: Option<i32>,
|
|
amended_insight_id: Option<i32>,
|
|
backend_used: String,
|
|
model_used: String,
|
|
},
|
|
/// Terminal failure event. No further events follow.
|
|
Error(String),
|
|
}
|
|
|
|
/// Is this raw message visible in the rendered transcript? Must match
|
|
/// `load_history`'s filter exactly — `find_raw_cut` depends on it to map
|
|
/// rendered indices back to raw positions.
|
|
fn is_rendered(m: &ChatMessage) -> bool {
|
|
match m.role.as_str() {
|
|
"user" => true,
|
|
"assistant" => {
|
|
let has_tool_calls = m
|
|
.tool_calls
|
|
.as_ref()
|
|
.map(|c| !c.is_empty())
|
|
.unwrap_or(false);
|
|
!(has_tool_calls && m.content.trim().is_empty())
|
|
}
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
/// Given a rendered index to start discarding from, find the raw index at
|
|
/// which to truncate. The cut position is the raw length after all prior
|
|
/// rendered messages — which also strips any tool-call scaffolding that
|
|
/// immediately precedes the discarded rendered message. Returns `None` if
|
|
/// `discard_from_rendered_index` is past the end of the rendered view.
|
|
pub(crate) fn find_raw_cut(
|
|
messages: &[ChatMessage],
|
|
discard_from_rendered_index: usize,
|
|
) -> Option<usize> {
|
|
let mut rendered_count = 0usize;
|
|
let mut last_kept_raw_end = 0usize;
|
|
for (i, m) in messages.iter().enumerate() {
|
|
if !is_rendered(m) {
|
|
continue;
|
|
}
|
|
if rendered_count == discard_from_rendered_index {
|
|
return Some(last_kept_raw_end);
|
|
}
|
|
rendered_count += 1;
|
|
last_kept_raw_end = i + 1;
|
|
}
|
|
if rendered_count == discard_from_rendered_index {
|
|
// Discarding past the last rendered message is a no-op, but we
|
|
// surface it as "nothing to cut" rather than silent success.
|
|
return None;
|
|
}
|
|
None
|
|
}
|
|
|
|
/// Read AGENTIC_CHAT_MAX_ITERATIONS once per call. Cheap; keeps the code
|
|
/// free of static globals and lets the operator change the cap by env without
|
|
/// a restart in test harnesses (the running server still caches via Default).
|
|
fn env_max_iterations() -> usize {
|
|
std::env::var("AGENTIC_CHAT_MAX_ITERATIONS")
|
|
.ok()
|
|
.and_then(|s| s.parse::<usize>().ok())
|
|
.unwrap_or(DEFAULT_MAX_ITERATIONS)
|
|
.max(1)
|
|
}
|
|
|
|
/// View returned to clients for chat-UI rendering.
|
|
#[derive(Debug)]
|
|
pub struct HistoryView {
|
|
pub messages: Vec<RenderedMessage>,
|
|
pub turn_count: usize,
|
|
pub model_version: String,
|
|
pub backend: String,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct RenderedMessage {
|
|
pub role: String,
|
|
pub content: String,
|
|
pub is_initial: bool,
|
|
/// Tools invoked during this turn (only populated for assistant replies).
|
|
/// Empty for user messages and for assistant replies that didn't involve
|
|
/// tool calls.
|
|
pub tools: Vec<ToolInvocation>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolInvocation {
|
|
pub name: String,
|
|
pub arguments: serde_json::Value,
|
|
pub result: String,
|
|
/// True when `result` was trimmed for payload size. Full value remains
|
|
/// available in the raw training_messages blob.
|
|
pub result_truncated: bool,
|
|
}
|
|
|
|
/// Soft cap for tool-result bodies returned via the history API. Keeps
|
|
/// payloads small for the mobile client — verbose SMS / geocoding responses
|
|
/// don't need to ship in full for inspection.
|
|
const TOOL_RESULT_PREVIEW_MAX: usize = 2000;
|
|
|
|
fn truncate_tool_result(s: &str) -> (String, bool) {
|
|
if s.len() <= TOOL_RESULT_PREVIEW_MAX {
|
|
(s.to_string(), false)
|
|
} else {
|
|
// Cut on a char boundary.
|
|
let mut cut = TOOL_RESULT_PREVIEW_MAX;
|
|
while !s.is_char_boundary(cut) && cut > 0 {
|
|
cut -= 1;
|
|
}
|
|
(s[..cut].to_string(), true)
|
|
}
|
|
}
|
|
|
|
/// Trim history to fit within `budget_bytes` of serialized JSON. Preserves
|
|
/// the system message and the first user message (with its base64 images
|
|
/// intact, since dropping those would invalidate the model's prior visual
|
|
/// reasoning). Drops the oldest assistant-tool_call + corresponding
|
|
/// tool-result pair on each pass until the budget is met or only the
|
|
/// preserved prefix remains.
|
|
///
|
|
/// Returns true when at least one message was dropped.
|
|
pub(crate) fn apply_context_budget(messages: &mut Vec<ChatMessage>, budget_bytes: usize) -> bool {
|
|
if budget_bytes == 0 {
|
|
return false;
|
|
}
|
|
if estimate_bytes(messages) <= budget_bytes {
|
|
return false;
|
|
}
|
|
|
|
// Find the index past the protected prefix: system messages + the first
|
|
// user message. Everything after is droppable in pairs.
|
|
let first_user_idx = messages.iter().position(|m| m.role == "user");
|
|
let preserve_through = match first_user_idx {
|
|
Some(i) => i, // keep [0..=i]
|
|
None => return false,
|
|
};
|
|
|
|
let mut dropped_any = false;
|
|
loop {
|
|
if estimate_bytes(messages) <= budget_bytes {
|
|
break;
|
|
}
|
|
// Find the oldest assistant-with-tool_calls strictly after the
|
|
// preserved prefix. Drop it together with the following tool turn(s)
|
|
// until we hit the next assistant or user turn.
|
|
let drop_start = (preserve_through + 1..messages.len()).find(|&i| {
|
|
let m = &messages[i];
|
|
m.role == "assistant"
|
|
&& m.tool_calls
|
|
.as_ref()
|
|
.map(|c| !c.is_empty())
|
|
.unwrap_or(false)
|
|
});
|
|
let Some(start) = drop_start else { break };
|
|
// Determine end: drop the assistant turn plus any contiguous tool
|
|
// result turns that follow.
|
|
let mut end = start + 1;
|
|
while end < messages.len() && messages[end].role == "tool" {
|
|
end += 1;
|
|
}
|
|
// Stop if dropping these would leave the just-appended user turn at
|
|
// the end alone with no preceding context — we still want it kept.
|
|
if end > messages.len() {
|
|
break;
|
|
}
|
|
messages.drain(start..end);
|
|
dropped_any = true;
|
|
}
|
|
|
|
dropped_any
|
|
}
|
|
|
|
fn estimate_bytes(messages: &[ChatMessage]) -> usize {
|
|
serde_json::to_string(messages)
|
|
.map(|s| s.len())
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::ai::llm_client::{ToolCall, ToolCallFunction};
|
|
|
|
fn assistant_with_tool_call(name: &str) -> ChatMessage {
|
|
ChatMessage {
|
|
role: "assistant".to_string(),
|
|
content: String::new(),
|
|
tool_calls: Some(vec![ToolCall {
|
|
id: None,
|
|
function: ToolCallFunction {
|
|
name: name.to_string(),
|
|
arguments: serde_json::Value::Object(Default::default()),
|
|
},
|
|
}]),
|
|
images: None,
|
|
}
|
|
}
|
|
|
|
fn assistant_text(text: &str) -> ChatMessage {
|
|
ChatMessage {
|
|
role: "assistant".to_string(),
|
|
content: text.to_string(),
|
|
tool_calls: None,
|
|
images: None,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn truncation_preserves_system_and_first_user() {
|
|
let mut msgs = vec![
|
|
ChatMessage::system("sys"),
|
|
ChatMessage::user("first user with lots of context".repeat(50)),
|
|
assistant_with_tool_call("get_x"),
|
|
ChatMessage::tool_result("x result ".repeat(200)),
|
|
assistant_with_tool_call("get_y"),
|
|
ChatMessage::tool_result("y result ".repeat(200)),
|
|
assistant_text("final answer"),
|
|
];
|
|
let original_len = msgs.len();
|
|
let dropped = apply_context_budget(&mut msgs, 500);
|
|
assert!(dropped, "should drop something at this small budget");
|
|
assert!(msgs.len() < original_len);
|
|
// First two messages preserved.
|
|
assert_eq!(msgs[0].role, "system");
|
|
assert_eq!(msgs[1].role, "user");
|
|
}
|
|
|
|
#[test]
|
|
fn truncation_no_op_when_under_budget() {
|
|
let mut msgs = vec![ChatMessage::system("s"), ChatMessage::user("u")];
|
|
let dropped = apply_context_budget(&mut msgs, 1_000_000);
|
|
assert!(!dropped);
|
|
assert_eq!(msgs.len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn truncation_returns_false_with_no_droppable_pairs() {
|
|
// Only system + user, no tool-call turns to drop.
|
|
let mut msgs = vec![ChatMessage::system("s"), ChatMessage::user("u")];
|
|
let dropped = apply_context_budget(&mut msgs, 1);
|
|
assert!(!dropped);
|
|
}
|
|
|
|
#[test]
|
|
fn rewind_strips_assistant_and_tool_scaffolding() {
|
|
// Rendered: [user1, asst1, user2, asst2] → cut at rendered index 3
|
|
// (the final asst2) should drop the tool-call scaffolding + asst2,
|
|
// leaving raw up through user2.
|
|
let msgs = vec![
|
|
ChatMessage::system("sys"),
|
|
ChatMessage::user("q1"),
|
|
assistant_text("a1"),
|
|
ChatMessage::user("q2"),
|
|
assistant_with_tool_call("lookup"),
|
|
ChatMessage::tool_result("data"),
|
|
assistant_text("a2 final"),
|
|
];
|
|
let cut = find_raw_cut(&msgs, 3).expect("cut found");
|
|
// raw[0..cut] should end at user("q2") — indices 0..=3.
|
|
assert_eq!(cut, 4);
|
|
assert_eq!(msgs[cut - 1].role, "user");
|
|
assert_eq!(msgs[cut - 1].content, "q2");
|
|
}
|
|
|
|
#[test]
|
|
fn rewind_at_second_rendered_cuts_after_first_user() {
|
|
// Rendered index 1 = the first assistant reply → dropping it should
|
|
// leave just the initial user message.
|
|
let msgs = vec![
|
|
ChatMessage::system("s"),
|
|
ChatMessage::user("q1"),
|
|
assistant_with_tool_call("tool"),
|
|
ChatMessage::tool_result("r"),
|
|
assistant_text("a1"),
|
|
];
|
|
let cut = find_raw_cut(&msgs, 1).expect("cut found");
|
|
assert_eq!(cut, 2); // sys + user("q1")
|
|
}
|
|
|
|
#[test]
|
|
fn rewind_beyond_range_returns_none() {
|
|
let msgs = vec![ChatMessage::user("q1"), assistant_text("a1")];
|
|
assert!(find_raw_cut(&msgs, 5).is_none());
|
|
}
|
|
}
|