- Drop redundant `use anyhow::Context` inside has_any_faces (already
imported at the module level).
- Drop dead `.unwrap_or("?")` on bound faces — the vec is filtered to
is_some() so the fallback can never fire.
- Reorder the face_dao constructor param + initializer to match the
struct declaration (between tag_dao and knowledge_dao). Update both
state.rs call sites and populate_knowledge.rs to match.
- Hold face_dao lock once across the library-resolver loop instead of
reacquiring per iteration.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
361 lines
12 KiB
Rust
361 lines
12 KiB
Rust
use std::path::{Path, PathBuf};
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use clap::Parser;
|
|
use log::warn;
|
|
use walkdir::WalkDir;
|
|
|
|
use image_api::ai::apollo_client::ApolloClient;
|
|
use image_api::ai::{InsightGenerator, OllamaClient, SmsApiClient};
|
|
use image_api::bin_progress;
|
|
use image_api::database::{
|
|
CalendarEventDao, DailySummaryDao, ExifDao, InsightDao, KnowledgeDao, LocationHistoryDao,
|
|
SearchHistoryDao, SqliteCalendarEventDao, SqliteDailySummaryDao, SqliteExifDao,
|
|
SqliteInsightDao, SqliteKnowledgeDao, SqliteLocationHistoryDao, SqliteSearchHistoryDao,
|
|
connect,
|
|
};
|
|
use image_api::faces::{FaceDao, SqliteFaceDao};
|
|
use image_api::file_types::{IMAGE_EXTENSIONS, VIDEO_EXTENSIONS};
|
|
use image_api::libraries::{self, Library};
|
|
use image_api::tags::{SqliteTagDao, TagDao};
|
|
|
|
#[derive(Parser, Debug)]
|
|
#[command(name = "populate_knowledge")]
|
|
#[command(
|
|
about = "Batch populate the knowledge base by running the agentic insight loop over a folder"
|
|
)]
|
|
struct Args {
|
|
/// Restrict to a single library by numeric id or name. Defaults to all
|
|
/// configured libraries.
|
|
#[arg(long)]
|
|
library: Option<String>,
|
|
|
|
/// Optional subdirectory to scan instead of full library roots. Must be
|
|
/// an absolute path under one of the selected libraries.
|
|
#[arg(long)]
|
|
path: Option<String>,
|
|
|
|
/// Ollama model override. Defaults to OLLAMA_PRIMARY_MODEL from .env
|
|
#[arg(long)]
|
|
model: Option<String>,
|
|
|
|
/// Maximum agentic loop iterations per file
|
|
#[arg(long, default_value_t = 12)]
|
|
max_iterations: usize,
|
|
|
|
/// HTTP request timeout in seconds. Increase for large/slow models
|
|
#[arg(long, default_value_t = 120)]
|
|
timeout_secs: u64,
|
|
|
|
/// Context window size (num_ctx) passed to the model
|
|
#[arg(long)]
|
|
num_ctx: Option<i32>,
|
|
|
|
/// Sampling temperature (e.g. 0.8). Omit for model default
|
|
#[arg(long)]
|
|
temperature: Option<f32>,
|
|
|
|
/// Top-p (nucleus) sampling (e.g. 0.9). Omit for model default
|
|
#[arg(long)]
|
|
top_p: Option<f32>,
|
|
|
|
/// Top-k sampling (e.g. 40). Omit for model default
|
|
#[arg(long)]
|
|
top_k: Option<i32>,
|
|
|
|
/// Min-p sampling (e.g. 0.05). Omit for model default
|
|
#[arg(long)]
|
|
min_p: Option<f32>,
|
|
|
|
/// Re-process files that already have an insight stored
|
|
#[arg(long, default_value_t = false)]
|
|
reprocess: bool,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
env_logger::init();
|
|
dotenv::dotenv().ok();
|
|
|
|
let args = Args::parse();
|
|
|
|
// Load libraries from the DB. Patch the placeholder row from BASE_PATH
|
|
// first when present so a fresh install still gets a valid root.
|
|
let env_base_path = dotenv::var("BASE_PATH").ok();
|
|
let mut seed_conn = connect();
|
|
if let Some(base) = env_base_path.as_deref() {
|
|
libraries::seed_or_patch_from_env(&mut seed_conn, base);
|
|
}
|
|
let all_libs = libraries::load_all(&mut seed_conn);
|
|
drop(seed_conn);
|
|
if all_libs.is_empty() {
|
|
anyhow::bail!("No libraries configured");
|
|
}
|
|
|
|
// Resolve --library to a concrete subset.
|
|
let selected_libs: Vec<Library> = match args.library.as_deref() {
|
|
None => all_libs.clone(),
|
|
Some(raw) => {
|
|
let raw = raw.trim();
|
|
let matched = if let Ok(id) = raw.parse::<i32>() {
|
|
all_libs.iter().find(|l| l.id == id).cloned()
|
|
} else {
|
|
all_libs.iter().find(|l| l.name == raw).cloned()
|
|
};
|
|
match matched {
|
|
Some(lib) => vec![lib],
|
|
None => anyhow::bail!("Unknown library: {}", raw),
|
|
}
|
|
}
|
|
};
|
|
|
|
// Resolve --path to (target_library, walk_root). When provided, the path
|
|
// must live under exactly one of the selected libraries.
|
|
let scan_targets: Vec<(Library, PathBuf)> = match args.path.as_deref() {
|
|
None => selected_libs
|
|
.iter()
|
|
.map(|lib| (lib.clone(), PathBuf::from(&lib.root_path)))
|
|
.collect(),
|
|
Some(raw) => {
|
|
let abs = PathBuf::from(raw);
|
|
let matched = selected_libs
|
|
.iter()
|
|
.find(|lib| abs.starts_with(&lib.root_path))
|
|
.cloned();
|
|
match matched {
|
|
Some(lib) => vec![(lib, abs)],
|
|
None => anyhow::bail!("--path {} is not under any selected library root", raw),
|
|
}
|
|
}
|
|
};
|
|
|
|
// Ollama config from env with CLI overrides.
|
|
let primary_url = std::env::var("OLLAMA_PRIMARY_URL")
|
|
.or_else(|_| std::env::var("OLLAMA_URL"))
|
|
.unwrap_or_else(|_| "http://localhost:11434".to_string());
|
|
let fallback_url = std::env::var("OLLAMA_FALLBACK_URL").ok();
|
|
let primary_model = args
|
|
.model
|
|
.clone()
|
|
.or_else(|| std::env::var("OLLAMA_PRIMARY_MODEL").ok())
|
|
.or_else(|| std::env::var("OLLAMA_MODEL").ok())
|
|
.unwrap_or_else(|| "nemotron-3-nano:30b".to_string());
|
|
let fallback_model = std::env::var("OLLAMA_FALLBACK_MODEL").ok();
|
|
|
|
let mut ollama = OllamaClient::new(
|
|
primary_url.clone(),
|
|
fallback_url,
|
|
primary_model.clone(),
|
|
fallback_model,
|
|
)
|
|
.with_request_timeout(args.timeout_secs);
|
|
|
|
if let Some(ctx) = args.num_ctx {
|
|
ollama.set_num_ctx(Some(ctx));
|
|
}
|
|
if args.temperature.is_some()
|
|
|| args.top_p.is_some()
|
|
|| args.top_k.is_some()
|
|
|| args.min_p.is_some()
|
|
{
|
|
ollama.set_sampling_params(args.temperature, args.top_p, args.top_k, args.min_p);
|
|
}
|
|
|
|
let sms_api_url =
|
|
std::env::var("SMS_API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string());
|
|
let sms_api_token = std::env::var("SMS_API_TOKEN").ok();
|
|
let sms_client = SmsApiClient::new(sms_api_url, sms_api_token);
|
|
let apollo_client = ApolloClient::new(std::env::var("APOLLO_API_BASE_URL").ok());
|
|
|
|
let insight_dao: Arc<Mutex<Box<dyn InsightDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteInsightDao::new())));
|
|
let exif_dao: Arc<Mutex<Box<dyn ExifDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteExifDao::new())));
|
|
let daily_summary_dao: Arc<Mutex<Box<dyn DailySummaryDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteDailySummaryDao::new())));
|
|
let calendar_dao: Arc<Mutex<Box<dyn CalendarEventDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteCalendarEventDao::new())));
|
|
let location_dao: Arc<Mutex<Box<dyn LocationHistoryDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteLocationHistoryDao::new())));
|
|
let search_dao: Arc<Mutex<Box<dyn SearchHistoryDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteSearchHistoryDao::new())));
|
|
let tag_dao: Arc<Mutex<Box<dyn TagDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteTagDao::default())));
|
|
let knowledge_dao: Arc<Mutex<Box<dyn KnowledgeDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteKnowledgeDao::new())));
|
|
let face_dao: Arc<Mutex<Box<dyn FaceDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteFaceDao::new())));
|
|
|
|
// Pass the full library set so `resolve_full_path` probes every root,
|
|
// even when --library restricts the walk. A rel_path shared across
|
|
// libraries will resolve against the first existing match.
|
|
let generator = InsightGenerator::new(
|
|
ollama,
|
|
None,
|
|
sms_client,
|
|
apollo_client,
|
|
insight_dao.clone(),
|
|
exif_dao,
|
|
daily_summary_dao,
|
|
calendar_dao,
|
|
location_dao,
|
|
search_dao,
|
|
tag_dao,
|
|
face_dao,
|
|
knowledge_dao,
|
|
all_libs.clone(),
|
|
);
|
|
|
|
println!("Knowledge Base Population");
|
|
println!("=========================");
|
|
for (lib, root) in &scan_targets {
|
|
println!("Library: {} (id={})", lib.name, lib.id);
|
|
println!("Scan root: {}", root.display());
|
|
}
|
|
println!("Model: {}", primary_model);
|
|
println!("Max iterations: {}", args.max_iterations);
|
|
println!("Timeout: {}s", args.timeout_secs);
|
|
if let Some(ctx) = args.num_ctx {
|
|
println!("Num ctx: {}", ctx);
|
|
}
|
|
if let Some(t) = args.temperature {
|
|
println!("Temperature: {}", t);
|
|
}
|
|
if let Some(p) = args.top_p {
|
|
println!("Top P: {}", p);
|
|
}
|
|
if let Some(k) = args.top_k {
|
|
println!("Top K: {}", k);
|
|
}
|
|
if let Some(m) = args.min_p {
|
|
println!("Min P: {}", m);
|
|
}
|
|
println!(
|
|
"Mode: {}",
|
|
if args.reprocess {
|
|
"reprocess all"
|
|
} else {
|
|
"skip existing"
|
|
}
|
|
);
|
|
println!();
|
|
|
|
let all_extensions: Vec<&str> = IMAGE_EXTENSIONS
|
|
.iter()
|
|
.chain(VIDEO_EXTENSIONS.iter())
|
|
.copied()
|
|
.collect();
|
|
|
|
// Collect (library, abs_path, rel_path) for every media file across all
|
|
// scan targets so the progress counter spans the full job.
|
|
let mut files: Vec<(Library, PathBuf, String)> = Vec::new();
|
|
for (lib, walk_root) in &scan_targets {
|
|
let lib_root = Path::new(&lib.root_path);
|
|
let scan_pb = bin_progress::spinner(format!("scanning {}", walk_root.display()));
|
|
let count_before = files.len();
|
|
for entry in WalkDir::new(walk_root).into_iter().filter_map(|e| e.ok()) {
|
|
if !entry.file_type().is_file() {
|
|
continue;
|
|
}
|
|
let abs_path = entry.path().to_path_buf();
|
|
let ext_ok = abs_path
|
|
.extension()
|
|
.and_then(|ext| ext.to_str())
|
|
.map(|ext| all_extensions.contains(&ext.to_lowercase().as_str()))
|
|
.unwrap_or(false);
|
|
if !ext_ok {
|
|
continue;
|
|
}
|
|
let rel = match abs_path.strip_prefix(lib_root) {
|
|
Ok(p) => p.to_string_lossy().replace('\\', "/"),
|
|
Err(_) => {
|
|
warn!(
|
|
"{} is not under library root {}; skipping",
|
|
abs_path.display(),
|
|
lib_root.display()
|
|
);
|
|
continue;
|
|
}
|
|
};
|
|
files.push((lib.clone(), abs_path, rel));
|
|
scan_pb.inc(1);
|
|
}
|
|
let added = files.len() - count_before;
|
|
scan_pb.finish_with_message(format!(
|
|
"scanned {} ({} media files)",
|
|
walk_root.display(),
|
|
added
|
|
));
|
|
}
|
|
|
|
let total = files.len();
|
|
println!("\nTotal files to consider: {}\n", total);
|
|
|
|
if total == 0 {
|
|
println!("Nothing to process.");
|
|
return Ok(());
|
|
}
|
|
|
|
let cx = opentelemetry::Context::new();
|
|
let mut processed = 0usize;
|
|
let mut skipped = 0usize;
|
|
let mut errors = 0usize;
|
|
|
|
let pb = bin_progress::determinate(total as u64, "");
|
|
|
|
for (lib, _abs_path, relative) in files.iter() {
|
|
pb.set_message(format!("{}: {}", lib.name, relative));
|
|
|
|
if !args.reprocess {
|
|
let has_insight = insight_dao
|
|
.lock()
|
|
.unwrap()
|
|
.get_insight(&cx, relative)
|
|
.unwrap_or(None)
|
|
.is_some();
|
|
|
|
if has_insight {
|
|
skipped += 1;
|
|
pb.inc(1);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
match generator
|
|
.generate_agentic_insight_for_photo(
|
|
relative,
|
|
args.model.clone(),
|
|
None,
|
|
args.num_ctx,
|
|
args.temperature,
|
|
args.top_p,
|
|
args.top_k,
|
|
args.min_p,
|
|
args.max_iterations,
|
|
None,
|
|
Vec::new(),
|
|
Vec::new(),
|
|
)
|
|
.await
|
|
{
|
|
Ok(_) => processed += 1,
|
|
Err(e) => {
|
|
pb.println(format!("error {}: {} — {:?}", lib.name, relative, e));
|
|
errors += 1;
|
|
}
|
|
}
|
|
pb.inc(1);
|
|
}
|
|
|
|
pb.finish_and_clear();
|
|
|
|
println!();
|
|
println!("=========================");
|
|
println!("Complete");
|
|
println!(" Processed: {}", processed);
|
|
println!(" Skipped: {}", skipped);
|
|
println!(" Errors: {}", errors);
|
|
|
|
Ok(())
|
|
}
|