Compare commits

1 Commits

Author SHA1 Message Date
Cameron Cordes
dbbd4470a5 auto-tag: Apollo tag client + probe binary
Adds ai::tag_client mirroring face_client for Apollo's RAM++ endpoint
(APOLLO_TAG_API_BASE_URL falling back to APOLLO_API_BASE_URL), and a
throwaway probe_auto_tags binary that walks image_exif and prints tags
without writing the DB. Lets us eyeball RAM++ output quality + threshold
before committing to a schema and per-tick drain.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 20:01:55 -04:00
3 changed files with 570 additions and 0 deletions

View File

@@ -8,6 +8,7 @@ pub mod llm_client;
pub mod ollama;
pub mod openrouter;
pub mod sms_client;
pub mod tag_client;
// strip_summary_boilerplate is used by binaries (test_daily_summary), not the library
#[allow(unused_imports)]

319
src/ai/tag_client.rs Normal file
View File

@@ -0,0 +1,319 @@
//! Thin async HTTP client for Apollo's `/api/internal/tags/*` endpoints.
//!
//! Apollo hosts the RAM++ auto-tag inference service alongside insightface.
//! This client is the ImageApi side — shove image bytes through `/auto` and
//! get back a list of `(name, confidence)` predictions over RAM++'s
//! ~4585-tag vocabulary.
//!
//! Mirrors `face_client.rs` shape: optional base URL (None = disabled), one
//! reqwest client with a generous timeout because GPU inference under a
//! backlog can queue server-side (Apollo's threadpool is bounded to 1
//! worker on CUDA).
//!
//! Configured via `APOLLO_TAG_API_BASE_URL`, falling back to
//! `APOLLO_API_BASE_URL` when the dedicated var is unset (single-Apollo
//! deploys are the common case). Both unset → `is_enabled()` returns false
//! and the probe binary / future backlog drain no-op.
//!
//! Wire format: multipart/form-data with `file=<bytes>` and `meta=<json>`.
//! `meta` carries `{content_hash, library_id, rel_path, threshold?}` —
//! Apollo logs the path/lib for traceability and reads `threshold` to
//! override the engine default for that call (the probe binary uses this
//! to sweep without restarting Apollo).
//!
//! Error mapping (reflected in [`TagDetectError`]):
//! - 422 `decode_failed` → permanent: ImageApi marks `status='failed'` and
//! doesn't retry until a manual rerun.
//! - 200 with `tags:[]` → `status='no_tags'` marker (success-with-zero).
//! - 503 `cuda_oom` / `engine_unavailable` → defer-and-retry: no marker
//! written.
//! - Any other 5xx / network error → defer.
use anyhow::{Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Serialize)]
pub struct TagMeta {
pub content_hash: String,
pub library_id: i32,
pub rel_path: String,
/// Per-call threshold override. Apollo's engine default (0.68 for
/// ram_plus_swin_large_14m) is used when unset. The probe binary
/// uses this to sweep without restarting Apollo.
#[serde(skip_serializing_if = "Option::is_none")]
pub threshold: Option<f32>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TagPrediction {
pub name: String,
pub confidence: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TagResponse {
pub model_version: String,
pub duration_ms: i64,
pub threshold: f32,
pub tags: Vec<TagPrediction>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)] // Reported by Apollo; load_error consumed by future health probe
pub struct TagHealth {
pub loaded: bool,
pub device: String,
pub model_version: String,
pub image_size: i32,
pub threshold: f32,
#[serde(default)]
pub load_error: Option<String>,
}
/// Distinguishes permanent failures (don't retry) from transient ones
/// (defer and retry on next scan tick). Mirrors `FaceDetectError` so the
/// future backlog drain can use the same marker-row decision tree.
#[derive(Debug)]
pub enum TagDetectError {
/// Apollo refused the bytes for a reason that won't change on retry
/// (decode failure, zero-dim image). Mark `status='failed'`.
Permanent(anyhow::Error),
/// Apollo couldn't process this turn but might next time (CUDA OOM,
/// engine not loaded yet, network hiccup). Don't mark anything.
Transient(anyhow::Error),
/// Feature is disabled (no APOLLO_TAG_API_BASE_URL / APOLLO_API_BASE_URL).
/// Caller should silently no-op.
Disabled,
}
impl std::fmt::Display for TagDetectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TagDetectError::Permanent(e) => write!(f, "permanent: {e}"),
TagDetectError::Transient(e) => write!(f, "transient: {e}"),
TagDetectError::Disabled => write!(f, "tag client disabled"),
}
}
}
impl std::error::Error for TagDetectError {}
#[derive(Clone)]
pub struct TagClient {
client: Client,
/// `None` → disabled. Trailing slash trimmed at construction so url
/// building doesn't double up.
base_url: Option<String>,
}
impl TagClient {
pub fn new(base_url: Option<String>) -> Self {
// 60 s timeout: GPU inference is fast (~50150 ms on RTX-class
// hardware) but Apollo's 1-worker threadpool means a backlog drain
// queues server-side. 60 s is enough headroom for a small queue
// depth without surfacing a false transient.
let timeout_secs = std::env::var("TAG_DETECT_TIMEOUT_SEC")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(60);
let client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.expect("reqwest client build");
Self {
client,
base_url: base_url.map(|u| u.trim_end_matches('/').to_string()),
}
}
/// Construct a client from the standard env vars. APOLLO_TAG_API_BASE_URL
/// wins; falls back to APOLLO_API_BASE_URL. Both unset → disabled.
pub fn from_env() -> Self {
let base = std::env::var("APOLLO_TAG_API_BASE_URL")
.ok()
.filter(|s| !s.trim().is_empty())
.or_else(|| {
std::env::var("APOLLO_API_BASE_URL")
.ok()
.filter(|s| !s.trim().is_empty())
});
Self::new(base)
}
pub fn is_enabled(&self) -> bool {
self.base_url.is_some()
}
/// Run RAM++ auto-tagging over `bytes`. Empty `tags[]` is the no-tags
/// signal — caller writes a marker row in the persistence phase.
pub async fn auto_tag(
&self,
bytes: Vec<u8>,
meta: TagMeta,
) -> std::result::Result<TagResponse, TagDetectError> {
let Some(base) = self.base_url.as_deref() else {
return Err(TagDetectError::Disabled);
};
let url = format!("{}/api/internal/tags/auto", base);
self.post_multipart(&url, bytes, &meta).await
}
/// Engine reachability + device/model report.
#[allow(dead_code)] // consumed by future startup probe
pub async fn health(&self) -> Result<TagHealth> {
let base = self.base_url.as_deref().context("tag client disabled")?;
let url = format!("{}/api/internal/tags/health", base);
let resp = self.client.get(&url).send().await?.error_for_status()?;
let body: TagHealth = resp.json().await?;
Ok(body)
}
async fn post_multipart(
&self,
url: &str,
bytes: Vec<u8>,
meta: &TagMeta,
) -> std::result::Result<TagResponse, TagDetectError> {
let meta_json = serde_json::to_string(meta)
.map_err(|e| TagDetectError::Permanent(anyhow::anyhow!("meta serialize: {e}")))?;
let form = reqwest::multipart::Form::new()
.text("meta", meta_json)
.part(
"file",
reqwest::multipart::Part::bytes(bytes)
.file_name(meta.rel_path.clone())
.mime_str("application/octet-stream")
.unwrap_or_else(|_| reqwest::multipart::Part::bytes(Vec::new())),
);
let resp = match self.client.post(url).multipart(form).send().await {
Ok(r) => r,
Err(e) if e.is_timeout() || e.is_connect() => {
return Err(TagDetectError::Transient(anyhow::anyhow!(
"tag client network: {e}"
)));
}
Err(e) => {
return Err(TagDetectError::Transient(anyhow::anyhow!(
"tag client request: {e}"
)));
}
};
let status = resp.status();
if status.is_success() {
let body: TagResponse = resp.json().await.map_err(|e| {
TagDetectError::Transient(anyhow::anyhow!("tag response decode: {e}"))
})?;
return Ok(body);
}
let body_text = resp.text().await.unwrap_or_default();
Err(classify_error_response(status.as_u16(), &body_text))
}
}
/// Pulled out as a pure function so the marker-row contract is unit-testable
/// without spinning up an HTTP server. Behavior matches face_client::classify
/// so the future backlog drain can share the same retry policy.
fn classify_error_response(status: u16, body_text: &str) -> TagDetectError {
let detail_code = serde_json::from_str::<serde_json::Value>(body_text)
.ok()
.and_then(|v| {
v.get("detail")
.and_then(|d| d.as_str().map(str::to_string))
.or_else(|| {
v.get("detail")
.and_then(|d| d.get("code"))
.and_then(|c| c.as_str())
.map(str::to_string)
})
})
.unwrap_or_default();
if status == 422 {
return TagDetectError::Permanent(anyhow::anyhow!(
"tag detect 422 {}: {}",
detail_code,
body_text
));
}
if status == 503 {
return TagDetectError::Transient(anyhow::anyhow!(
"tag detect 503 {}: {}",
detail_code,
body_text
));
}
// 408 / 413 / 429 are operator-fixable infra issues — defer so the
// next pass retries naturally once the proxy is fixed (see
// face_client::classify_error_response for the cautionary tale).
if matches!(status, 408 | 413 | 429) {
return TagDetectError::Transient(anyhow::anyhow!(
"tag detect {} {}: {}",
status,
detail_code,
body_text
));
}
if (400..500).contains(&status) {
TagDetectError::Permanent(anyhow::anyhow!(
"tag detect {} {}: {}",
status,
detail_code,
body_text
))
} else {
TagDetectError::Transient(anyhow::anyhow!(
"tag detect {} {}: {}",
status,
detail_code,
body_text
))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn is_permanent(e: &TagDetectError) -> bool {
matches!(e, TagDetectError::Permanent(_))
}
fn is_transient(e: &TagDetectError) -> bool {
matches!(e, TagDetectError::Transient(_))
}
#[test]
fn classify_422_decode_failed_is_permanent() {
let e = classify_error_response(422, r#"{"detail":"decode_failed: bad bytes"}"#);
assert!(is_permanent(&e));
assert!(format!("{e}").contains("decode_failed"));
}
#[test]
fn classify_503_cuda_oom_is_transient() {
let e = classify_error_response(
503,
r#"{"detail":{"code":"cuda_oom","error":"out of memory"}}"#,
);
assert!(is_transient(&e));
assert!(format!("{e}").contains("cuda_oom"));
}
#[test]
fn classify_5xx_is_transient_other_4xx_is_permanent() {
assert!(is_transient(&classify_error_response(500, "")));
assert!(is_permanent(&classify_error_response(400, "{}")));
assert!(is_permanent(&classify_error_response(404, "{}")));
}
#[test]
fn classify_infra_4xx_is_transient() {
assert!(is_transient(&classify_error_response(408, "")));
assert!(is_transient(&classify_error_response(413, "<html>")));
assert!(is_transient(&classify_error_response(429, "{}")));
}
}

250
src/bin/probe_auto_tags.rs Normal file
View File

@@ -0,0 +1,250 @@
//! Probe binary for RAM++ auto-tagging.
//!
//! No DB writes. Walks a library's `image_exif` rows, sends a sample
//! through Apollo's `/api/internal/tags/auto`, and prints `(path, tags)`
//! to stdout so the operator can eyeball whether the model's vocabulary
//! and threshold defaults are appropriate for this library before
//! committing to the persistence phase (new table, per-tick drain, UI).
//!
//! Usage:
//! cargo run --release --bin probe_auto_tags -- \
//! --library 1 --limit 50 --threshold 0.7
//!
//! Env: standard ImageApi `.env`. Requires either
//! `APOLLO_TAG_API_BASE_URL` or `APOLLO_API_BASE_URL` to be set
//! (otherwise the client is disabled and the probe bails).
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::Instant;
use clap::Parser;
use log::{info, warn};
use image_api::ai::tag_client::{TagClient, TagDetectError, TagMeta};
use image_api::database::{ExifDao, SqliteExifDao, connect};
use image_api::exif;
use image_api::file_types;
use image_api::libraries::{self, Library};
#[derive(Parser, Debug)]
#[command(name = "probe_auto_tags")]
#[command(about = "Print RAM++ auto-tags for a sample of image_exif rows")]
struct Args {
/// Library id to sample from.
#[arg(long)]
library: i32,
/// Max files to probe. The binary scans more rows internally because
/// non-image rows (videos, junk) are skipped client-side.
#[arg(long, default_value_t = 25)]
limit: usize,
/// Per-call threshold sent to Apollo. Overrides the engine default.
/// Lower = more tags per photo, more noise. 0.50.75 is the useful
/// sweep range for ram_plus_swin_large_14m.
#[arg(long, default_value_t = 0.65)]
threshold: f32,
/// Offset into the library's rel_path listing (sorted by id ASC).
/// Bump on re-runs to sample a different slice.
#[arg(long, default_value_t = 0)]
offset: i64,
/// How many DB rows to scan before giving up on hitting the limit.
/// Useful when a library is mostly videos.
#[arg(long, default_value_t = 2000)]
max_scan: i64,
}
/// Mirror of `face_watch::read_image_bytes_for_detect` — it's pub(crate)
/// so we can't import it across the bin boundary. The probe is throwaway
/// scope; inlining is cleaner than changing the visibility.
fn read_image_bytes(path: &Path) -> std::io::Result<Vec<u8>> {
if file_types::needs_ffmpeg_thumbnail(path)
&& let Some(preview) = exif::extract_embedded_jpeg_preview(path)
{
return Ok(preview);
}
std::fs::read(path)
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();
dotenv::dotenv().ok();
let args = Args::parse();
let client = TagClient::from_env();
if !client.is_enabled() {
anyhow::bail!(
"TagClient disabled: set APOLLO_TAG_API_BASE_URL or APOLLO_API_BASE_URL in .env"
);
}
// Quick health probe so we fail fast on a misconfig before grinding
// through a thousand rows.
match client.health().await {
Ok(h) => info!(
"tag engine: loaded={} device={} model={} threshold_default={}",
h.loaded, h.device, h.model_version, h.threshold
),
Err(e) => warn!("health probe failed (continuing): {e}"),
}
let mut seed_conn = connect();
if let Some(base) = dotenv::var("BASE_PATH").ok().as_deref() {
libraries::seed_or_patch_from_env(&mut seed_conn, base);
}
let libs = libraries::load_all(&mut seed_conn);
drop(seed_conn);
let lib: Library = libs
.into_iter()
.find(|l| l.id == args.library)
.ok_or_else(|| anyhow::anyhow!("library id {} not found", args.library))?;
info!("probing library #{} ({}) at {}", lib.id, lib.name, lib.root_path);
let dao: Arc<Mutex<Box<dyn ExifDao>>> = Arc::new(Mutex::new(Box::new(SqliteExifDao::new())));
let ctx = opentelemetry::Context::new();
// Paginate through (id, rel_path) for this library, filter to images
// on disk, take `limit`. Page size is tuned so we don't slam the DB
// when a library is video-heavy.
const PAGE: i64 = 500;
let mut offset = args.offset;
let mut scanned: i64 = 0;
let mut probed = 0usize;
let mut ok_count = 0usize;
let mut empty_count = 0usize;
let mut perm_fail = 0usize;
let mut transient_fail = 0usize;
let started = Instant::now();
let root = PathBuf::from(&lib.root_path);
'outer: loop {
if scanned >= args.max_scan {
warn!(
"scan cap ({}) reached before hitting limit ({}); bump --max-scan to scan deeper",
args.max_scan, args.limit
);
break;
}
let rows = {
let mut guard = dao.lock().expect("dao lock");
guard
.list_rel_paths_for_library_page(&ctx, lib.id, PAGE, offset)
.map_err(|e| anyhow::anyhow!("list rel_paths: {:?}", e))?
};
if rows.is_empty() {
info!("no more rows after offset {}", offset);
break;
}
offset += rows.len() as i64;
scanned += rows.len() as i64;
for (_id, rel_path) in rows {
if probed >= args.limit {
break 'outer;
}
let abs = root.join(&rel_path);
// Skip non-images and videos at the path level — same logic
// the face backlog drain uses, just inlined.
if !file_types::is_image_file(&abs) {
continue;
}
if !abs.exists() {
continue;
}
let bytes = match read_image_bytes(&abs) {
Ok(b) => b,
Err(e) => {
warn!("read {rel_path}: {e}");
continue;
}
};
// The probe doesn't need a real content_hash — Apollo only
// logs it. Pass an empty marker so we don't trip on no-hash
// image_exif rows.
let meta = TagMeta {
content_hash: String::new(),
library_id: lib.id,
rel_path: rel_path.clone(),
threshold: Some(args.threshold),
};
let call_start = Instant::now();
match client.auto_tag(bytes, meta).await {
Ok(resp) => {
probed += 1;
if resp.tags.is_empty() {
empty_count += 1;
println!(
"[{:>3}] (no tags) {}ms {}",
probed, resp.duration_ms, rel_path
);
} else {
ok_count += 1;
let preview = resp
.tags
.iter()
.map(|t| format!("{}({:.2})", t.name, t.confidence))
.collect::<Vec<_>>()
.join(", ");
println!(
"[{:>3}] {} tags {}ms {}\n {}",
probed,
resp.tags.len(),
resp.duration_ms,
rel_path,
preview
);
}
}
Err(TagDetectError::Permanent(e)) => {
probed += 1;
perm_fail += 1;
println!(
"[{:>3}] PERMANENT FAIL ({:>4}ms) {}\n {}",
probed,
call_start.elapsed().as_millis(),
rel_path,
e
);
}
Err(TagDetectError::Transient(e)) => {
probed += 1;
transient_fail += 1;
println!(
"[{:>3}] TRANSIENT FAIL ({:>4}ms) {}\n {}",
probed,
call_start.elapsed().as_millis(),
rel_path,
e
);
}
Err(TagDetectError::Disabled) => {
anyhow::bail!("tag client became disabled mid-run; impossible");
}
}
}
}
let elapsed = started.elapsed();
println!();
println!("── summary ───────────────────────────────────────");
println!("scanned rows : {scanned}");
println!("probed files : {probed}");
println!(" with tags : {ok_count}");
println!(" empty (no tags) : {empty_count}");
println!(" permanent failures : {perm_fail}");
println!(" transient failures : {transient_fail}");
println!("elapsed : {:.1}s", elapsed.as_secs_f32());
if probed > 0 {
println!(
"throughput : {:.2} photos/s",
probed as f32 / elapsed.as_secs_f32().max(0.001)
);
}
Ok(())
}