From b7582e69a0edee300904b46bedeedb7857bd3e1c Mon Sep 17 00:00:00 2001 From: Cameron Date: Wed, 14 Jan 2026 13:12:09 -0500 Subject: [PATCH] Add model capability caching and clear functions --- src/ai/ollama.rs | 106 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 4 deletions(-) diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs index a9c9d35..16b4bd3 100644 --- a/src/ai/ollama.rs +++ b/src/ai/ollama.rs @@ -2,7 +2,41 @@ use anyhow::Result; use chrono::NaiveDate; use reqwest::Client; use serde::{Deserialize, Serialize}; -use std::time::Duration; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +// Cache duration: 15 minutes +const CACHE_DURATION_SECS: u64 = 15 * 60; + +// Cached entry with timestamp +#[derive(Clone)] +struct CachedEntry { + data: T, + cached_at: Instant, +} + +impl CachedEntry { + fn new(data: T) -> Self { + Self { + data, + cached_at: Instant::now(), + } + } + + fn is_expired(&self) -> bool { + self.cached_at.elapsed().as_secs() > CACHE_DURATION_SECS + } +} + +// Global cache for model lists and capabilities +lazy_static::lazy_static! { + static ref MODEL_LIST_CACHE: Arc>>>> = + Arc::new(Mutex::new(HashMap::new())); + + static ref MODEL_CAPABILITIES_CACHE: Arc>>>> = + Arc::new(Mutex::new(HashMap::new())); +} #[derive(Clone)] pub struct OllamaClient { @@ -39,8 +73,21 @@ impl OllamaClient { self.num_ctx = num_ctx; } - /// List available models on an Ollama server + /// List available models on an Ollama server (cached for 15 minutes) pub async fn list_models(url: &str) -> Result> { + // Check cache first + { + let cache = MODEL_LIST_CACHE.lock().unwrap(); + if let Some(entry) = cache.get(url) { + if !entry.is_expired() { + log::debug!("Returning cached model list for {}", url); + return Ok(entry.data.clone()); + } + } + } + + log::debug!("Fetching fresh model list from {}", url); + let client = Client::builder() .connect_timeout(Duration::from_secs(5)) .timeout(Duration::from_secs(10)) @@ -53,7 +100,15 @@ impl OllamaClient { } let tags_response: OllamaTagsResponse = response.json().await?; - Ok(tags_response.models.into_iter().map(|m| m.name).collect()) + let models: Vec = tags_response.models.into_iter().map(|m| m.name).collect(); + + // Store in cache + { + let mut cache = MODEL_LIST_CACHE.lock().unwrap(); + cache.insert(url.to_string(), CachedEntry::new(models.clone())); + } + + Ok(models) } /// Check if a model is available on a server @@ -62,6 +117,30 @@ impl OllamaClient { Ok(models.iter().any(|m| m == model_name)) } + /// Clear the model list cache for a specific URL or all URLs + pub fn clear_model_cache(url: Option<&str>) { + let mut cache = MODEL_LIST_CACHE.lock().unwrap(); + if let Some(url) = url { + cache.remove(url); + log::debug!("Cleared model list cache for {}", url); + } else { + cache.clear(); + log::debug!("Cleared all model list cache entries"); + } + } + + /// Clear the model capabilities cache for a specific URL or all URLs + pub fn clear_capabilities_cache(url: Option<&str>) { + let mut cache = MODEL_CAPABILITIES_CACHE.lock().unwrap(); + if let Some(url) = url { + cache.remove(url); + log::debug!("Cleared model capabilities cache for {}", url); + } else { + cache.clear(); + log::debug!("Cleared all model capabilities cache entries"); + } + } + /// Check if a model has vision capabilities using the /api/show endpoint pub async fn check_model_capabilities( url: &str, @@ -104,8 +183,21 @@ impl OllamaClient { }) } - /// List all models with their capabilities from a server + /// List all models with their capabilities from a server (cached for 15 minutes) pub async fn list_models_with_capabilities(url: &str) -> Result> { + // Check cache first + { + let cache = MODEL_CAPABILITIES_CACHE.lock().unwrap(); + if let Some(entry) = cache.get(url) { + if !entry.is_expired() { + log::debug!("Returning cached model capabilities for {}", url); + return Ok(entry.data.clone()); + } + } + } + + log::debug!("Fetching fresh model capabilities from {}", url); + let models = Self::list_models(url).await?; let mut capabilities = Vec::new(); @@ -123,6 +215,12 @@ impl OllamaClient { } } + // Store in cache + { + let mut cache = MODEL_CAPABILITIES_CACHE.lock().unwrap(); + cache.insert(url.to_string(), CachedEntry::new(capabilities.clone())); + } + Ok(capabilities) }