Accumulate streamed tool calls across chunks in Ollama streaming

Ollama >=0.8 can stream tool_calls incrementally across NDJSON chunks;
chat_with_tools_stream did `tool_calls = Some(tcs)` per chunk, so only
the last chunk's calls survived assembly and earlier calls were silently
dropped. Append into the accumulator instead.

- ollama: append_streamed_tool_calls helper + tests covering two calls
  arriving in separate chunks and the single-chunk batch case.
- llamacpp: the SSE delta assembly was already correct (per-index
  BTreeMap, same-index argument fragments concatenate, distinct indexes
  accumulate); extracted it into apply_tool_call_deltas /
  finalize_tool_calls and added tests pinning that behavior.
- llm_client: new shared strip_think_blocks (moved from ollama's private
  extract_final_answer, which now delegates) so the tool-calling final
  content paths can reuse it; unit tests for tagged/plain/unclosed/empty
  cases.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Cameron Cordes
2026-06-09 18:29:06 -04:00
parent 8e4f91561b
commit 592dfcb42c
3 changed files with 223 additions and 66 deletions
+119 -52
View File
@@ -590,10 +590,7 @@ impl LlmClient for LlamaCppClient {
let mut byte_stream = byte_stream;
let mut buf: Vec<u8> = Vec::new();
let mut accumulated_content = String::new();
let mut tool_state: std::collections::BTreeMap<
usize,
(Option<String>, Option<String>, String),
> = std::collections::BTreeMap::new();
let mut tool_state = ToolCallAssembly::new();
let mut role = "assistant".to_string();
let mut prompt_tokens: Option<i32> = None;
let mut completion_tokens: Option<i32> = None;
@@ -670,32 +667,7 @@ impl LlmClient for LlamaCppClient {
yield Ok(LlmStreamEvent::TextDelta(content.to_string()));
}
if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) {
for tc_delta in tcs {
let idx = tc_delta
.get("index")
.and_then(|n| n.as_u64())
.unwrap_or(0) as usize;
let entry = tool_state
.entry(idx)
.or_insert((None, None, String::new()));
if let Some(id) =
tc_delta.get("id").and_then(|v| v.as_str())
{
entry.0 = Some(id.to_string());
}
if let Some(func) = tc_delta.get("function") {
if let Some(name) =
func.get("name").and_then(|v| v.as_str())
{
entry.1 = Some(name.to_string());
}
if let Some(args) =
func.get("arguments").and_then(|v| v.as_str())
{
entry.2.push_str(args);
}
}
}
apply_tool_call_deltas(&mut tool_state, tcs);
}
}
if done_seen {
@@ -707,28 +679,7 @@ impl LlmClient for LlamaCppClient {
}
}
let tool_calls: Option<Vec<ToolCall>> = if tool_state.is_empty() {
None
} else {
let mut v = Vec::with_capacity(tool_state.len());
for (_idx, (id, name, args)) in tool_state {
let arguments: Value = if args.trim().is_empty() {
Value::Object(Default::default())
} else {
serde_json::from_str(&args).unwrap_or_else(|_| {
Value::Object(Default::default())
})
};
v.push(ToolCall {
id,
function: ToolCallFunction {
name: name.unwrap_or_default(),
arguments,
},
});
}
Some(v)
};
let tool_calls = finalize_tool_calls(tool_state);
if let Some(ref frame) = last_frame {
log_timings(frame, prompt_tokens, completion_tokens);
@@ -937,6 +888,58 @@ fn extract_error_detail(parsed: &Value) -> String {
raw.chars().take(300).collect()
}
/// Per-index assembly state for streamed OpenAI-style tool-call deltas:
/// `index → (id, name, concatenated argument fragments)`. BTreeMap so the
/// finalized calls come out in index order.
type ToolCallAssembly = std::collections::BTreeMap<usize, (Option<String>, Option<String>, String)>;
/// Fold one SSE frame's `delta.tool_calls` array into the assembly state.
/// Deltas carrying the same `index` merge into one call (llama.cpp streams a
/// call's argument JSON in fragments — they concatenate); distinct indexes
/// accumulate as separate calls.
fn apply_tool_call_deltas(state: &mut ToolCallAssembly, tcs: &[Value]) {
for tc_delta in tcs {
let idx = tc_delta.get("index").and_then(|n| n.as_u64()).unwrap_or(0) as usize;
let entry = state.entry(idx).or_insert((None, None, String::new()));
if let Some(id) = tc_delta.get("id").and_then(|v| v.as_str()) {
entry.0 = Some(id.to_string());
}
if let Some(func) = tc_delta.get("function") {
if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
entry.1 = Some(name.to_string());
}
if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) {
entry.2.push_str(args);
}
}
}
}
/// Convert assembled tool-call state into canonical `ToolCall`s, parsing each
/// call's concatenated argument JSON (empty / malformed → `{}`). `None` when
/// no tool-call deltas arrived.
fn finalize_tool_calls(state: ToolCallAssembly) -> Option<Vec<ToolCall>> {
if state.is_empty() {
return None;
}
let mut v = Vec::with_capacity(state.len());
for (_idx, (id, name, args)) in state {
let arguments: Value = if args.trim().is_empty() {
Value::Object(Default::default())
} else {
serde_json::from_str(&args).unwrap_or_else(|_| Value::Object(Default::default()))
};
v.push(ToolCall {
id,
function: ToolCallFunction {
name: name.unwrap_or_default(),
arguments,
},
});
}
Some(v)
}
fn find_double_newline(buf: &[u8]) -> Option<usize> {
for i in 0..buf.len().saturating_sub(1) {
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
@@ -1302,4 +1305,68 @@ mod tests {
let c = LlamaCppClient::new(None, None);
assert_eq!(c.tts_model, "chatterbox");
}
#[test]
fn stream_assembly_keeps_two_tool_calls_from_separate_chunks() {
// llama.cpp emits one delta per SSE frame; two calls with distinct
// `index` values arriving in separate frames must BOTH survive.
let mut state = ToolCallAssembly::new();
apply_tool_call_deltas(
&mut state,
&[json!({
"index": 0,
"id": "call_a",
"function": { "name": "get_sms_messages", "arguments": "{\"date\":\"2019-01-01\"}" }
})],
);
apply_tool_call_deltas(
&mut state,
&[json!({
"index": 1,
"id": "call_b",
"function": { "name": "reverse_geocode", "arguments": "{\"latitude\":1.0,\"longitude\":2.0}" }
})],
);
let calls = finalize_tool_calls(state).expect("two calls assembled");
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].id.as_deref(), Some("call_a"));
assert_eq!(calls[0].function.name, "get_sms_messages");
assert_eq!(calls[0].function.arguments["date"], "2019-01-01");
assert_eq!(calls[1].id.as_deref(), Some("call_b"));
assert_eq!(calls[1].function.name, "reverse_geocode");
assert_eq!(calls[1].function.arguments["latitude"], 1.0);
}
#[test]
fn stream_assembly_concatenates_argument_fragments_for_same_index() {
// A single call's argument JSON streamed across frames concatenates
// into one parseable document.
let mut state = ToolCallAssembly::new();
apply_tool_call_deltas(
&mut state,
&[json!({
"index": 0,
"id": "call_x",
"function": { "name": "search_messages", "arguments": "{\"query\":" }
})],
);
apply_tool_call_deltas(
&mut state,
&[json!({
"index": 0,
"function": { "arguments": "\"dinner\"}" }
})],
);
let calls = finalize_tool_calls(state).expect("one call assembled");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "search_messages");
assert_eq!(calls[0].function.arguments["query"], "dinner");
}
#[test]
fn stream_assembly_empty_state_finalizes_to_none() {
assert!(finalize_tool_calls(ToolCallAssembly::new()).is_none());
}
}
+52
View File
@@ -170,3 +170,55 @@ pub struct ModelCapabilities {
pub has_vision: bool,
pub has_tool_calling: bool,
}
/// Strip a leading `<think>…</think>` reasoning block from model output.
///
/// Thinking models sometimes emit chain-of-thought inside think tags before
/// the real answer. Everything after the first `</think>` is the answer;
/// when no tag is present — or the text after it is empty — the trimmed
/// input is returned unchanged. Mirrors the behavior Ollama's
/// `extract_final_answer` has applied to single-shot generation; shared here
/// so the tool-calling final-content paths (agentic generation + chat) can
/// apply the identical cleanup before parsing / persisting.
pub fn strip_think_blocks(response: &str) -> String {
let response = response.trim();
if let Some(pos) = response.find("</think>") {
let answer = response[pos + "</think>".len()..].trim();
if !answer.is_empty() {
return answer.to_string();
}
}
response.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strip_think_blocks_removes_leading_think_block() {
let raw = "<think>\nLet me reason about this.\n</think>\n\nTitle: A Day Out\n\nThe body.";
assert_eq!(strip_think_blocks(raw), "Title: A Day Out\n\nThe body.");
}
#[test]
fn strip_think_blocks_passes_through_plain_content() {
assert_eq!(strip_think_blocks(" just an answer "), "just an answer");
}
#[test]
fn strip_think_blocks_keeps_content_when_answer_after_tag_is_empty() {
// A think block with nothing after it: better to return the trimmed
// original than an empty string (matches Ollama's fallback).
let raw = "<think>only thoughts</think>";
assert_eq!(strip_think_blocks(raw), raw);
}
#[test]
fn strip_think_blocks_handles_unclosed_tag() {
let raw = "<think>thinking forever";
assert_eq!(strip_think_blocks(raw), raw);
}
}
+52 -14
View File
@@ -360,18 +360,7 @@ impl OllamaClient {
/// Extract final answer from thinking model output
/// Handles <think>...</think> tags and takes everything after
fn extract_final_answer(&self, response: &str) -> String {
let response = response.trim();
// Look for </think> tag and take everything after it
if let Some(pos) = response.find("</think>") {
let answer = response[pos + 8..].trim();
if !answer.is_empty() {
return answer.to_string();
}
}
// Fallback: return the whole response trimmed
response.to_string()
crate::ai::llm_client::strip_think_blocks(response)
}
async fn try_generate(
@@ -846,11 +835,14 @@ Analyze the image and use specific details from both the visual content and the
if !chunk.message.role.is_empty() {
role = chunk.message.role;
}
// Ollama only attaches tool_calls on the final chunk.
// Ollama ≥0.8 can stream tool_calls incrementally
// across chunks (older servers attach them all to
// one chunk) — append rather than overwrite so
// calls from earlier chunks survive.
if let Some(tcs) = chunk.message.tool_calls
&& !tcs.is_empty()
{
tool_calls = Some(tcs);
append_streamed_tool_calls(&mut tool_calls, tcs);
}
if chunk.done {
prompt_eval_count = chunk.prompt_eval_count;
@@ -1329,8 +1321,20 @@ struct OllamaEmbedResponse {
embeddings: Vec<Vec<f32>>,
}
/// Accumulate tool calls streamed across NDJSON chunks. Ollama ≥0.8 may
/// emit each tool call on its own chunk; replacing the accumulator on every
/// chunk would keep only the last call, so extend instead.
fn append_streamed_tool_calls(
acc: &mut Option<Vec<crate::ai::llm_client::ToolCall>>,
new: Vec<crate::ai::llm_client::ToolCall>,
) {
acc.get_or_insert_with(Vec::new).extend(new);
}
#[cfg(test)]
mod tests {
use super::append_streamed_tool_calls;
use crate::ai::llm_client::{ToolCall, ToolCallFunction};
#[test]
fn generate_photo_description_prompt_is_concise() {
@@ -1341,4 +1345,38 @@ mod tests {
Focus on the people, location, and activity.";
assert!(prompt.len() < 200, "Prompt should be concise");
}
fn call(name: &str) -> ToolCall {
ToolCall {
id: None,
function: ToolCallFunction {
name: name.to_string(),
arguments: serde_json::json!({}),
},
}
}
#[test]
fn streamed_tool_calls_across_chunks_accumulate() {
// Two tool calls arriving in two separate stream chunks must BOTH
// survive assembly — the old `tool_calls = Some(tcs)` kept only the
// last chunk's calls.
let mut acc: Option<Vec<ToolCall>> = None;
append_streamed_tool_calls(&mut acc, vec![call("get_sms_messages")]);
append_streamed_tool_calls(&mut acc, vec![call("reverse_geocode")]);
let calls = acc.expect("tool calls accumulated");
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].function.name, "get_sms_messages");
assert_eq!(calls[1].function.name, "reverse_geocode");
}
#[test]
fn streamed_tool_calls_single_chunk_batch_kept_intact() {
// Older Ollama servers attach all calls to one chunk — unchanged.
let mut acc: Option<Vec<ToolCall>> = None;
append_streamed_tool_calls(&mut acc, vec![call("a"), call("b")]);
let calls = acc.expect("tool calls accumulated");
assert_eq!(calls.len(), 2);
}
}