Accumulate streamed tool calls across chunks in Ollama streaming
Ollama >=0.8 can stream tool_calls incrementally across NDJSON chunks; chat_with_tools_stream did `tool_calls = Some(tcs)` per chunk, so only the last chunk's calls survived assembly and earlier calls were silently dropped. Append into the accumulator instead. - ollama: append_streamed_tool_calls helper + tests covering two calls arriving in separate chunks and the single-chunk batch case. - llamacpp: the SSE delta assembly was already correct (per-index BTreeMap, same-index argument fragments concatenate, distinct indexes accumulate); extracted it into apply_tool_call_deltas / finalize_tool_calls and added tests pinning that behavior. - llm_client: new shared strip_think_blocks (moved from ollama's private extract_final_answer, which now delegates) so the tool-calling final content paths can reuse it; unit tests for tagged/plain/unclosed/empty cases. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
+119
-52
@@ -590,10 +590,7 @@ impl LlmClient for LlamaCppClient {
|
||||
let mut byte_stream = byte_stream;
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut accumulated_content = String::new();
|
||||
let mut tool_state: std::collections::BTreeMap<
|
||||
usize,
|
||||
(Option<String>, Option<String>, String),
|
||||
> = std::collections::BTreeMap::new();
|
||||
let mut tool_state = ToolCallAssembly::new();
|
||||
let mut role = "assistant".to_string();
|
||||
let mut prompt_tokens: Option<i32> = None;
|
||||
let mut completion_tokens: Option<i32> = None;
|
||||
@@ -670,32 +667,7 @@ impl LlmClient for LlamaCppClient {
|
||||
yield Ok(LlmStreamEvent::TextDelta(content.to_string()));
|
||||
}
|
||||
if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) {
|
||||
for tc_delta in tcs {
|
||||
let idx = tc_delta
|
||||
.get("index")
|
||||
.and_then(|n| n.as_u64())
|
||||
.unwrap_or(0) as usize;
|
||||
let entry = tool_state
|
||||
.entry(idx)
|
||||
.or_insert((None, None, String::new()));
|
||||
if let Some(id) =
|
||||
tc_delta.get("id").and_then(|v| v.as_str())
|
||||
{
|
||||
entry.0 = Some(id.to_string());
|
||||
}
|
||||
if let Some(func) = tc_delta.get("function") {
|
||||
if let Some(name) =
|
||||
func.get("name").and_then(|v| v.as_str())
|
||||
{
|
||||
entry.1 = Some(name.to_string());
|
||||
}
|
||||
if let Some(args) =
|
||||
func.get("arguments").and_then(|v| v.as_str())
|
||||
{
|
||||
entry.2.push_str(args);
|
||||
}
|
||||
}
|
||||
}
|
||||
apply_tool_call_deltas(&mut tool_state, tcs);
|
||||
}
|
||||
}
|
||||
if done_seen {
|
||||
@@ -707,28 +679,7 @@ impl LlmClient for LlamaCppClient {
|
||||
}
|
||||
}
|
||||
|
||||
let tool_calls: Option<Vec<ToolCall>> = if tool_state.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut v = Vec::with_capacity(tool_state.len());
|
||||
for (_idx, (id, name, args)) in tool_state {
|
||||
let arguments: Value = if args.trim().is_empty() {
|
||||
Value::Object(Default::default())
|
||||
} else {
|
||||
serde_json::from_str(&args).unwrap_or_else(|_| {
|
||||
Value::Object(Default::default())
|
||||
})
|
||||
};
|
||||
v.push(ToolCall {
|
||||
id,
|
||||
function: ToolCallFunction {
|
||||
name: name.unwrap_or_default(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
Some(v)
|
||||
};
|
||||
let tool_calls = finalize_tool_calls(tool_state);
|
||||
|
||||
if let Some(ref frame) = last_frame {
|
||||
log_timings(frame, prompt_tokens, completion_tokens);
|
||||
@@ -937,6 +888,58 @@ fn extract_error_detail(parsed: &Value) -> String {
|
||||
raw.chars().take(300).collect()
|
||||
}
|
||||
|
||||
/// Per-index assembly state for streamed OpenAI-style tool-call deltas:
|
||||
/// `index → (id, name, concatenated argument fragments)`. BTreeMap so the
|
||||
/// finalized calls come out in index order.
|
||||
type ToolCallAssembly = std::collections::BTreeMap<usize, (Option<String>, Option<String>, String)>;
|
||||
|
||||
/// Fold one SSE frame's `delta.tool_calls` array into the assembly state.
|
||||
/// Deltas carrying the same `index` merge into one call (llama.cpp streams a
|
||||
/// call's argument JSON in fragments — they concatenate); distinct indexes
|
||||
/// accumulate as separate calls.
|
||||
fn apply_tool_call_deltas(state: &mut ToolCallAssembly, tcs: &[Value]) {
|
||||
for tc_delta in tcs {
|
||||
let idx = tc_delta.get("index").and_then(|n| n.as_u64()).unwrap_or(0) as usize;
|
||||
let entry = state.entry(idx).or_insert((None, None, String::new()));
|
||||
if let Some(id) = tc_delta.get("id").and_then(|v| v.as_str()) {
|
||||
entry.0 = Some(id.to_string());
|
||||
}
|
||||
if let Some(func) = tc_delta.get("function") {
|
||||
if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
|
||||
entry.1 = Some(name.to_string());
|
||||
}
|
||||
if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) {
|
||||
entry.2.push_str(args);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert assembled tool-call state into canonical `ToolCall`s, parsing each
|
||||
/// call's concatenated argument JSON (empty / malformed → `{}`). `None` when
|
||||
/// no tool-call deltas arrived.
|
||||
fn finalize_tool_calls(state: ToolCallAssembly) -> Option<Vec<ToolCall>> {
|
||||
if state.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut v = Vec::with_capacity(state.len());
|
||||
for (_idx, (id, name, args)) in state {
|
||||
let arguments: Value = if args.trim().is_empty() {
|
||||
Value::Object(Default::default())
|
||||
} else {
|
||||
serde_json::from_str(&args).unwrap_or_else(|_| Value::Object(Default::default()))
|
||||
};
|
||||
v.push(ToolCall {
|
||||
id,
|
||||
function: ToolCallFunction {
|
||||
name: name.unwrap_or_default(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
Some(v)
|
||||
}
|
||||
|
||||
fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
||||
for i in 0..buf.len().saturating_sub(1) {
|
||||
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
|
||||
@@ -1302,4 +1305,68 @@ mod tests {
|
||||
let c = LlamaCppClient::new(None, None);
|
||||
assert_eq!(c.tts_model, "chatterbox");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_assembly_keeps_two_tool_calls_from_separate_chunks() {
|
||||
// llama.cpp emits one delta per SSE frame; two calls with distinct
|
||||
// `index` values arriving in separate frames must BOTH survive.
|
||||
let mut state = ToolCallAssembly::new();
|
||||
apply_tool_call_deltas(
|
||||
&mut state,
|
||||
&[json!({
|
||||
"index": 0,
|
||||
"id": "call_a",
|
||||
"function": { "name": "get_sms_messages", "arguments": "{\"date\":\"2019-01-01\"}" }
|
||||
})],
|
||||
);
|
||||
apply_tool_call_deltas(
|
||||
&mut state,
|
||||
&[json!({
|
||||
"index": 1,
|
||||
"id": "call_b",
|
||||
"function": { "name": "reverse_geocode", "arguments": "{\"latitude\":1.0,\"longitude\":2.0}" }
|
||||
})],
|
||||
);
|
||||
|
||||
let calls = finalize_tool_calls(state).expect("two calls assembled");
|
||||
assert_eq!(calls.len(), 2);
|
||||
assert_eq!(calls[0].id.as_deref(), Some("call_a"));
|
||||
assert_eq!(calls[0].function.name, "get_sms_messages");
|
||||
assert_eq!(calls[0].function.arguments["date"], "2019-01-01");
|
||||
assert_eq!(calls[1].id.as_deref(), Some("call_b"));
|
||||
assert_eq!(calls[1].function.name, "reverse_geocode");
|
||||
assert_eq!(calls[1].function.arguments["latitude"], 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_assembly_concatenates_argument_fragments_for_same_index() {
|
||||
// A single call's argument JSON streamed across frames concatenates
|
||||
// into one parseable document.
|
||||
let mut state = ToolCallAssembly::new();
|
||||
apply_tool_call_deltas(
|
||||
&mut state,
|
||||
&[json!({
|
||||
"index": 0,
|
||||
"id": "call_x",
|
||||
"function": { "name": "search_messages", "arguments": "{\"query\":" }
|
||||
})],
|
||||
);
|
||||
apply_tool_call_deltas(
|
||||
&mut state,
|
||||
&[json!({
|
||||
"index": 0,
|
||||
"function": { "arguments": "\"dinner\"}" }
|
||||
})],
|
||||
);
|
||||
|
||||
let calls = finalize_tool_calls(state).expect("one call assembled");
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].function.name, "search_messages");
|
||||
assert_eq!(calls[0].function.arguments["query"], "dinner");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_assembly_empty_state_finalizes_to_none() {
|
||||
assert!(finalize_tool_calls(ToolCallAssembly::new()).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,3 +170,55 @@ pub struct ModelCapabilities {
|
||||
pub has_vision: bool,
|
||||
pub has_tool_calling: bool,
|
||||
}
|
||||
|
||||
/// Strip a leading `<think>…</think>` reasoning block from model output.
|
||||
///
|
||||
/// Thinking models sometimes emit chain-of-thought inside think tags before
|
||||
/// the real answer. Everything after the first `</think>` is the answer;
|
||||
/// when no tag is present — or the text after it is empty — the trimmed
|
||||
/// input is returned unchanged. Mirrors the behavior Ollama's
|
||||
/// `extract_final_answer` has applied to single-shot generation; shared here
|
||||
/// so the tool-calling final-content paths (agentic generation + chat) can
|
||||
/// apply the identical cleanup before parsing / persisting.
|
||||
pub fn strip_think_blocks(response: &str) -> String {
|
||||
let response = response.trim();
|
||||
|
||||
if let Some(pos) = response.find("</think>") {
|
||||
let answer = response[pos + "</think>".len()..].trim();
|
||||
if !answer.is_empty() {
|
||||
return answer.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
response.to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn strip_think_blocks_removes_leading_think_block() {
|
||||
let raw = "<think>\nLet me reason about this.\n</think>\n\nTitle: A Day Out\n\nThe body.";
|
||||
assert_eq!(strip_think_blocks(raw), "Title: A Day Out\n\nThe body.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_think_blocks_passes_through_plain_content() {
|
||||
assert_eq!(strip_think_blocks(" just an answer "), "just an answer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_think_blocks_keeps_content_when_answer_after_tag_is_empty() {
|
||||
// A think block with nothing after it: better to return the trimmed
|
||||
// original than an empty string (matches Ollama's fallback).
|
||||
let raw = "<think>only thoughts</think>";
|
||||
assert_eq!(strip_think_blocks(raw), raw);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_think_blocks_handles_unclosed_tag() {
|
||||
let raw = "<think>thinking forever";
|
||||
assert_eq!(strip_think_blocks(raw), raw);
|
||||
}
|
||||
}
|
||||
|
||||
+52
-14
@@ -360,18 +360,7 @@ impl OllamaClient {
|
||||
/// Extract final answer from thinking model output
|
||||
/// Handles <think>...</think> tags and takes everything after
|
||||
fn extract_final_answer(&self, response: &str) -> String {
|
||||
let response = response.trim();
|
||||
|
||||
// Look for </think> tag and take everything after it
|
||||
if let Some(pos) = response.find("</think>") {
|
||||
let answer = response[pos + 8..].trim();
|
||||
if !answer.is_empty() {
|
||||
return answer.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: return the whole response trimmed
|
||||
response.to_string()
|
||||
crate::ai::llm_client::strip_think_blocks(response)
|
||||
}
|
||||
|
||||
async fn try_generate(
|
||||
@@ -846,11 +835,14 @@ Analyze the image and use specific details from both the visual content and the
|
||||
if !chunk.message.role.is_empty() {
|
||||
role = chunk.message.role;
|
||||
}
|
||||
// Ollama only attaches tool_calls on the final chunk.
|
||||
// Ollama ≥0.8 can stream tool_calls incrementally
|
||||
// across chunks (older servers attach them all to
|
||||
// one chunk) — append rather than overwrite so
|
||||
// calls from earlier chunks survive.
|
||||
if let Some(tcs) = chunk.message.tool_calls
|
||||
&& !tcs.is_empty()
|
||||
{
|
||||
tool_calls = Some(tcs);
|
||||
append_streamed_tool_calls(&mut tool_calls, tcs);
|
||||
}
|
||||
if chunk.done {
|
||||
prompt_eval_count = chunk.prompt_eval_count;
|
||||
@@ -1329,8 +1321,20 @@ struct OllamaEmbedResponse {
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
/// Accumulate tool calls streamed across NDJSON chunks. Ollama ≥0.8 may
|
||||
/// emit each tool call on its own chunk; replacing the accumulator on every
|
||||
/// chunk would keep only the last call, so extend instead.
|
||||
fn append_streamed_tool_calls(
|
||||
acc: &mut Option<Vec<crate::ai::llm_client::ToolCall>>,
|
||||
new: Vec<crate::ai::llm_client::ToolCall>,
|
||||
) {
|
||||
acc.get_or_insert_with(Vec::new).extend(new);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::append_streamed_tool_calls;
|
||||
use crate::ai::llm_client::{ToolCall, ToolCallFunction};
|
||||
|
||||
#[test]
|
||||
fn generate_photo_description_prompt_is_concise() {
|
||||
@@ -1341,4 +1345,38 @@ mod tests {
|
||||
Focus on the people, location, and activity.";
|
||||
assert!(prompt.len() < 200, "Prompt should be concise");
|
||||
}
|
||||
|
||||
fn call(name: &str) -> ToolCall {
|
||||
ToolCall {
|
||||
id: None,
|
||||
function: ToolCallFunction {
|
||||
name: name.to_string(),
|
||||
arguments: serde_json::json!({}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streamed_tool_calls_across_chunks_accumulate() {
|
||||
// Two tool calls arriving in two separate stream chunks must BOTH
|
||||
// survive assembly — the old `tool_calls = Some(tcs)` kept only the
|
||||
// last chunk's calls.
|
||||
let mut acc: Option<Vec<ToolCall>> = None;
|
||||
append_streamed_tool_calls(&mut acc, vec![call("get_sms_messages")]);
|
||||
append_streamed_tool_calls(&mut acc, vec![call("reverse_geocode")]);
|
||||
|
||||
let calls = acc.expect("tool calls accumulated");
|
||||
assert_eq!(calls.len(), 2);
|
||||
assert_eq!(calls[0].function.name, "get_sms_messages");
|
||||
assert_eq!(calls[1].function.name, "reverse_geocode");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streamed_tool_calls_single_chunk_batch_kept_intact() {
|
||||
// Older Ollama servers attach all calls to one chunk — unchanged.
|
||||
let mut acc: Option<Vec<ToolCall>> = None;
|
||||
append_streamed_tool_calls(&mut acc, vec![call("a"), call("b")]);
|
||||
let calls = acc.expect("tool calls accumulated");
|
||||
assert_eq!(calls.len(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user