Files
ImageApi/src/ai/ollama.rs
Cameron 079cd4c5b9 feat(ai): streaming chat endpoint with live tool events
Add LlmClient::chat_with_tools_stream and SSE endpoint
POST /insights/chat/stream that emits text deltas, tool_call /
tool_result pairs, truncated notice, and a terminal done frame as the
agentic loop runs.

- Ollama: parses NDJSON from /api/chat stream, accumulates content
  deltas, emits Done with tool_calls from the final chunk.
- OpenRouter: parses OpenAI-compatible SSE, reassembles tool_call
  argument deltas by index, asks for stream_options.include_usage.
- InsightChatService spawns the loop on a tokio task, feeds events
  through an mpsc channel, persists training_messages at the end.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-21 16:57:41 -04:00

1236 lines
42 KiB
Rust

use anyhow::{Context, Result};
use async_trait::async_trait;
use chrono::NaiveDate;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::ai::llm_client::{LlmClient, LlmStreamEvent};
use futures::stream::{BoxStream, StreamExt};
// Re-export shared types so existing `crate::ai::ollama::{...}` imports
// continue to resolve.
pub use crate::ai::llm_client::{ChatMessage, ModelCapabilities, Tool};
#[allow(unused_imports)]
pub use crate::ai::llm_client::{ToolCall, ToolCallFunction, ToolFunction};
// Cache duration: 15 minutes
const CACHE_DURATION_SECS: u64 = 15 * 60;
// Cached entry with timestamp
#[derive(Clone)]
struct CachedEntry<T> {
data: T,
cached_at: Instant,
}
impl<T> CachedEntry<T> {
fn new(data: T) -> Self {
Self {
data,
cached_at: Instant::now(),
}
}
fn is_expired(&self) -> bool {
self.cached_at.elapsed().as_secs() > CACHE_DURATION_SECS
}
}
// Global cache for model lists and capabilities
lazy_static::lazy_static! {
static ref MODEL_LIST_CACHE: Arc<Mutex<HashMap<String, CachedEntry<Vec<String>>>>> =
Arc::new(Mutex::new(HashMap::new()));
static ref MODEL_CAPABILITIES_CACHE: Arc<Mutex<HashMap<String, CachedEntry<Vec<ModelCapabilities>>>>> =
Arc::new(Mutex::new(HashMap::new()));
}
#[derive(Clone)]
pub struct OllamaClient {
client: Client,
pub primary_url: String,
pub fallback_url: Option<String>,
pub primary_model: String,
pub fallback_model: Option<String>,
num_ctx: Option<i32>,
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<i32>,
min_p: Option<f32>,
}
impl OllamaClient {
pub fn new(
primary_url: String,
fallback_url: Option<String>,
primary_model: String,
fallback_model: Option<String>,
) -> Self {
Self {
client: Client::builder()
.connect_timeout(Duration::from_secs(5)) // Quick connection timeout
.timeout(Duration::from_secs(120)) // Total request timeout for generation
.build()
.unwrap_or_else(|_| Client::new()),
primary_url,
fallback_url,
primary_model,
fallback_model,
num_ctx: None,
temperature: None,
top_p: None,
top_k: None,
min_p: None,
}
}
pub fn set_num_ctx(&mut self, num_ctx: Option<i32>) {
self.num_ctx = num_ctx;
}
/// Set sampling parameters for generation. `None` values leave the
/// server-side default in place.
pub fn set_sampling_params(
&mut self,
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<i32>,
min_p: Option<f32>,
) {
self.temperature = temperature;
self.top_p = top_p;
self.top_k = top_k;
self.min_p = min_p;
}
/// Build an `OllamaOptions` payload from the currently configured fields.
/// Returns `None` if no options would be set, so the `options` field is
/// omitted from the request entirely.
fn build_options(&self) -> Option<OllamaOptions> {
if self.num_ctx.is_none()
&& self.temperature.is_none()
&& self.top_p.is_none()
&& self.top_k.is_none()
&& self.min_p.is_none()
{
None
} else {
Some(OllamaOptions {
num_ctx: self.num_ctx,
temperature: self.temperature,
top_p: self.top_p,
top_k: self.top_k,
min_p: self.min_p,
})
}
}
/// Replace the HTTP client with one using a custom request timeout.
/// Useful for slow models where the default 120s may be insufficient.
#[allow(dead_code)]
pub fn with_request_timeout(mut self, secs: u64) -> Self {
self.client = Client::builder()
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(secs))
.build()
.unwrap_or_else(|_| Client::new());
self
}
/// List available models on an Ollama server (cached for 15 minutes)
pub async fn list_models(url: &str) -> Result<Vec<String>> {
// Check cache first
{
let cache = MODEL_LIST_CACHE.lock().unwrap();
if let Some(entry) = cache.get(url)
&& !entry.is_expired()
{
log::debug!("Returning cached model list for {}", url);
return Ok(entry.data.clone());
}
}
log::debug!("Fetching fresh model list from {}", url);
let client = Client::builder()
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(10))
.build()?;
let response = client.get(format!("{}/api/tags", url)).send().await?;
if !response.status().is_success() {
return Err(anyhow::anyhow!("Failed to list models from {}", url));
}
let tags_response: OllamaTagsResponse = response.json().await?;
let models: Vec<String> = tags_response.models.into_iter().map(|m| m.name).collect();
// Store in cache
{
let mut cache = MODEL_LIST_CACHE.lock().unwrap();
cache.insert(url.to_string(), CachedEntry::new(models.clone()));
}
Ok(models)
}
/// Check if a model is available on a server
pub async fn is_model_available(url: &str, model_name: &str) -> Result<bool> {
let models = Self::list_models(url).await?;
Ok(models.iter().any(|m| m == model_name))
}
/// Clear the model list cache for a specific URL or all URLs
#[allow(dead_code)]
pub fn clear_model_cache(url: Option<&str>) {
let mut cache = MODEL_LIST_CACHE.lock().unwrap();
if let Some(url) = url {
cache.remove(url);
log::debug!("Cleared model list cache for {}", url);
} else {
cache.clear();
log::debug!("Cleared all model list cache entries");
}
}
/// Clear the model capabilities cache for a specific URL or all URLs
#[allow(dead_code)]
pub fn clear_capabilities_cache(url: Option<&str>) {
let mut cache = MODEL_CAPABILITIES_CACHE.lock().unwrap();
if let Some(url) = url {
cache.remove(url);
log::debug!("Cleared model capabilities cache for {}", url);
} else {
cache.clear();
log::debug!("Cleared all model capabilities cache entries");
}
}
/// Check if a model has vision capabilities using the /api/show endpoint
pub async fn check_model_capabilities(
url: &str,
model_name: &str,
) -> Result<ModelCapabilities> {
let client = Client::builder()
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(10))
.build()?;
#[derive(Serialize)]
struct ShowRequest {
model: String,
}
let response = client
.post(format!("{}/api/show", url))
.json(&ShowRequest {
model: model_name.to_string(),
})
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to get model details for {} from {}",
model_name,
url
));
}
let show_response: OllamaShowResponse = response.json().await?;
// Check if "vision" is in the capabilities array
let has_vision = show_response.capabilities.iter().any(|cap| cap == "vision");
// Check if "tools" is in the capabilities array
let has_tool_calling = show_response.capabilities.iter().any(|cap| cap == "tools");
Ok(ModelCapabilities {
name: model_name.to_string(),
has_vision,
has_tool_calling,
})
}
/// List all models with their capabilities from a server (cached for 15 minutes)
pub async fn list_models_with_capabilities(url: &str) -> Result<Vec<ModelCapabilities>> {
// Check cache first
{
let cache = MODEL_CAPABILITIES_CACHE.lock().unwrap();
if let Some(entry) = cache.get(url)
&& !entry.is_expired()
{
log::debug!("Returning cached model capabilities for {}", url);
return Ok(entry.data.clone());
}
}
log::debug!("Fetching fresh model capabilities from {}", url);
let models = Self::list_models(url).await?;
let mut capabilities = Vec::new();
for model_name in models {
match Self::check_model_capabilities(url, &model_name).await {
Ok(cap) => capabilities.push(cap),
Err(e) => {
log::warn!("Failed to get capabilities for model {}: {}", model_name, e);
// Fallback: assume no vision/tools if we can't check
capabilities.push(ModelCapabilities {
name: model_name,
has_vision: false,
has_tool_calling: false,
});
}
}
}
// Store in cache
{
let mut cache = MODEL_CAPABILITIES_CACHE.lock().unwrap();
cache.insert(url.to_string(), CachedEntry::new(capabilities.clone()));
}
Ok(capabilities)
}
/// Extract final answer from thinking model output
/// Handles <think>...</think> tags and takes everything after
fn extract_final_answer(&self, response: &str) -> String {
let response = response.trim();
// Look for </think> tag and take everything after it
if let Some(pos) = response.find("</think>") {
let answer = response[pos + 8..].trim();
if !answer.is_empty() {
return answer.to_string();
}
}
// Fallback: return the whole response trimmed
response.to_string()
}
async fn try_generate(
&self,
url: &str,
model: &str,
prompt: &str,
system: Option<&str>,
images: Option<Vec<String>>,
) -> Result<String> {
let request = OllamaRequest {
model: model.to_string(),
prompt: prompt.to_string(),
stream: false,
system: system.map(|s| s.to_string()),
options: self.build_options(),
images,
};
let response = self
.client
.post(format!("{}/api/generate", url))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"Ollama request failed: {} - {}",
status,
error_body
));
}
let result: OllamaResponse = response.json().await?;
Ok(result.response)
}
pub async fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> {
self.generate_with_images(prompt, system, None).await
}
pub async fn generate_with_images(
&self,
prompt: &str,
system: Option<&str>,
images: Option<Vec<String>>,
) -> Result<String> {
log::debug!("=== Ollama Request ===");
log::debug!("Primary model: {}", self.primary_model);
if let Some(sys) = system {
log::debug!("System: {}", sys);
}
log::debug!("Prompt:\n{}", prompt);
if let Some(ref imgs) = images {
log::debug!("Images: {} image(s) included", imgs.len());
}
log::debug!("=====================");
// Try primary server first with primary model
log::info!(
"Attempting to generate with primary server: {} (model: {})",
self.primary_url,
self.primary_model
);
let primary_result = self
.try_generate(
&self.primary_url,
&self.primary_model,
prompt,
system,
images.clone(),
)
.await;
let raw_response = match primary_result {
Ok(response) => {
log::info!("Successfully generated response from primary server");
response
}
Err(e) => {
log::warn!("Primary server failed: {}", e);
// Try fallback server if available
if let Some(fallback_url) = &self.fallback_url {
// Use fallback model if specified, otherwise use primary model
let fallback_model =
self.fallback_model.as_ref().unwrap_or(&self.primary_model);
log::info!(
"Attempting to generate with fallback server: {} (model: {})",
fallback_url,
fallback_model
);
match self
.try_generate(fallback_url, fallback_model, prompt, system, images.clone())
.await
{
Ok(response) => {
log::info!("Successfully generated response from fallback server");
response
}
Err(fallback_e) => {
log::error!("Fallback server also failed: {}", fallback_e);
return Err(anyhow::anyhow!(
"Both primary and fallback servers failed. Primary: {}, Fallback: {}",
e,
fallback_e
));
}
}
} else {
log::error!("No fallback server configured");
return Err(e);
}
}
};
log::debug!("=== Ollama Response ===");
log::debug!("Raw response: {}", raw_response.trim());
log::debug!("=======================");
// Extract final answer from thinking model output
let cleaned = self.extract_final_answer(&raw_response);
log::debug!("=== Cleaned Response ===");
log::debug!("Final answer: {}", cleaned);
log::debug!("========================");
Ok(cleaned)
}
/// Generate a title for a single photo based on its generated summary
pub async fn generate_photo_title(
&self,
summary: &str,
custom_system: Option<&str>,
) -> Result<String> {
let prompt = format!(
r#"Create a short title (maximum 8 words) for the following journal entry:
{}
Capture the key moment or theme. Return ONLY the title, nothing else."#,
summary
);
let system = custom_system.unwrap_or("You are my long term memory assistant. Use only the information provided. Do not invent details.");
let title = self
.generate_with_images(&prompt, Some(system), None)
.await?;
Ok(title.trim().trim_matches('"').to_string())
}
/// Generate a summary for a single photo based on its context
pub async fn generate_photo_summary(
&self,
date: NaiveDate,
location: Option<&str>,
contact: Option<&str>,
sms_summary: Option<&str>,
custom_system: Option<&str>,
image_base64: Option<String>,
) -> Result<String> {
let location_str = location.unwrap_or("Unknown");
let sms_str = sms_summary.unwrap_or("No messages");
let prompt = if image_base64.is_some() {
if let Some(contact_name) = contact {
format!(
r#"Write a 1-3 paragraph description of this moment by analyzing the image and the available context:
Date: {}
Location: {}
Person/Contact: {}
Messages: {}
Analyze the image and use specific details from both the visual content and the context above. The photo is from a folder for {}, so they are likely in or related to this photo. Mention people's names (especially {}), places, or activities if they appear in either the image or the context. Write in first person as Cameron with the tone of a journal entry. If limited information is available, keep it simple and factual based on what you see and know. If the location is unknown omit it"#,
date.format("%B %d, %Y"),
location_str,
contact_name,
sms_str,
contact_name,
contact_name
)
} else {
format!(
r#"Write a 1-3 paragraph description of this moment by analyzing the image and the available context:
Date: {}
Location: {}
Messages: {}
Analyze the image and use specific details from both the visual content and the context above. Mention people's names, places, or activities if they appear in either the image or the context. Write in first person as Cameron with the tone of a journal entry. If limited information is available, keep it simple and factual based on what you see and know. If the location is unknown omit it"#,
date.format("%B %d, %Y"),
location_str,
sms_str
)
}
} else if let Some(contact_name) = contact {
format!(
r#"Write a 1-3 paragraph description of this moment based on the available information:
Date: {}
Location: {}
Person/Contact: {}
Messages: {}
Use only the specific details provided above. The photo is from a folder for {}, so they are likely related to this moment. Mention people's names (especially {}), places, or activities if they appear in the context. Write in first person as Cameron with the tone of a journal entry. If limited information is available, keep it simple and factual. If the location is unknown omit it"#,
date.format("%B %d, %Y"),
location_str,
contact_name,
sms_str,
contact_name,
contact_name
)
} else {
format!(
r#"Write a 1-3 paragraph description of this moment based on the available information:
Date: {}
Location: {}
Messages: {}
Use only the specific details provided above. Mention people's names, places, or activities if they appear in the context. Write in first person as Cameron with the tone of a journal entry. If limited information is available, keep it simple and factual. If the location is unknown omit it"#,
date.format("%B %d, %Y"),
location_str,
sms_str
)
};
let system = custom_system.unwrap_or("You are a memory refreshing assistant who is able to provide insights through analyzing past conversations. Use only the information provided. Do not invent details.");
let images = image_base64.map(|img| vec![img]);
self.generate_with_images(&prompt, Some(system), images)
.await
}
/// Generate a brief visual description of a photo for use in RAG query enrichment.
/// Returns 1-2 sentences describing people, location, and activity visible in the image.
/// Only called when the model has vision capabilities.
pub async fn generate_photo_description(&self, image_base64: &str) -> Result<String> {
let prompt = "Briefly describe what you see in this image in 1-2 sentences. \
Focus on the people, location, and activity.";
let system = "You are a scene description assistant. Be concise and factual.";
let images = vec![image_base64.to_string()];
let description = self
.generate_with_images(prompt, Some(system), Some(images))
.await?;
Ok(description.trim().to_string())
}
/// Send a chat request with tool definitions to /api/chat.
/// Returns the assistant's response message (may contain tool_calls or final content).
/// Uses primary/fallback URL routing same as other generation methods.
pub async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<(ChatMessage, Option<i32>, Option<i32>)> {
// Try primary server first
log::info!(
"Attempting chat_with_tools with primary server: {} (model: {})",
self.primary_url,
self.primary_model
);
let primary_result = self
.try_chat_with_tools(&self.primary_url, messages.clone(), tools.clone())
.await;
match primary_result {
Ok(result) => {
log::info!("Successfully got chat_with_tools response from primary server");
Ok(result)
}
Err(e) => {
log::warn!("Primary server chat_with_tools failed: {}", e);
// Try fallback server if available
if let Some(fallback_url) = &self.fallback_url {
let fallback_model =
self.fallback_model.as_ref().unwrap_or(&self.primary_model);
log::info!(
"Attempting chat_with_tools with fallback server: {} (model: {})",
fallback_url,
fallback_model
);
match self
.try_chat_with_tools(fallback_url, messages, tools)
.await
{
Ok(result) => {
log::info!(
"Successfully got chat_with_tools response from fallback server"
);
Ok(result)
}
Err(fallback_e) => {
log::error!(
"Fallback server chat_with_tools also failed: {}",
fallback_e
);
Err(anyhow::anyhow!(
"Both primary and fallback servers failed. Primary: {}, Fallback: {}",
e,
fallback_e
))
}
}
} else {
log::error!("No fallback server configured");
Err(e)
}
}
}
}
/// Streaming variant of `chat_with_tools`. Tries primary, then falls
/// back if the initial connection fails; once the stream has begun
/// emitting, mid-stream errors propagate to the caller. Emits
/// `TextDelta` events as content tokens arrive and a single terminal
/// `Done` event when the model marks the turn complete (tool_calls, if
/// any, live on the final message).
pub async fn chat_with_tools_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
// Attempt primary. If it can't be opened at all, try fallback.
match self
.try_chat_with_tools_stream(&self.primary_url, messages.clone(), tools.clone())
.await
{
Ok(s) => Ok(s),
Err(e) => {
if let Some(fallback_url) = self.fallback_url.clone() {
log::warn!(
"Streaming chat primary failed ({}); trying fallback {}",
e,
fallback_url
);
self.try_chat_with_tools_stream(&fallback_url, messages, tools)
.await
} else {
Err(e)
}
}
}
}
async fn try_chat_with_tools_stream(
&self,
base_url: &str,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
let url = format!("{}/api/chat", base_url);
let model = if base_url == self.primary_url {
&self.primary_model
} else {
self.fallback_model
.as_deref()
.unwrap_or(&self.primary_model)
};
let options = self.build_options();
let request_body = OllamaChatRequest {
model,
messages: &messages,
stream: true,
tools,
options,
};
let response = self
.client
.post(&url)
.json(&request_body)
.send()
.await
.with_context(|| format!("Failed to connect to Ollama at {}", url))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!(
"Ollama stream request failed with status {}: {}",
status,
body
);
}
// Ollama streams NDJSON: each line is a full `OllamaStreamChunk`.
// We buffer partial lines across chunks from the byte stream.
let byte_stream = response.bytes_stream();
let stream = async_stream::stream! {
let mut buf: Vec<u8> = Vec::new();
let mut accumulated = String::new();
let mut tool_calls: Option<Vec<crate::ai::llm_client::ToolCall>> = None;
let mut role = "assistant".to_string();
let mut prompt_eval_count: Option<i32> = None;
let mut eval_count: Option<i32> = None;
let mut prompt_eval_duration: Option<u64> = None;
let mut eval_duration: Option<u64> = None;
let mut done_seen = false;
let mut byte_stream = byte_stream;
while let Some(chunk) = byte_stream.next().await {
let chunk = match chunk {
Ok(b) => b,
Err(e) => {
yield Err(anyhow::anyhow!("stream read failed: {}", e));
return;
}
};
buf.extend_from_slice(&chunk);
// Drain complete lines; hold any trailing partial.
while let Some(nl) = buf.iter().position(|b| *b == b'\n') {
let line = buf.drain(..=nl).collect::<Vec<_>>();
let line_str = match std::str::from_utf8(&line) {
Ok(s) => s.trim(),
Err(_) => continue,
};
if line_str.is_empty() {
continue;
}
match serde_json::from_str::<OllamaStreamChunk>(line_str) {
Ok(chunk) => {
// Accumulate content delta.
if !chunk.message.content.is_empty() {
accumulated.push_str(&chunk.message.content);
yield Ok(LlmStreamEvent::TextDelta(chunk.message.content));
}
if !chunk.message.role.is_empty() {
role = chunk.message.role;
}
// Ollama only attaches tool_calls on the final chunk.
if let Some(tcs) = chunk.message.tool_calls
&& !tcs.is_empty()
{
tool_calls = Some(tcs);
}
if chunk.done {
prompt_eval_count = chunk.prompt_eval_count;
eval_count = chunk.eval_count;
prompt_eval_duration = chunk.prompt_eval_duration;
eval_duration = chunk.eval_duration;
done_seen = true;
break;
}
}
Err(e) => {
log::warn!("malformed Ollama stream line: {} ({})", line_str, e);
}
}
}
if done_seen {
break;
}
}
// Emit the terminal Done event with the assembled message.
log_chat_metrics(
prompt_eval_count,
prompt_eval_duration,
eval_count,
eval_duration,
);
let message = ChatMessage {
role,
content: accumulated,
tool_calls,
images: None,
};
yield Ok(LlmStreamEvent::Done {
message,
prompt_eval_count,
eval_count,
});
};
Ok(Box::pin(stream))
}
async fn try_chat_with_tools(
&self,
base_url: &str,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<(ChatMessage, Option<i32>, Option<i32>)> {
let url = format!("{}/api/chat", base_url);
let model = if base_url == self.primary_url {
&self.primary_model
} else {
self.fallback_model
.as_deref()
.unwrap_or(&self.primary_model)
};
let options = self.build_options();
let request_body = OllamaChatRequest {
model,
messages: &messages,
stream: false,
tools,
options,
};
let request_json = serde_json::to_string(&request_body)
.unwrap_or_else(|e| format!("<serialization error: {}>", e));
log::debug!("chat_with_tools request body: {}", request_json);
let response = self
.client
.post(&url)
.json(&request_body)
.send()
.await
.with_context(|| format!("Failed to connect to Ollama at {}", url))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
log::error!(
"chat_with_tools request body that caused {}: {}",
status,
request_json
);
anyhow::bail!(
"Ollama chat request failed with status {}: {}",
status,
body
);
}
let chat_response: OllamaChatResponse = response
.json()
.await
.with_context(|| "Failed to parse Ollama chat response")?;
// Log performance counters returned by Ollama. Durations are
// reported in nanoseconds; we render ms + tokens/sec for skim-ability
// in the server log. Missing fields are left off the line rather
// than printed as `None`.
log_chat_metrics(
chat_response.prompt_eval_count,
chat_response.prompt_eval_duration,
chat_response.eval_count,
chat_response.eval_duration,
);
Ok((
chat_response.message,
chat_response.prompt_eval_count,
chat_response.eval_count,
))
}
/// Generate an embedding vector for text using nomic-embed-text:v1.5
/// Returns a 768-dimensional vector as Vec<f32>
pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.generate_embeddings(&[text]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No embedding returned"))
}
/// Generate embeddings for multiple texts in a single API call (batch mode)
/// Returns a vector of 768-dimensional vectors
/// This is much more efficient than calling generate_embedding multiple times
pub async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let embedding_model = "nomic-embed-text:v1.5";
log::debug!("=== Ollama Batch Embedding Request ===");
log::debug!("Model: {}", embedding_model);
log::debug!("Batch size: {} texts", texts.len());
log::debug!("======================================");
// Try primary server first
log::debug!(
"Attempting to generate {} embeddings with primary server: {} (model: {})",
texts.len(),
self.primary_url,
embedding_model
);
let primary_result = self
.try_generate_embeddings(&self.primary_url, embedding_model, texts)
.await;
let embeddings = match primary_result {
Ok(embeddings) => {
log::debug!(
"Successfully generated {} embeddings from primary server",
embeddings.len()
);
embeddings
}
Err(e) => {
log::warn!("Primary server batch embedding failed: {}", e);
// Try fallback server if available
if let Some(fallback_url) = &self.fallback_url {
log::info!(
"Attempting to generate {} embeddings with fallback server: {} (model: {})",
texts.len(),
fallback_url,
embedding_model
);
match self
.try_generate_embeddings(fallback_url, embedding_model, texts)
.await
{
Ok(embeddings) => {
log::info!(
"Successfully generated {} embeddings from fallback server",
embeddings.len()
);
embeddings
}
Err(fallback_e) => {
log::error!(
"Fallback server batch embedding also failed: {}",
fallback_e
);
return Err(anyhow::anyhow!(
"Both primary and fallback servers failed. Primary: {}, Fallback: {}",
e,
fallback_e
));
}
}
} else {
log::error!("No fallback server configured");
return Err(e);
}
}
};
// Validate embedding dimensions (should be 768 for nomic-embed-text:v1.5)
for (i, embedding) in embeddings.iter().enumerate() {
if embedding.len() != 768 {
log::warn!(
"Unexpected embedding dimensions for item {}: {} (expected 768)",
i,
embedding.len()
);
}
}
Ok(embeddings)
}
/// Internal helper to try generating embeddings for multiple texts from a specific server
async fn try_generate_embeddings(
&self,
url: &str,
model: &str,
texts: &[&str],
) -> Result<Vec<Vec<f32>>> {
let request = OllamaBatchEmbedRequest {
model: model.to_string(),
input: texts.iter().map(|s| s.to_string()).collect(),
};
let response = self
.client
.post(format!("{}/api/embed", url))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"Ollama batch embedding request failed: {} - {}",
status,
error_body
));
}
let result: OllamaEmbedResponse = response.json().await?;
Ok(result.embeddings)
}
}
#[async_trait]
impl LlmClient for OllamaClient {
async fn generate(
&self,
prompt: &str,
system: Option<&str>,
images: Option<Vec<String>>,
) -> Result<String> {
self.generate_with_images(prompt, system, images).await
}
async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<(ChatMessage, Option<i32>, Option<i32>)> {
OllamaClient::chat_with_tools(self, messages, tools).await
}
async fn chat_with_tools_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<Tool>,
) -> Result<BoxStream<'static, Result<LlmStreamEvent>>> {
OllamaClient::chat_with_tools_stream(self, messages, tools).await
}
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
OllamaClient::generate_embeddings(self, texts).await
}
async fn describe_image(&self, image_base64: &str) -> Result<String> {
self.generate_photo_description(image_base64).await
}
async fn list_models(&self) -> Result<Vec<ModelCapabilities>> {
Self::list_models_with_capabilities(&self.primary_url).await
}
async fn model_capabilities(&self, model: &str) -> Result<ModelCapabilities> {
Self::check_model_capabilities(&self.primary_url, model).await
}
fn primary_model(&self) -> &str {
&self.primary_model
}
}
#[derive(Serialize)]
struct OllamaRequest {
model: String,
prompt: String,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<OllamaOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
images: Option<Vec<String>>,
}
#[derive(Serialize)]
struct OllamaOptions {
#[serde(skip_serializing_if = "Option::is_none")]
num_ctx: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
min_p: Option<f32>,
}
#[derive(Serialize)]
struct OllamaChatRequest<'a> {
model: &'a str,
messages: &'a [ChatMessage],
stream: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<Tool>,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<OllamaOptions>,
}
#[derive(Deserialize, Debug)]
struct OllamaChatResponse {
message: ChatMessage,
#[allow(dead_code)]
done: bool,
#[serde(default)]
#[allow(dead_code)]
done_reason: String,
#[serde(default)]
prompt_eval_count: Option<i32>,
/// Nanoseconds spent evaluating the prompt (context ingestion).
#[serde(default)]
prompt_eval_duration: Option<u64>,
#[serde(default)]
eval_count: Option<i32>,
/// Nanoseconds spent generating the response tokens.
#[serde(default)]
eval_duration: Option<u64>,
}
/// One chunk in the NDJSON stream from `/api/chat` with `stream: true`.
/// Early chunks carry content deltas in `message.content`; the final chunk
/// has `done: true`, optional `tool_calls`, and usage counters.
#[derive(Deserialize, Debug)]
struct OllamaStreamChunk {
#[serde(default)]
message: OllamaStreamMessage,
#[serde(default)]
done: bool,
#[serde(default)]
prompt_eval_count: Option<i32>,
#[serde(default)]
prompt_eval_duration: Option<u64>,
#[serde(default)]
eval_count: Option<i32>,
#[serde(default)]
eval_duration: Option<u64>,
}
#[derive(Deserialize, Debug, Default)]
struct OllamaStreamMessage {
#[serde(default)]
role: String,
#[serde(default)]
content: String,
#[serde(default)]
tool_calls: Option<Vec<crate::ai::llm_client::ToolCall>>,
}
#[derive(Deserialize)]
struct OllamaResponse {
response: String,
}
fn log_chat_metrics(
prompt_eval_count: Option<i32>,
prompt_eval_duration_ns: Option<u64>,
eval_count: Option<i32>,
eval_duration_ns: Option<u64>,
) {
// Compute tokens/sec when both count and duration are present.
fn tokens_per_sec(count: Option<i32>, duration_ns: Option<u64>) -> Option<f64> {
match (count, duration_ns) {
(Some(c), Some(d)) if c > 0 && d > 0 => Some((c as f64) * 1_000_000_000.0 / (d as f64)),
_ => None,
}
}
let prompt_ms = prompt_eval_duration_ns.map(|ns| ns as f64 / 1_000_000.0);
let eval_ms = eval_duration_ns.map(|ns| ns as f64 / 1_000_000.0);
let prompt_tps = tokens_per_sec(prompt_eval_count, prompt_eval_duration_ns);
let eval_tps = tokens_per_sec(eval_count, eval_duration_ns);
let mut parts: Vec<String> = Vec::new();
if let Some(c) = prompt_eval_count {
let mut s = format!("prompt={} tok", c);
if let Some(ms) = prompt_ms {
s.push_str(&format!(" ({:.0} ms", ms));
if let Some(tps) = prompt_tps {
s.push_str(&format!(", {:.1} tok/s", tps));
}
s.push(')');
}
parts.push(s);
}
if let Some(c) = eval_count {
let mut s = format!("gen={} tok", c);
if let Some(ms) = eval_ms {
s.push_str(&format!(" ({:.0} ms", ms));
if let Some(tps) = eval_tps {
s.push_str(&format!(", {:.1} tok/s", tps));
}
s.push(')');
}
parts.push(s);
}
if !parts.is_empty() {
log::info!("Ollama chat metrics — {}", parts.join(", "));
}
}
#[derive(Deserialize)]
struct OllamaTagsResponse {
models: Vec<OllamaModel>,
}
#[derive(Deserialize)]
struct OllamaModel {
name: String,
}
#[derive(Deserialize)]
struct OllamaShowResponse {
#[serde(default)]
capabilities: Vec<String>,
}
#[derive(Serialize)]
struct OllamaBatchEmbedRequest {
model: String,
input: Vec<String>,
}
#[derive(Deserialize)]
struct OllamaEmbedResponse {
embeddings: Vec<Vec<f32>>,
}
#[cfg(test)]
mod tests {
#[test]
fn generate_photo_description_prompt_is_concise() {
// Verify the method exists and its prompt is sane by checking the
// constant we'll use. This is a compile + smoke check; actual LLM
// calls are integration-tested manually.
let prompt = "Briefly describe what you see in this image in 1-2 sentences. \
Focus on the people, location, and activity.";
assert!(prompt.len() < 200, "Prompt should be concise");
}
}