diff --git a/src/ai/mod.rs b/src/ai/mod.rs index d6fda90..cea8103 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -8,6 +8,7 @@ pub mod llm_client; pub mod ollama; pub mod openrouter; pub mod sms_client; +pub mod tag_client; // strip_summary_boilerplate is used by binaries (test_daily_summary), not the library #[allow(unused_imports)] diff --git a/src/ai/tag_client.rs b/src/ai/tag_client.rs new file mode 100644 index 0000000..2785aeb --- /dev/null +++ b/src/ai/tag_client.rs @@ -0,0 +1,319 @@ +//! Thin async HTTP client for Apollo's `/api/internal/tags/*` endpoints. +//! +//! Apollo hosts the RAM++ auto-tag inference service alongside insightface. +//! This client is the ImageApi side — shove image bytes through `/auto` and +//! get back a list of `(name, confidence)` predictions over RAM++'s +//! ~4585-tag vocabulary. +//! +//! Mirrors `face_client.rs` shape: optional base URL (None = disabled), one +//! reqwest client with a generous timeout because GPU inference under a +//! backlog can queue server-side (Apollo's threadpool is bounded to 1 +//! worker on CUDA). +//! +//! Configured via `APOLLO_TAG_API_BASE_URL`, falling back to +//! `APOLLO_API_BASE_URL` when the dedicated var is unset (single-Apollo +//! deploys are the common case). Both unset → `is_enabled()` returns false +//! and the probe binary / future backlog drain no-op. +//! +//! Wire format: multipart/form-data with `file=` and `meta=`. +//! `meta` carries `{content_hash, library_id, rel_path, threshold?}` — +//! Apollo logs the path/lib for traceability and reads `threshold` to +//! override the engine default for that call (the probe binary uses this +//! to sweep without restarting Apollo). +//! +//! Error mapping (reflected in [`TagDetectError`]): +//! - 422 `decode_failed` → permanent: ImageApi marks `status='failed'` and +//! doesn't retry until a manual rerun. +//! - 200 with `tags:[]` → `status='no_tags'` marker (success-with-zero). +//! - 503 `cuda_oom` / `engine_unavailable` → defer-and-retry: no marker +//! written. +//! - Any other 5xx / network error → defer. + +use anyhow::{Context, Result}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Debug, Clone, Serialize)] +pub struct TagMeta { + pub content_hash: String, + pub library_id: i32, + pub rel_path: String, + /// Per-call threshold override. Apollo's engine default (0.68 for + /// ram_plus_swin_large_14m) is used when unset. The probe binary + /// uses this to sweep without restarting Apollo. + #[serde(skip_serializing_if = "Option::is_none")] + pub threshold: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TagPrediction { + pub name: String, + pub confidence: f32, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TagResponse { + pub model_version: String, + pub duration_ms: i64, + pub threshold: f32, + pub tags: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +#[allow(dead_code)] // Reported by Apollo; load_error consumed by future health probe +pub struct TagHealth { + pub loaded: bool, + pub device: String, + pub model_version: String, + pub image_size: i32, + pub threshold: f32, + #[serde(default)] + pub load_error: Option, +} + +/// Distinguishes permanent failures (don't retry) from transient ones +/// (defer and retry on next scan tick). Mirrors `FaceDetectError` so the +/// future backlog drain can use the same marker-row decision tree. +#[derive(Debug)] +pub enum TagDetectError { + /// Apollo refused the bytes for a reason that won't change on retry + /// (decode failure, zero-dim image). Mark `status='failed'`. + Permanent(anyhow::Error), + /// Apollo couldn't process this turn but might next time (CUDA OOM, + /// engine not loaded yet, network hiccup). Don't mark anything. + Transient(anyhow::Error), + /// Feature is disabled (no APOLLO_TAG_API_BASE_URL / APOLLO_API_BASE_URL). + /// Caller should silently no-op. + Disabled, +} + +impl std::fmt::Display for TagDetectError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TagDetectError::Permanent(e) => write!(f, "permanent: {e}"), + TagDetectError::Transient(e) => write!(f, "transient: {e}"), + TagDetectError::Disabled => write!(f, "tag client disabled"), + } + } +} + +impl std::error::Error for TagDetectError {} + +#[derive(Clone)] +pub struct TagClient { + client: Client, + /// `None` → disabled. Trailing slash trimmed at construction so url + /// building doesn't double up. + base_url: Option, +} + +impl TagClient { + pub fn new(base_url: Option) -> Self { + // 60 s timeout: GPU inference is fast (~50–150 ms on RTX-class + // hardware) but Apollo's 1-worker threadpool means a backlog drain + // queues server-side. 60 s is enough headroom for a small queue + // depth without surfacing a false transient. + let timeout_secs = std::env::var("TAG_DETECT_TIMEOUT_SEC") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(60); + let client = Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build() + .expect("reqwest client build"); + Self { + client, + base_url: base_url.map(|u| u.trim_end_matches('/').to_string()), + } + } + + /// Construct a client from the standard env vars. APOLLO_TAG_API_BASE_URL + /// wins; falls back to APOLLO_API_BASE_URL. Both unset → disabled. + pub fn from_env() -> Self { + let base = std::env::var("APOLLO_TAG_API_BASE_URL") + .ok() + .filter(|s| !s.trim().is_empty()) + .or_else(|| { + std::env::var("APOLLO_API_BASE_URL") + .ok() + .filter(|s| !s.trim().is_empty()) + }); + Self::new(base) + } + + pub fn is_enabled(&self) -> bool { + self.base_url.is_some() + } + + /// Run RAM++ auto-tagging over `bytes`. Empty `tags[]` is the no-tags + /// signal — caller writes a marker row in the persistence phase. + pub async fn auto_tag( + &self, + bytes: Vec, + meta: TagMeta, + ) -> std::result::Result { + let Some(base) = self.base_url.as_deref() else { + return Err(TagDetectError::Disabled); + }; + let url = format!("{}/api/internal/tags/auto", base); + self.post_multipart(&url, bytes, &meta).await + } + + /// Engine reachability + device/model report. + #[allow(dead_code)] // consumed by future startup probe + pub async fn health(&self) -> Result { + let base = self.base_url.as_deref().context("tag client disabled")?; + let url = format!("{}/api/internal/tags/health", base); + let resp = self.client.get(&url).send().await?.error_for_status()?; + let body: TagHealth = resp.json().await?; + Ok(body) + } + + async fn post_multipart( + &self, + url: &str, + bytes: Vec, + meta: &TagMeta, + ) -> std::result::Result { + let meta_json = serde_json::to_string(meta) + .map_err(|e| TagDetectError::Permanent(anyhow::anyhow!("meta serialize: {e}")))?; + let form = reqwest::multipart::Form::new() + .text("meta", meta_json) + .part( + "file", + reqwest::multipart::Part::bytes(bytes) + .file_name(meta.rel_path.clone()) + .mime_str("application/octet-stream") + .unwrap_or_else(|_| reqwest::multipart::Part::bytes(Vec::new())), + ); + + let resp = match self.client.post(url).multipart(form).send().await { + Ok(r) => r, + Err(e) if e.is_timeout() || e.is_connect() => { + return Err(TagDetectError::Transient(anyhow::anyhow!( + "tag client network: {e}" + ))); + } + Err(e) => { + return Err(TagDetectError::Transient(anyhow::anyhow!( + "tag client request: {e}" + ))); + } + }; + + let status = resp.status(); + if status.is_success() { + let body: TagResponse = resp.json().await.map_err(|e| { + TagDetectError::Transient(anyhow::anyhow!("tag response decode: {e}")) + })?; + return Ok(body); + } + + let body_text = resp.text().await.unwrap_or_default(); + Err(classify_error_response(status.as_u16(), &body_text)) + } +} + +/// Pulled out as a pure function so the marker-row contract is unit-testable +/// without spinning up an HTTP server. Behavior matches face_client::classify +/// so the future backlog drain can share the same retry policy. +fn classify_error_response(status: u16, body_text: &str) -> TagDetectError { + let detail_code = serde_json::from_str::(body_text) + .ok() + .and_then(|v| { + v.get("detail") + .and_then(|d| d.as_str().map(str::to_string)) + .or_else(|| { + v.get("detail") + .and_then(|d| d.get("code")) + .and_then(|c| c.as_str()) + .map(str::to_string) + }) + }) + .unwrap_or_default(); + + if status == 422 { + return TagDetectError::Permanent(anyhow::anyhow!( + "tag detect 422 {}: {}", + detail_code, + body_text + )); + } + if status == 503 { + return TagDetectError::Transient(anyhow::anyhow!( + "tag detect 503 {}: {}", + detail_code, + body_text + )); + } + // 408 / 413 / 429 are operator-fixable infra issues — defer so the + // next pass retries naturally once the proxy is fixed (see + // face_client::classify_error_response for the cautionary tale). + if matches!(status, 408 | 413 | 429) { + return TagDetectError::Transient(anyhow::anyhow!( + "tag detect {} {}: {}", + status, + detail_code, + body_text + )); + } + if (400..500).contains(&status) { + TagDetectError::Permanent(anyhow::anyhow!( + "tag detect {} {}: {}", + status, + detail_code, + body_text + )) + } else { + TagDetectError::Transient(anyhow::anyhow!( + "tag detect {} {}: {}", + status, + detail_code, + body_text + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn is_permanent(e: &TagDetectError) -> bool { + matches!(e, TagDetectError::Permanent(_)) + } + fn is_transient(e: &TagDetectError) -> bool { + matches!(e, TagDetectError::Transient(_)) + } + + #[test] + fn classify_422_decode_failed_is_permanent() { + let e = classify_error_response(422, r#"{"detail":"decode_failed: bad bytes"}"#); + assert!(is_permanent(&e)); + assert!(format!("{e}").contains("decode_failed")); + } + + #[test] + fn classify_503_cuda_oom_is_transient() { + let e = classify_error_response( + 503, + r#"{"detail":{"code":"cuda_oom","error":"out of memory"}}"#, + ); + assert!(is_transient(&e)); + assert!(format!("{e}").contains("cuda_oom")); + } + + #[test] + fn classify_5xx_is_transient_other_4xx_is_permanent() { + assert!(is_transient(&classify_error_response(500, ""))); + assert!(is_permanent(&classify_error_response(400, "{}"))); + assert!(is_permanent(&classify_error_response(404, "{}"))); + } + + #[test] + fn classify_infra_4xx_is_transient() { + assert!(is_transient(&classify_error_response(408, ""))); + assert!(is_transient(&classify_error_response(413, ""))); + assert!(is_transient(&classify_error_response(429, "{}"))); + } +} diff --git a/src/bin/probe_auto_tags.rs b/src/bin/probe_auto_tags.rs new file mode 100644 index 0000000..6531104 --- /dev/null +++ b/src/bin/probe_auto_tags.rs @@ -0,0 +1,250 @@ +//! Probe binary for RAM++ auto-tagging. +//! +//! No DB writes. Walks a library's `image_exif` rows, sends a sample +//! through Apollo's `/api/internal/tags/auto`, and prints `(path, tags)` +//! to stdout so the operator can eyeball whether the model's vocabulary +//! and threshold defaults are appropriate for this library before +//! committing to the persistence phase (new table, per-tick drain, UI). +//! +//! Usage: +//! cargo run --release --bin probe_auto_tags -- \ +//! --library 1 --limit 50 --threshold 0.7 +//! +//! Env: standard ImageApi `.env`. Requires either +//! `APOLLO_TAG_API_BASE_URL` or `APOLLO_API_BASE_URL` to be set +//! (otherwise the client is disabled and the probe bails). + +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +use clap::Parser; +use log::{info, warn}; + +use image_api::ai::tag_client::{TagClient, TagDetectError, TagMeta}; +use image_api::database::{ExifDao, SqliteExifDao, connect}; +use image_api::exif; +use image_api::file_types; +use image_api::libraries::{self, Library}; + +#[derive(Parser, Debug)] +#[command(name = "probe_auto_tags")] +#[command(about = "Print RAM++ auto-tags for a sample of image_exif rows")] +struct Args { + /// Library id to sample from. + #[arg(long)] + library: i32, + + /// Max files to probe. The binary scans more rows internally because + /// non-image rows (videos, junk) are skipped client-side. + #[arg(long, default_value_t = 25)] + limit: usize, + + /// Per-call threshold sent to Apollo. Overrides the engine default. + /// Lower = more tags per photo, more noise. 0.5–0.75 is the useful + /// sweep range for ram_plus_swin_large_14m. + #[arg(long, default_value_t = 0.65)] + threshold: f32, + + /// Offset into the library's rel_path listing (sorted by id ASC). + /// Bump on re-runs to sample a different slice. + #[arg(long, default_value_t = 0)] + offset: i64, + + /// How many DB rows to scan before giving up on hitting the limit. + /// Useful when a library is mostly videos. + #[arg(long, default_value_t = 2000)] + max_scan: i64, +} + +/// Mirror of `face_watch::read_image_bytes_for_detect` — it's pub(crate) +/// so we can't import it across the bin boundary. The probe is throwaway +/// scope; inlining is cleaner than changing the visibility. +fn read_image_bytes(path: &Path) -> std::io::Result> { + if file_types::needs_ffmpeg_thumbnail(path) + && let Some(preview) = exif::extract_embedded_jpeg_preview(path) + { + return Ok(preview); + } + std::fs::read(path) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + env_logger::init(); + dotenv::dotenv().ok(); + + let args = Args::parse(); + + let client = TagClient::from_env(); + if !client.is_enabled() { + anyhow::bail!( + "TagClient disabled: set APOLLO_TAG_API_BASE_URL or APOLLO_API_BASE_URL in .env" + ); + } + + // Quick health probe so we fail fast on a misconfig before grinding + // through a thousand rows. + match client.health().await { + Ok(h) => info!( + "tag engine: loaded={} device={} model={} threshold_default={}", + h.loaded, h.device, h.model_version, h.threshold + ), + Err(e) => warn!("health probe failed (continuing): {e}"), + } + + let mut seed_conn = connect(); + if let Some(base) = dotenv::var("BASE_PATH").ok().as_deref() { + libraries::seed_or_patch_from_env(&mut seed_conn, base); + } + let libs = libraries::load_all(&mut seed_conn); + drop(seed_conn); + let lib: Library = libs + .into_iter() + .find(|l| l.id == args.library) + .ok_or_else(|| anyhow::anyhow!("library id {} not found", args.library))?; + info!("probing library #{} ({}) at {}", lib.id, lib.name, lib.root_path); + + let dao: Arc>> = Arc::new(Mutex::new(Box::new(SqliteExifDao::new()))); + let ctx = opentelemetry::Context::new(); + + // Paginate through (id, rel_path) for this library, filter to images + // on disk, take `limit`. Page size is tuned so we don't slam the DB + // when a library is video-heavy. + const PAGE: i64 = 500; + let mut offset = args.offset; + let mut scanned: i64 = 0; + let mut probed = 0usize; + let mut ok_count = 0usize; + let mut empty_count = 0usize; + let mut perm_fail = 0usize; + let mut transient_fail = 0usize; + let started = Instant::now(); + let root = PathBuf::from(&lib.root_path); + + 'outer: loop { + if scanned >= args.max_scan { + warn!( + "scan cap ({}) reached before hitting limit ({}); bump --max-scan to scan deeper", + args.max_scan, args.limit + ); + break; + } + let rows = { + let mut guard = dao.lock().expect("dao lock"); + guard + .list_rel_paths_for_library_page(&ctx, lib.id, PAGE, offset) + .map_err(|e| anyhow::anyhow!("list rel_paths: {:?}", e))? + }; + if rows.is_empty() { + info!("no more rows after offset {}", offset); + break; + } + offset += rows.len() as i64; + scanned += rows.len() as i64; + + for (_id, rel_path) in rows { + if probed >= args.limit { + break 'outer; + } + let abs = root.join(&rel_path); + // Skip non-images and videos at the path level — same logic + // the face backlog drain uses, just inlined. + if !file_types::is_image_file(&abs) { + continue; + } + if !abs.exists() { + continue; + } + let bytes = match read_image_bytes(&abs) { + Ok(b) => b, + Err(e) => { + warn!("read {rel_path}: {e}"); + continue; + } + }; + // The probe doesn't need a real content_hash — Apollo only + // logs it. Pass an empty marker so we don't trip on no-hash + // image_exif rows. + let meta = TagMeta { + content_hash: String::new(), + library_id: lib.id, + rel_path: rel_path.clone(), + threshold: Some(args.threshold), + }; + + let call_start = Instant::now(); + match client.auto_tag(bytes, meta).await { + Ok(resp) => { + probed += 1; + if resp.tags.is_empty() { + empty_count += 1; + println!( + "[{:>3}] (no tags) {}ms {}", + probed, resp.duration_ms, rel_path + ); + } else { + ok_count += 1; + let preview = resp + .tags + .iter() + .map(|t| format!("{}({:.2})", t.name, t.confidence)) + .collect::>() + .join(", "); + println!( + "[{:>3}] {} tags {}ms {}\n {}", + probed, + resp.tags.len(), + resp.duration_ms, + rel_path, + preview + ); + } + } + Err(TagDetectError::Permanent(e)) => { + probed += 1; + perm_fail += 1; + println!( + "[{:>3}] PERMANENT FAIL ({:>4}ms) {}\n {}", + probed, + call_start.elapsed().as_millis(), + rel_path, + e + ); + } + Err(TagDetectError::Transient(e)) => { + probed += 1; + transient_fail += 1; + println!( + "[{:>3}] TRANSIENT FAIL ({:>4}ms) {}\n {}", + probed, + call_start.elapsed().as_millis(), + rel_path, + e + ); + } + Err(TagDetectError::Disabled) => { + anyhow::bail!("tag client became disabled mid-run; impossible"); + } + } + } + } + + let elapsed = started.elapsed(); + println!(); + println!("── summary ───────────────────────────────────────"); + println!("scanned rows : {scanned}"); + println!("probed files : {probed}"); + println!(" with tags : {ok_count}"); + println!(" empty (no tags) : {empty_count}"); + println!(" permanent failures : {perm_fail}"); + println!(" transient failures : {transient_fail}"); + println!("elapsed : {:.1}s", elapsed.as_secs_f32()); + if probed > 0 { + println!( + "throughput : {:.2} photos/s", + probed as f32 / elapsed.as_secs_f32().max(0.001) + ); + } + Ok(()) +}