Add model capability caching and clear functions
This commit is contained in:
106
src/ai/ollama.rs
106
src/ai/ollama.rs
@@ -2,7 +2,41 @@ use anyhow::Result;
|
||||
use chrono::NaiveDate;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
// Cache duration: 15 minutes
|
||||
const CACHE_DURATION_SECS: u64 = 15 * 60;
|
||||
|
||||
// Cached entry with timestamp
|
||||
#[derive(Clone)]
|
||||
struct CachedEntry<T> {
|
||||
data: T,
|
||||
cached_at: Instant,
|
||||
}
|
||||
|
||||
impl<T> CachedEntry<T> {
|
||||
fn new(data: T) -> Self {
|
||||
Self {
|
||||
data,
|
||||
cached_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_expired(&self) -> bool {
|
||||
self.cached_at.elapsed().as_secs() > CACHE_DURATION_SECS
|
||||
}
|
||||
}
|
||||
|
||||
// Global cache for model lists and capabilities
|
||||
lazy_static::lazy_static! {
|
||||
static ref MODEL_LIST_CACHE: Arc<Mutex<HashMap<String, CachedEntry<Vec<String>>>>> =
|
||||
Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
static ref MODEL_CAPABILITIES_CACHE: Arc<Mutex<HashMap<String, CachedEntry<Vec<ModelCapabilities>>>>> =
|
||||
Arc::new(Mutex::new(HashMap::new()));
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OllamaClient {
|
||||
@@ -39,8 +73,21 @@ impl OllamaClient {
|
||||
self.num_ctx = num_ctx;
|
||||
}
|
||||
|
||||
/// List available models on an Ollama server
|
||||
/// List available models on an Ollama server (cached for 15 minutes)
|
||||
pub async fn list_models(url: &str) -> Result<Vec<String>> {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = MODEL_LIST_CACHE.lock().unwrap();
|
||||
if let Some(entry) = cache.get(url) {
|
||||
if !entry.is_expired() {
|
||||
log::debug!("Returning cached model list for {}", url);
|
||||
return Ok(entry.data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::debug!("Fetching fresh model list from {}", url);
|
||||
|
||||
let client = Client::builder()
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.timeout(Duration::from_secs(10))
|
||||
@@ -53,7 +100,15 @@ impl OllamaClient {
|
||||
}
|
||||
|
||||
let tags_response: OllamaTagsResponse = response.json().await?;
|
||||
Ok(tags_response.models.into_iter().map(|m| m.name).collect())
|
||||
let models: Vec<String> = tags_response.models.into_iter().map(|m| m.name).collect();
|
||||
|
||||
// Store in cache
|
||||
{
|
||||
let mut cache = MODEL_LIST_CACHE.lock().unwrap();
|
||||
cache.insert(url.to_string(), CachedEntry::new(models.clone()));
|
||||
}
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
/// Check if a model is available on a server
|
||||
@@ -62,6 +117,30 @@ impl OllamaClient {
|
||||
Ok(models.iter().any(|m| m == model_name))
|
||||
}
|
||||
|
||||
/// Clear the model list cache for a specific URL or all URLs
|
||||
pub fn clear_model_cache(url: Option<&str>) {
|
||||
let mut cache = MODEL_LIST_CACHE.lock().unwrap();
|
||||
if let Some(url) = url {
|
||||
cache.remove(url);
|
||||
log::debug!("Cleared model list cache for {}", url);
|
||||
} else {
|
||||
cache.clear();
|
||||
log::debug!("Cleared all model list cache entries");
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the model capabilities cache for a specific URL or all URLs
|
||||
pub fn clear_capabilities_cache(url: Option<&str>) {
|
||||
let mut cache = MODEL_CAPABILITIES_CACHE.lock().unwrap();
|
||||
if let Some(url) = url {
|
||||
cache.remove(url);
|
||||
log::debug!("Cleared model capabilities cache for {}", url);
|
||||
} else {
|
||||
cache.clear();
|
||||
log::debug!("Cleared all model capabilities cache entries");
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a model has vision capabilities using the /api/show endpoint
|
||||
pub async fn check_model_capabilities(
|
||||
url: &str,
|
||||
@@ -104,8 +183,21 @@ impl OllamaClient {
|
||||
})
|
||||
}
|
||||
|
||||
/// List all models with their capabilities from a server
|
||||
/// List all models with their capabilities from a server (cached for 15 minutes)
|
||||
pub async fn list_models_with_capabilities(url: &str) -> Result<Vec<ModelCapabilities>> {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = MODEL_CAPABILITIES_CACHE.lock().unwrap();
|
||||
if let Some(entry) = cache.get(url) {
|
||||
if !entry.is_expired() {
|
||||
log::debug!("Returning cached model capabilities for {}", url);
|
||||
return Ok(entry.data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::debug!("Fetching fresh model capabilities from {}", url);
|
||||
|
||||
let models = Self::list_models(url).await?;
|
||||
let mut capabilities = Vec::new();
|
||||
|
||||
@@ -123,6 +215,12 @@ impl OllamaClient {
|
||||
}
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
{
|
||||
let mut cache = MODEL_CAPABILITIES_CACHE.lock().unwrap();
|
||||
cache.insert(url.to_string(), CachedEntry::new(capabilities.clone()));
|
||||
}
|
||||
|
||||
Ok(capabilities)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user