Accumulate streamed tool calls across chunks in Ollama streaming
Ollama >=0.8 can stream tool_calls incrementally across NDJSON chunks; chat_with_tools_stream did `tool_calls = Some(tcs)` per chunk, so only the last chunk's calls survived assembly and earlier calls were silently dropped. Append into the accumulator instead. - ollama: append_streamed_tool_calls helper + tests covering two calls arriving in separate chunks and the single-chunk batch case. - llamacpp: the SSE delta assembly was already correct (per-index BTreeMap, same-index argument fragments concatenate, distinct indexes accumulate); extracted it into apply_tool_call_deltas / finalize_tool_calls and added tests pinning that behavior. - llm_client: new shared strip_think_blocks (moved from ollama's private extract_final_answer, which now delegates) so the tool-calling final content paths can reuse it; unit tests for tagged/plain/unclosed/empty cases. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
+119
-52
@@ -590,10 +590,7 @@ impl LlmClient for LlamaCppClient {
|
|||||||
let mut byte_stream = byte_stream;
|
let mut byte_stream = byte_stream;
|
||||||
let mut buf: Vec<u8> = Vec::new();
|
let mut buf: Vec<u8> = Vec::new();
|
||||||
let mut accumulated_content = String::new();
|
let mut accumulated_content = String::new();
|
||||||
let mut tool_state: std::collections::BTreeMap<
|
let mut tool_state = ToolCallAssembly::new();
|
||||||
usize,
|
|
||||||
(Option<String>, Option<String>, String),
|
|
||||||
> = std::collections::BTreeMap::new();
|
|
||||||
let mut role = "assistant".to_string();
|
let mut role = "assistant".to_string();
|
||||||
let mut prompt_tokens: Option<i32> = None;
|
let mut prompt_tokens: Option<i32> = None;
|
||||||
let mut completion_tokens: Option<i32> = None;
|
let mut completion_tokens: Option<i32> = None;
|
||||||
@@ -670,32 +667,7 @@ impl LlmClient for LlamaCppClient {
|
|||||||
yield Ok(LlmStreamEvent::TextDelta(content.to_string()));
|
yield Ok(LlmStreamEvent::TextDelta(content.to_string()));
|
||||||
}
|
}
|
||||||
if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) {
|
if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) {
|
||||||
for tc_delta in tcs {
|
apply_tool_call_deltas(&mut tool_state, tcs);
|
||||||
let idx = tc_delta
|
|
||||||
.get("index")
|
|
||||||
.and_then(|n| n.as_u64())
|
|
||||||
.unwrap_or(0) as usize;
|
|
||||||
let entry = tool_state
|
|
||||||
.entry(idx)
|
|
||||||
.or_insert((None, None, String::new()));
|
|
||||||
if let Some(id) =
|
|
||||||
tc_delta.get("id").and_then(|v| v.as_str())
|
|
||||||
{
|
|
||||||
entry.0 = Some(id.to_string());
|
|
||||||
}
|
|
||||||
if let Some(func) = tc_delta.get("function") {
|
|
||||||
if let Some(name) =
|
|
||||||
func.get("name").and_then(|v| v.as_str())
|
|
||||||
{
|
|
||||||
entry.1 = Some(name.to_string());
|
|
||||||
}
|
|
||||||
if let Some(args) =
|
|
||||||
func.get("arguments").and_then(|v| v.as_str())
|
|
||||||
{
|
|
||||||
entry.2.push_str(args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if done_seen {
|
if done_seen {
|
||||||
@@ -707,28 +679,7 @@ impl LlmClient for LlamaCppClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let tool_calls: Option<Vec<ToolCall>> = if tool_state.is_empty() {
|
let tool_calls = finalize_tool_calls(tool_state);
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let mut v = Vec::with_capacity(tool_state.len());
|
|
||||||
for (_idx, (id, name, args)) in tool_state {
|
|
||||||
let arguments: Value = if args.trim().is_empty() {
|
|
||||||
Value::Object(Default::default())
|
|
||||||
} else {
|
|
||||||
serde_json::from_str(&args).unwrap_or_else(|_| {
|
|
||||||
Value::Object(Default::default())
|
|
||||||
})
|
|
||||||
};
|
|
||||||
v.push(ToolCall {
|
|
||||||
id,
|
|
||||||
function: ToolCallFunction {
|
|
||||||
name: name.unwrap_or_default(),
|
|
||||||
arguments,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Some(v)
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(ref frame) = last_frame {
|
if let Some(ref frame) = last_frame {
|
||||||
log_timings(frame, prompt_tokens, completion_tokens);
|
log_timings(frame, prompt_tokens, completion_tokens);
|
||||||
@@ -937,6 +888,58 @@ fn extract_error_detail(parsed: &Value) -> String {
|
|||||||
raw.chars().take(300).collect()
|
raw.chars().take(300).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Per-index assembly state for streamed OpenAI-style tool-call deltas:
|
||||||
|
/// `index → (id, name, concatenated argument fragments)`. BTreeMap so the
|
||||||
|
/// finalized calls come out in index order.
|
||||||
|
type ToolCallAssembly = std::collections::BTreeMap<usize, (Option<String>, Option<String>, String)>;
|
||||||
|
|
||||||
|
/// Fold one SSE frame's `delta.tool_calls` array into the assembly state.
|
||||||
|
/// Deltas carrying the same `index` merge into one call (llama.cpp streams a
|
||||||
|
/// call's argument JSON in fragments — they concatenate); distinct indexes
|
||||||
|
/// accumulate as separate calls.
|
||||||
|
fn apply_tool_call_deltas(state: &mut ToolCallAssembly, tcs: &[Value]) {
|
||||||
|
for tc_delta in tcs {
|
||||||
|
let idx = tc_delta.get("index").and_then(|n| n.as_u64()).unwrap_or(0) as usize;
|
||||||
|
let entry = state.entry(idx).or_insert((None, None, String::new()));
|
||||||
|
if let Some(id) = tc_delta.get("id").and_then(|v| v.as_str()) {
|
||||||
|
entry.0 = Some(id.to_string());
|
||||||
|
}
|
||||||
|
if let Some(func) = tc_delta.get("function") {
|
||||||
|
if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
|
||||||
|
entry.1 = Some(name.to_string());
|
||||||
|
}
|
||||||
|
if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) {
|
||||||
|
entry.2.push_str(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert assembled tool-call state into canonical `ToolCall`s, parsing each
|
||||||
|
/// call's concatenated argument JSON (empty / malformed → `{}`). `None` when
|
||||||
|
/// no tool-call deltas arrived.
|
||||||
|
fn finalize_tool_calls(state: ToolCallAssembly) -> Option<Vec<ToolCall>> {
|
||||||
|
if state.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let mut v = Vec::with_capacity(state.len());
|
||||||
|
for (_idx, (id, name, args)) in state {
|
||||||
|
let arguments: Value = if args.trim().is_empty() {
|
||||||
|
Value::Object(Default::default())
|
||||||
|
} else {
|
||||||
|
serde_json::from_str(&args).unwrap_or_else(|_| Value::Object(Default::default()))
|
||||||
|
};
|
||||||
|
v.push(ToolCall {
|
||||||
|
id,
|
||||||
|
function: ToolCallFunction {
|
||||||
|
name: name.unwrap_or_default(),
|
||||||
|
arguments,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Some(v)
|
||||||
|
}
|
||||||
|
|
||||||
fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
fn find_double_newline(buf: &[u8]) -> Option<usize> {
|
||||||
for i in 0..buf.len().saturating_sub(1) {
|
for i in 0..buf.len().saturating_sub(1) {
|
||||||
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
|
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
|
||||||
@@ -1302,4 +1305,68 @@ mod tests {
|
|||||||
let c = LlamaCppClient::new(None, None);
|
let c = LlamaCppClient::new(None, None);
|
||||||
assert_eq!(c.tts_model, "chatterbox");
|
assert_eq!(c.tts_model, "chatterbox");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stream_assembly_keeps_two_tool_calls_from_separate_chunks() {
|
||||||
|
// llama.cpp emits one delta per SSE frame; two calls with distinct
|
||||||
|
// `index` values arriving in separate frames must BOTH survive.
|
||||||
|
let mut state = ToolCallAssembly::new();
|
||||||
|
apply_tool_call_deltas(
|
||||||
|
&mut state,
|
||||||
|
&[json!({
|
||||||
|
"index": 0,
|
||||||
|
"id": "call_a",
|
||||||
|
"function": { "name": "get_sms_messages", "arguments": "{\"date\":\"2019-01-01\"}" }
|
||||||
|
})],
|
||||||
|
);
|
||||||
|
apply_tool_call_deltas(
|
||||||
|
&mut state,
|
||||||
|
&[json!({
|
||||||
|
"index": 1,
|
||||||
|
"id": "call_b",
|
||||||
|
"function": { "name": "reverse_geocode", "arguments": "{\"latitude\":1.0,\"longitude\":2.0}" }
|
||||||
|
})],
|
||||||
|
);
|
||||||
|
|
||||||
|
let calls = finalize_tool_calls(state).expect("two calls assembled");
|
||||||
|
assert_eq!(calls.len(), 2);
|
||||||
|
assert_eq!(calls[0].id.as_deref(), Some("call_a"));
|
||||||
|
assert_eq!(calls[0].function.name, "get_sms_messages");
|
||||||
|
assert_eq!(calls[0].function.arguments["date"], "2019-01-01");
|
||||||
|
assert_eq!(calls[1].id.as_deref(), Some("call_b"));
|
||||||
|
assert_eq!(calls[1].function.name, "reverse_geocode");
|
||||||
|
assert_eq!(calls[1].function.arguments["latitude"], 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stream_assembly_concatenates_argument_fragments_for_same_index() {
|
||||||
|
// A single call's argument JSON streamed across frames concatenates
|
||||||
|
// into one parseable document.
|
||||||
|
let mut state = ToolCallAssembly::new();
|
||||||
|
apply_tool_call_deltas(
|
||||||
|
&mut state,
|
||||||
|
&[json!({
|
||||||
|
"index": 0,
|
||||||
|
"id": "call_x",
|
||||||
|
"function": { "name": "search_messages", "arguments": "{\"query\":" }
|
||||||
|
})],
|
||||||
|
);
|
||||||
|
apply_tool_call_deltas(
|
||||||
|
&mut state,
|
||||||
|
&[json!({
|
||||||
|
"index": 0,
|
||||||
|
"function": { "arguments": "\"dinner\"}" }
|
||||||
|
})],
|
||||||
|
);
|
||||||
|
|
||||||
|
let calls = finalize_tool_calls(state).expect("one call assembled");
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].function.name, "search_messages");
|
||||||
|
assert_eq!(calls[0].function.arguments["query"], "dinner");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stream_assembly_empty_state_finalizes_to_none() {
|
||||||
|
assert!(finalize_tool_calls(ToolCallAssembly::new()).is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -170,3 +170,55 @@ pub struct ModelCapabilities {
|
|||||||
pub has_vision: bool,
|
pub has_vision: bool,
|
||||||
pub has_tool_calling: bool,
|
pub has_tool_calling: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Strip a leading `<think>…</think>` reasoning block from model output.
|
||||||
|
///
|
||||||
|
/// Thinking models sometimes emit chain-of-thought inside think tags before
|
||||||
|
/// the real answer. Everything after the first `</think>` is the answer;
|
||||||
|
/// when no tag is present — or the text after it is empty — the trimmed
|
||||||
|
/// input is returned unchanged. Mirrors the behavior Ollama's
|
||||||
|
/// `extract_final_answer` has applied to single-shot generation; shared here
|
||||||
|
/// so the tool-calling final-content paths (agentic generation + chat) can
|
||||||
|
/// apply the identical cleanup before parsing / persisting.
|
||||||
|
pub fn strip_think_blocks(response: &str) -> String {
|
||||||
|
let response = response.trim();
|
||||||
|
|
||||||
|
if let Some(pos) = response.find("</think>") {
|
||||||
|
let answer = response[pos + "</think>".len()..].trim();
|
||||||
|
if !answer.is_empty() {
|
||||||
|
return answer.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strip_think_blocks_removes_leading_think_block() {
|
||||||
|
let raw = "<think>\nLet me reason about this.\n</think>\n\nTitle: A Day Out\n\nThe body.";
|
||||||
|
assert_eq!(strip_think_blocks(raw), "Title: A Day Out\n\nThe body.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strip_think_blocks_passes_through_plain_content() {
|
||||||
|
assert_eq!(strip_think_blocks(" just an answer "), "just an answer");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strip_think_blocks_keeps_content_when_answer_after_tag_is_empty() {
|
||||||
|
// A think block with nothing after it: better to return the trimmed
|
||||||
|
// original than an empty string (matches Ollama's fallback).
|
||||||
|
let raw = "<think>only thoughts</think>";
|
||||||
|
assert_eq!(strip_think_blocks(raw), raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strip_think_blocks_handles_unclosed_tag() {
|
||||||
|
let raw = "<think>thinking forever";
|
||||||
|
assert_eq!(strip_think_blocks(raw), raw);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+52
-14
@@ -360,18 +360,7 @@ impl OllamaClient {
|
|||||||
/// Extract final answer from thinking model output
|
/// Extract final answer from thinking model output
|
||||||
/// Handles <think>...</think> tags and takes everything after
|
/// Handles <think>...</think> tags and takes everything after
|
||||||
fn extract_final_answer(&self, response: &str) -> String {
|
fn extract_final_answer(&self, response: &str) -> String {
|
||||||
let response = response.trim();
|
crate::ai::llm_client::strip_think_blocks(response)
|
||||||
|
|
||||||
// Look for </think> tag and take everything after it
|
|
||||||
if let Some(pos) = response.find("</think>") {
|
|
||||||
let answer = response[pos + 8..].trim();
|
|
||||||
if !answer.is_empty() {
|
|
||||||
return answer.to_string();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: return the whole response trimmed
|
|
||||||
response.to_string()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn try_generate(
|
async fn try_generate(
|
||||||
@@ -846,11 +835,14 @@ Analyze the image and use specific details from both the visual content and the
|
|||||||
if !chunk.message.role.is_empty() {
|
if !chunk.message.role.is_empty() {
|
||||||
role = chunk.message.role;
|
role = chunk.message.role;
|
||||||
}
|
}
|
||||||
// Ollama only attaches tool_calls on the final chunk.
|
// Ollama ≥0.8 can stream tool_calls incrementally
|
||||||
|
// across chunks (older servers attach them all to
|
||||||
|
// one chunk) — append rather than overwrite so
|
||||||
|
// calls from earlier chunks survive.
|
||||||
if let Some(tcs) = chunk.message.tool_calls
|
if let Some(tcs) = chunk.message.tool_calls
|
||||||
&& !tcs.is_empty()
|
&& !tcs.is_empty()
|
||||||
{
|
{
|
||||||
tool_calls = Some(tcs);
|
append_streamed_tool_calls(&mut tool_calls, tcs);
|
||||||
}
|
}
|
||||||
if chunk.done {
|
if chunk.done {
|
||||||
prompt_eval_count = chunk.prompt_eval_count;
|
prompt_eval_count = chunk.prompt_eval_count;
|
||||||
@@ -1329,8 +1321,20 @@ struct OllamaEmbedResponse {
|
|||||||
embeddings: Vec<Vec<f32>>,
|
embeddings: Vec<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Accumulate tool calls streamed across NDJSON chunks. Ollama ≥0.8 may
|
||||||
|
/// emit each tool call on its own chunk; replacing the accumulator on every
|
||||||
|
/// chunk would keep only the last call, so extend instead.
|
||||||
|
fn append_streamed_tool_calls(
|
||||||
|
acc: &mut Option<Vec<crate::ai::llm_client::ToolCall>>,
|
||||||
|
new: Vec<crate::ai::llm_client::ToolCall>,
|
||||||
|
) {
|
||||||
|
acc.get_or_insert_with(Vec::new).extend(new);
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use super::append_streamed_tool_calls;
|
||||||
|
use crate::ai::llm_client::{ToolCall, ToolCallFunction};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn generate_photo_description_prompt_is_concise() {
|
fn generate_photo_description_prompt_is_concise() {
|
||||||
@@ -1341,4 +1345,38 @@ mod tests {
|
|||||||
Focus on the people, location, and activity.";
|
Focus on the people, location, and activity.";
|
||||||
assert!(prompt.len() < 200, "Prompt should be concise");
|
assert!(prompt.len() < 200, "Prompt should be concise");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn call(name: &str) -> ToolCall {
|
||||||
|
ToolCall {
|
||||||
|
id: None,
|
||||||
|
function: ToolCallFunction {
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments: serde_json::json!({}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn streamed_tool_calls_across_chunks_accumulate() {
|
||||||
|
// Two tool calls arriving in two separate stream chunks must BOTH
|
||||||
|
// survive assembly — the old `tool_calls = Some(tcs)` kept only the
|
||||||
|
// last chunk's calls.
|
||||||
|
let mut acc: Option<Vec<ToolCall>> = None;
|
||||||
|
append_streamed_tool_calls(&mut acc, vec![call("get_sms_messages")]);
|
||||||
|
append_streamed_tool_calls(&mut acc, vec![call("reverse_geocode")]);
|
||||||
|
|
||||||
|
let calls = acc.expect("tool calls accumulated");
|
||||||
|
assert_eq!(calls.len(), 2);
|
||||||
|
assert_eq!(calls[0].function.name, "get_sms_messages");
|
||||||
|
assert_eq!(calls[1].function.name, "reverse_geocode");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn streamed_tool_calls_single_chunk_batch_kept_intact() {
|
||||||
|
// Older Ollama servers attach all calls to one chunk — unchanged.
|
||||||
|
let mut acc: Option<Vec<ToolCall>> = None;
|
||||||
|
append_streamed_tool_calls(&mut acc, vec![call("a"), call("b")]);
|
||||||
|
let calls = acc.expect("tool calls accumulated");
|
||||||
|
assert_eq!(calls.len(), 2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user