diff --git a/src/ai/llamacpp.rs b/src/ai/llamacpp.rs index d56b645..6227e2f 100644 --- a/src/ai/llamacpp.rs +++ b/src/ai/llamacpp.rs @@ -590,10 +590,7 @@ impl LlmClient for LlamaCppClient { let mut byte_stream = byte_stream; let mut buf: Vec = Vec::new(); let mut accumulated_content = String::new(); - let mut tool_state: std::collections::BTreeMap< - usize, - (Option, Option, String), - > = std::collections::BTreeMap::new(); + let mut tool_state = ToolCallAssembly::new(); let mut role = "assistant".to_string(); let mut prompt_tokens: Option = None; let mut completion_tokens: Option = None; @@ -670,32 +667,7 @@ impl LlmClient for LlamaCppClient { yield Ok(LlmStreamEvent::TextDelta(content.to_string())); } if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) { - for tc_delta in tcs { - let idx = tc_delta - .get("index") - .and_then(|n| n.as_u64()) - .unwrap_or(0) as usize; - let entry = tool_state - .entry(idx) - .or_insert((None, None, String::new())); - if let Some(id) = - tc_delta.get("id").and_then(|v| v.as_str()) - { - entry.0 = Some(id.to_string()); - } - if let Some(func) = tc_delta.get("function") { - if let Some(name) = - func.get("name").and_then(|v| v.as_str()) - { - entry.1 = Some(name.to_string()); - } - if let Some(args) = - func.get("arguments").and_then(|v| v.as_str()) - { - entry.2.push_str(args); - } - } - } + apply_tool_call_deltas(&mut tool_state, tcs); } } if done_seen { @@ -707,28 +679,7 @@ impl LlmClient for LlamaCppClient { } } - let tool_calls: Option> = if tool_state.is_empty() { - None - } else { - let mut v = Vec::with_capacity(tool_state.len()); - for (_idx, (id, name, args)) in tool_state { - let arguments: Value = if args.trim().is_empty() { - Value::Object(Default::default()) - } else { - serde_json::from_str(&args).unwrap_or_else(|_| { - Value::Object(Default::default()) - }) - }; - v.push(ToolCall { - id, - function: ToolCallFunction { - name: name.unwrap_or_default(), - arguments, - }, - }); - } - Some(v) - }; + let tool_calls = finalize_tool_calls(tool_state); if let Some(ref frame) = last_frame { log_timings(frame, prompt_tokens, completion_tokens); @@ -937,6 +888,58 @@ fn extract_error_detail(parsed: &Value) -> String { raw.chars().take(300).collect() } +/// Per-index assembly state for streamed OpenAI-style tool-call deltas: +/// `index → (id, name, concatenated argument fragments)`. BTreeMap so the +/// finalized calls come out in index order. +type ToolCallAssembly = std::collections::BTreeMap, Option, String)>; + +/// Fold one SSE frame's `delta.tool_calls` array into the assembly state. +/// Deltas carrying the same `index` merge into one call (llama.cpp streams a +/// call's argument JSON in fragments — they concatenate); distinct indexes +/// accumulate as separate calls. +fn apply_tool_call_deltas(state: &mut ToolCallAssembly, tcs: &[Value]) { + for tc_delta in tcs { + let idx = tc_delta.get("index").and_then(|n| n.as_u64()).unwrap_or(0) as usize; + let entry = state.entry(idx).or_insert((None, None, String::new())); + if let Some(id) = tc_delta.get("id").and_then(|v| v.as_str()) { + entry.0 = Some(id.to_string()); + } + if let Some(func) = tc_delta.get("function") { + if let Some(name) = func.get("name").and_then(|v| v.as_str()) { + entry.1 = Some(name.to_string()); + } + if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) { + entry.2.push_str(args); + } + } + } +} + +/// Convert assembled tool-call state into canonical `ToolCall`s, parsing each +/// call's concatenated argument JSON (empty / malformed → `{}`). `None` when +/// no tool-call deltas arrived. +fn finalize_tool_calls(state: ToolCallAssembly) -> Option> { + if state.is_empty() { + return None; + } + let mut v = Vec::with_capacity(state.len()); + for (_idx, (id, name, args)) in state { + let arguments: Value = if args.trim().is_empty() { + Value::Object(Default::default()) + } else { + serde_json::from_str(&args).unwrap_or_else(|_| Value::Object(Default::default())) + }; + v.push(ToolCall { + id, + function: ToolCallFunction { + name: name.unwrap_or_default(), + arguments, + }, + }); + } + Some(v) +} + fn find_double_newline(buf: &[u8]) -> Option { for i in 0..buf.len().saturating_sub(1) { if buf[i] == b'\n' && buf[i + 1] == b'\n' { @@ -1302,4 +1305,68 @@ mod tests { let c = LlamaCppClient::new(None, None); assert_eq!(c.tts_model, "chatterbox"); } + + #[test] + fn stream_assembly_keeps_two_tool_calls_from_separate_chunks() { + // llama.cpp emits one delta per SSE frame; two calls with distinct + // `index` values arriving in separate frames must BOTH survive. + let mut state = ToolCallAssembly::new(); + apply_tool_call_deltas( + &mut state, + &[json!({ + "index": 0, + "id": "call_a", + "function": { "name": "get_sms_messages", "arguments": "{\"date\":\"2019-01-01\"}" } + })], + ); + apply_tool_call_deltas( + &mut state, + &[json!({ + "index": 1, + "id": "call_b", + "function": { "name": "reverse_geocode", "arguments": "{\"latitude\":1.0,\"longitude\":2.0}" } + })], + ); + + let calls = finalize_tool_calls(state).expect("two calls assembled"); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].id.as_deref(), Some("call_a")); + assert_eq!(calls[0].function.name, "get_sms_messages"); + assert_eq!(calls[0].function.arguments["date"], "2019-01-01"); + assert_eq!(calls[1].id.as_deref(), Some("call_b")); + assert_eq!(calls[1].function.name, "reverse_geocode"); + assert_eq!(calls[1].function.arguments["latitude"], 1.0); + } + + #[test] + fn stream_assembly_concatenates_argument_fragments_for_same_index() { + // A single call's argument JSON streamed across frames concatenates + // into one parseable document. + let mut state = ToolCallAssembly::new(); + apply_tool_call_deltas( + &mut state, + &[json!({ + "index": 0, + "id": "call_x", + "function": { "name": "search_messages", "arguments": "{\"query\":" } + })], + ); + apply_tool_call_deltas( + &mut state, + &[json!({ + "index": 0, + "function": { "arguments": "\"dinner\"}" } + })], + ); + + let calls = finalize_tool_calls(state).expect("one call assembled"); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "search_messages"); + assert_eq!(calls[0].function.arguments["query"], "dinner"); + } + + #[test] + fn stream_assembly_empty_state_finalizes_to_none() { + assert!(finalize_tool_calls(ToolCallAssembly::new()).is_none()); + } } diff --git a/src/ai/llm_client.rs b/src/ai/llm_client.rs index 8d68978..a50a6d8 100644 --- a/src/ai/llm_client.rs +++ b/src/ai/llm_client.rs @@ -170,3 +170,55 @@ pub struct ModelCapabilities { pub has_vision: bool, pub has_tool_calling: bool, } + +/// Strip a leading `` reasoning block from model output. +/// +/// Thinking models sometimes emit chain-of-thought inside think tags before +/// the real answer. Everything after the first `` is the answer; +/// when no tag is present — or the text after it is empty — the trimmed +/// input is returned unchanged. Mirrors the behavior Ollama's +/// `extract_final_answer` has applied to single-shot generation; shared here +/// so the tool-calling final-content paths (agentic generation + chat) can +/// apply the identical cleanup before parsing / persisting. +pub fn strip_think_blocks(response: &str) -> String { + let response = response.trim(); + + if let Some(pos) = response.find("") { + let answer = response[pos + "".len()..].trim(); + if !answer.is_empty() { + return answer.to_string(); + } + } + + response.to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn strip_think_blocks_removes_leading_think_block() { + let raw = "\nLet me reason about this.\n\n\nTitle: A Day Out\n\nThe body."; + assert_eq!(strip_think_blocks(raw), "Title: A Day Out\n\nThe body."); + } + + #[test] + fn strip_think_blocks_passes_through_plain_content() { + assert_eq!(strip_think_blocks(" just an answer "), "just an answer"); + } + + #[test] + fn strip_think_blocks_keeps_content_when_answer_after_tag_is_empty() { + // A think block with nothing after it: better to return the trimmed + // original than an empty string (matches Ollama's fallback). + let raw = "only thoughts"; + assert_eq!(strip_think_blocks(raw), raw); + } + + #[test] + fn strip_think_blocks_handles_unclosed_tag() { + let raw = "thinking forever"; + assert_eq!(strip_think_blocks(raw), raw); + } +} diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs index 680668f..75c8a02 100644 --- a/src/ai/ollama.rs +++ b/src/ai/ollama.rs @@ -360,18 +360,7 @@ impl OllamaClient { /// Extract final answer from thinking model output /// Handles ... tags and takes everything after fn extract_final_answer(&self, response: &str) -> String { - let response = response.trim(); - - // Look for tag and take everything after it - if let Some(pos) = response.find("") { - let answer = response[pos + 8..].trim(); - if !answer.is_empty() { - return answer.to_string(); - } - } - - // Fallback: return the whole response trimmed - response.to_string() + crate::ai::llm_client::strip_think_blocks(response) } async fn try_generate( @@ -846,11 +835,14 @@ Analyze the image and use specific details from both the visual content and the if !chunk.message.role.is_empty() { role = chunk.message.role; } - // Ollama only attaches tool_calls on the final chunk. + // Ollama ≥0.8 can stream tool_calls incrementally + // across chunks (older servers attach them all to + // one chunk) — append rather than overwrite so + // calls from earlier chunks survive. if let Some(tcs) = chunk.message.tool_calls && !tcs.is_empty() { - tool_calls = Some(tcs); + append_streamed_tool_calls(&mut tool_calls, tcs); } if chunk.done { prompt_eval_count = chunk.prompt_eval_count; @@ -1329,8 +1321,20 @@ struct OllamaEmbedResponse { embeddings: Vec>, } +/// Accumulate tool calls streamed across NDJSON chunks. Ollama ≥0.8 may +/// emit each tool call on its own chunk; replacing the accumulator on every +/// chunk would keep only the last call, so extend instead. +fn append_streamed_tool_calls( + acc: &mut Option>, + new: Vec, +) { + acc.get_or_insert_with(Vec::new).extend(new); +} + #[cfg(test)] mod tests { + use super::append_streamed_tool_calls; + use crate::ai::llm_client::{ToolCall, ToolCallFunction}; #[test] fn generate_photo_description_prompt_is_concise() { @@ -1341,4 +1345,38 @@ mod tests { Focus on the people, location, and activity."; assert!(prompt.len() < 200, "Prompt should be concise"); } + + fn call(name: &str) -> ToolCall { + ToolCall { + id: None, + function: ToolCallFunction { + name: name.to_string(), + arguments: serde_json::json!({}), + }, + } + } + + #[test] + fn streamed_tool_calls_across_chunks_accumulate() { + // Two tool calls arriving in two separate stream chunks must BOTH + // survive assembly — the old `tool_calls = Some(tcs)` kept only the + // last chunk's calls. + let mut acc: Option> = None; + append_streamed_tool_calls(&mut acc, vec![call("get_sms_messages")]); + append_streamed_tool_calls(&mut acc, vec![call("reverse_geocode")]); + + let calls = acc.expect("tool calls accumulated"); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].function.name, "get_sms_messages"); + assert_eq!(calls[1].function.name, "reverse_geocode"); + } + + #[test] + fn streamed_tool_calls_single_chunk_batch_kept_intact() { + // Older Ollama servers attach all calls to one chunk — unchanged. + let mut acc: Option> = None; + append_streamed_tool_calls(&mut acc, vec![call("a"), call("b")]); + let calls = acc.expect("tool calls accumulated"); + assert_eq!(calls.len(), 2); + } }