diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs index 80dddcb..3728da7 100644 --- a/src/ai/ollama.rs +++ b/src/ai/ollama.rs @@ -73,6 +73,17 @@ impl OllamaClient { self.num_ctx = num_ctx; } + /// Replace the HTTP client with one using a custom request timeout. + /// Useful for slow models where the default 120s may be insufficient. + pub fn with_request_timeout(mut self, secs: u64) -> Self { + self.client = Client::builder() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(secs)) + .build() + .unwrap_or_else(|_| Client::new()); + self + } + /// List available models on an Ollama server (cached for 15 minutes) pub async fn list_models(url: &str) -> Result> { // Check cache first diff --git a/src/bin/populate_knowledge.rs b/src/bin/populate_knowledge.rs new file mode 100644 index 0000000..432084b --- /dev/null +++ b/src/bin/populate_knowledge.rs @@ -0,0 +1,228 @@ +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; + +use clap::Parser; +use walkdir::WalkDir; + +use image_api::ai::{InsightGenerator, OllamaClient, SmsApiClient}; +use image_api::database::{ + CalendarEventDao, DailySummaryDao, ExifDao, InsightDao, KnowledgeDao, LocationHistoryDao, + SearchHistoryDao, SqliteCalendarEventDao, SqliteDailySummaryDao, SqliteExifDao, + SqliteInsightDao, SqliteKnowledgeDao, SqliteLocationHistoryDao, SqliteSearchHistoryDao, +}; +use image_api::file_types::{IMAGE_EXTENSIONS, VIDEO_EXTENSIONS}; +use image_api::tags::{SqliteTagDao, TagDao}; + +#[derive(Parser, Debug)] +#[command(name = "populate_knowledge")] +#[command( + about = "Batch populate the knowledge base by running the agentic insight loop over a folder" +)] +struct Args { + /// Directory to scan. Defaults to BASE_PATH from .env + #[arg(long)] + path: Option, + + /// Ollama model override. Defaults to OLLAMA_PRIMARY_MODEL from .env + #[arg(long)] + model: Option, + + /// Maximum agentic loop iterations per file + #[arg(long, default_value_t = 12)] + max_iterations: usize, + + /// HTTP request timeout in seconds. Increase for large/slow models + #[arg(long, default_value_t = 120)] + timeout_secs: u64, + + /// Context window size (num_ctx) passed to the model + #[arg(long)] + num_ctx: Option, + + /// Re-process files that already have an insight stored + #[arg(long, default_value_t = false)] + reprocess: bool, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + env_logger::init(); + dotenv::dotenv().ok(); + + let args = Args::parse(); + + let base_path = dotenv::var("BASE_PATH")?; + let scan_path = args.path.as_deref().unwrap_or(&base_path).to_string(); + + // Ollama config from env with CLI overrides + let primary_url = std::env::var("OLLAMA_PRIMARY_URL") + .or_else(|_| std::env::var("OLLAMA_URL")) + .unwrap_or_else(|_| "http://localhost:11434".to_string()); + let fallback_url = std::env::var("OLLAMA_FALLBACK_URL").ok(); + let primary_model = args + .model + .clone() + .or_else(|| std::env::var("OLLAMA_PRIMARY_MODEL").ok()) + .or_else(|| std::env::var("OLLAMA_MODEL").ok()) + .unwrap_or_else(|| "nemotron-3-nano:30b".to_string()); + let fallback_model = std::env::var("OLLAMA_FALLBACK_MODEL").ok(); + + let mut ollama = OllamaClient::new( + primary_url.clone(), + fallback_url, + primary_model.clone(), + fallback_model, + ) + .with_request_timeout(args.timeout_secs); + + if let Some(ctx) = args.num_ctx { + ollama.set_num_ctx(Some(ctx)); + } + + let sms_api_url = + std::env::var("SMS_API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string()); + let sms_api_token = std::env::var("SMS_API_TOKEN").ok(); + let sms_client = SmsApiClient::new(sms_api_url, sms_api_token); + + // Wire up all DAOs + let insight_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteInsightDao::new()))); + let exif_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteExifDao::new()))); + let daily_summary_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteDailySummaryDao::new()))); + let calendar_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteCalendarEventDao::new()))); + let location_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteLocationHistoryDao::new()))); + let search_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteSearchHistoryDao::new()))); + let tag_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteTagDao::default()))); + let knowledge_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteKnowledgeDao::new()))); + + let generator = InsightGenerator::new( + ollama, + sms_client, + insight_dao.clone(), + exif_dao, + daily_summary_dao, + calendar_dao, + location_dao, + search_dao, + tag_dao, + knowledge_dao, + base_path.clone(), + ); + + println!("Knowledge Base Population"); + println!("========================="); + println!("Scan path: {}", scan_path); + println!("Model: {}", primary_model); + println!("Max iterations: {}", args.max_iterations); + println!("Timeout: {}s", args.timeout_secs); + if let Some(ctx) = args.num_ctx { + println!("Num ctx: {}", ctx); + } + println!( + "Mode: {}", + if args.reprocess { + "reprocess all" + } else { + "skip existing" + } + ); + println!(); + + // Collect all image and video files + let all_extensions: Vec<&str> = IMAGE_EXTENSIONS + .iter() + .chain(VIDEO_EXTENSIONS.iter()) + .copied() + .collect(); + + println!("Scanning {}...", scan_path); + let files: Vec = WalkDir::new(&scan_path) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| e.file_type().is_file()) + .filter(|e| { + e.path() + .extension() + .and_then(|ext| ext.to_str()) + .map(|ext| all_extensions.contains(&ext.to_lowercase().as_str())) + .unwrap_or(false) + }) + .map(|e| e.path().to_path_buf()) + .collect(); + + let total = files.len(); + println!("Found {} files\n", total); + + if total == 0 { + println!("Nothing to process."); + return Ok(()); + } + + let cx = opentelemetry::Context::new(); + let mut processed = 0usize; + let mut skipped = 0usize; + let mut errors = 0usize; + + for (i, path) in files.iter().enumerate() { + let relative = match path.strip_prefix(&base_path) { + Ok(p) => p.to_string_lossy().replace('\\', "/"), + Err(_) => path.to_string_lossy().replace('\\', "/"), + }; + + let prefix = format!("[{}/{}]", i + 1, total); + + // Check for existing insight unless --reprocess + if !args.reprocess { + let has_insight = insight_dao + .lock() + .unwrap() + .get_insight(&cx, &relative) + .unwrap_or(None) + .is_some(); + + if has_insight { + println!("{} skip {}", prefix, relative); + skipped += 1; + continue; + } + } + + println!("{} start {}", prefix, relative); + + match generator + .generate_agentic_insight_for_photo( + &relative, + args.model.clone(), + None, + args.num_ctx, + args.max_iterations, + ) + .await + { + Ok(_) => { + println!("{} done {}", prefix, relative); + processed += 1; + } + Err(e) => { + eprintln!("{} error {} — {:?}", prefix, relative, e); + errors += 1; + } + } + } + + println!(); + println!("========================="); + println!("Complete"); + println!(" Processed: {}", processed); + println!(" Skipped: {}", skipped); + println!(" Errors: {}", errors); + + Ok(()) +}