clip-search: migration + client + probe binary
Probe-phase scaffolding for CLIP semantic search. Adds the column that will hold per-photo embeddings, the HTTP client to Apollo's inference service, and a throwaway probe binary so we can eyeball search-result quality on the live library before building the persistence layer (backlog drain, /photos/search endpoint, UI). - migrations/2026-05-14-000000_add_clip_embedding/ — adds image_exif.clip_embedding (BLOB) and clip_model_version (TEXT), plus a partial index on (clip_embedding IS NULL AND content_hash IS NOT NULL) for the future backfill drain. - src/database/models.rs — extends ImageExif struct to match. - src/ai/clip_client.rs — encode_image / encode_text / health, same Permanent/Transient/Disabled taxonomy as face_client. - src/bin/probe_clip_search.rs — --query <q> --library N --limit M --top K. Encodes a sample and prints top-K cosine similarities. No DB writes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
393
src/ai/clip_client.rs
Normal file
393
src/ai/clip_client.rs
Normal file
@@ -0,0 +1,393 @@
|
||||
//! Thin async HTTP client for Apollo's `/api/internal/clip/*` endpoints.
|
||||
//!
|
||||
//! Apollo hosts the OpenAI CLIP inference service (ViT-L/14 by default,
|
||||
//! configurable via `APOLLO_CLIP_MODEL`). This client is the ImageApi side
|
||||
//! of the contract: shove image bytes through `/encode_image` to populate
|
||||
//! `image_exif.clip_embedding` during backfill, and call `/encode_text` to
|
||||
//! encode a user's natural-language query at search time. The actual
|
||||
//! cosine-similarity rerank runs locally in ImageApi.
|
||||
//!
|
||||
//! Mirrors `face_client.rs` / `tag_client.rs` shape: optional base URL
|
||||
//! (None = disabled — feature off, drain and search no-op), reqwest
|
||||
//! client with a generous timeout because GPU inference under a backlog
|
||||
//! can queue server-side (Apollo's threadpool is bounded to 1 worker on
|
||||
//! CUDA).
|
||||
//!
|
||||
//! Configured via `APOLLO_CLIP_API_BASE_URL`, falling back to
|
||||
//! `APOLLO_API_BASE_URL` when the dedicated var is unset (single-Apollo
|
||||
//! deploys are the common case).
|
||||
//!
|
||||
//! Wire format:
|
||||
//! - `/encode_image`: multipart/form-data with `file=<bytes>` and
|
||||
//! `meta=<json>` (content_hash / library_id / rel_path for logging).
|
||||
//! - `/encode_text`: JSON `{"text": "<query>"}`.
|
||||
//!
|
||||
//! Both return `{model_version, embedding_dim, duration_ms, embedding}`
|
||||
//! where `embedding` is base64 of `dim×4` little-endian float32 bytes,
|
||||
//! L2-normalized so the rerank reduces to a plain dot product.
|
||||
//!
|
||||
//! Error mapping (reflected in [`ClipError`]):
|
||||
//! - 422 `decode_failed` / `empty_text` → permanent: ImageApi marks the
|
||||
//! row failed or surfaces the empty-query error to the search caller.
|
||||
//! - 503 `cuda_oom` / `engine_unavailable` → defer-and-retry: no marker.
|
||||
//! - Any other 5xx / network error → defer.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use base64::Engine;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct EncodeImageMeta {
|
||||
pub content_hash: String,
|
||||
pub library_id: i32,
|
||||
pub rel_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[allow(dead_code)] // duration_ms logged by the backfill drain
|
||||
pub struct EncodeResponse {
|
||||
pub model_version: String,
|
||||
pub embedding_dim: i32,
|
||||
pub duration_ms: i64,
|
||||
/// base64 of `embedding_dim * 4` bytes (LE float32). ImageApi stores
|
||||
/// the decoded bytes verbatim as a BLOB.
|
||||
pub embedding: String,
|
||||
}
|
||||
|
||||
impl EncodeResponse {
|
||||
/// Decode the wire-format embedding back into raw bytes for storage.
|
||||
/// Validates the buffer is `embedding_dim * 4` bytes long so a
|
||||
/// malformed response surfaces here rather than as a downstream
|
||||
/// silent length mismatch.
|
||||
pub fn decode_embedding(&self) -> Result<Vec<u8>> {
|
||||
let bytes = base64::engine::general_purpose::STANDARD
|
||||
.decode(self.embedding.as_bytes())
|
||||
.context("clip embedding base64 decode")?;
|
||||
let expected = (self.embedding_dim as usize) * 4;
|
||||
if bytes.len() != expected {
|
||||
anyhow::bail!(
|
||||
"clip embedding wrong size: got {} bytes, expected {} ({} * 4)",
|
||||
bytes.len(),
|
||||
expected,
|
||||
self.embedding_dim
|
||||
);
|
||||
}
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[allow(dead_code)] // load_error consumed by future health probe
|
||||
pub struct ClipHealth {
|
||||
pub loaded: bool,
|
||||
pub device: String,
|
||||
pub model_version: String,
|
||||
pub embedding_dim: i32,
|
||||
#[serde(default)]
|
||||
pub load_error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ClipError {
|
||||
/// Apollo refused for a reason that won't change on retry (decode
|
||||
/// failure on /encode_image, empty text on /encode_text).
|
||||
Permanent(anyhow::Error),
|
||||
/// Apollo couldn't process this turn but might next time (CUDA OOM,
|
||||
/// engine not loaded, network hiccup).
|
||||
Transient(anyhow::Error),
|
||||
/// Feature is disabled (no `APOLLO_CLIP_API_BASE_URL` /
|
||||
/// `APOLLO_API_BASE_URL`).
|
||||
Disabled,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ClipError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ClipError::Permanent(e) => write!(f, "permanent: {e}"),
|
||||
ClipError::Transient(e) => write!(f, "transient: {e}"),
|
||||
ClipError::Disabled => write!(f, "clip client disabled"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ClipError {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClipClient {
|
||||
client: Client,
|
||||
base_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ClipClient {
|
||||
pub fn new(base_url: Option<String>) -> Self {
|
||||
let timeout_secs = std::env::var("CLIP_REQUEST_TIMEOUT_SEC")
|
||||
.ok()
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(60);
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.expect("reqwest client build");
|
||||
Self {
|
||||
client,
|
||||
base_url: base_url.map(|u| u.trim_end_matches('/').to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Read both standard env vars. `APOLLO_CLIP_API_BASE_URL` wins;
|
||||
/// fallback to `APOLLO_API_BASE_URL`. Both unset → disabled.
|
||||
pub fn from_env() -> Self {
|
||||
let base = std::env::var("APOLLO_CLIP_API_BASE_URL")
|
||||
.ok()
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.or_else(|| {
|
||||
std::env::var("APOLLO_API_BASE_URL")
|
||||
.ok()
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
});
|
||||
Self::new(base)
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.base_url.is_some()
|
||||
}
|
||||
|
||||
/// Encode an image to a 768-d (ViT-L/14) or 512-d (ViT-B/32)
|
||||
/// L2-normalized embedding. Used by the backfill drain.
|
||||
pub async fn encode_image(
|
||||
&self,
|
||||
bytes: Vec<u8>,
|
||||
meta: EncodeImageMeta,
|
||||
) -> std::result::Result<EncodeResponse, ClipError> {
|
||||
let Some(base) = self.base_url.as_deref() else {
|
||||
return Err(ClipError::Disabled);
|
||||
};
|
||||
let url = format!("{}/api/internal/clip/encode_image", base);
|
||||
let meta_json = serde_json::to_string(&meta)
|
||||
.map_err(|e| ClipError::Permanent(anyhow::anyhow!("meta serialize: {e}")))?;
|
||||
let form = reqwest::multipart::Form::new()
|
||||
.text("meta", meta_json)
|
||||
.part(
|
||||
"file",
|
||||
reqwest::multipart::Part::bytes(bytes)
|
||||
.file_name(meta.rel_path.clone())
|
||||
.mime_str("application/octet-stream")
|
||||
.unwrap_or_else(|_| reqwest::multipart::Part::bytes(Vec::new())),
|
||||
);
|
||||
self.send_multipart(&url, form).await
|
||||
}
|
||||
|
||||
/// Encode a natural-language query to an embedding. Used by the
|
||||
/// search route to rank stored image embeddings by cosine sim.
|
||||
pub async fn encode_text(
|
||||
&self,
|
||||
text: &str,
|
||||
) -> std::result::Result<EncodeResponse, ClipError> {
|
||||
let Some(base) = self.base_url.as_deref() else {
|
||||
return Err(ClipError::Disabled);
|
||||
};
|
||||
let url = format!("{}/api/internal/clip/encode_text", base);
|
||||
let body = serde_json::json!({ "text": text });
|
||||
|
||||
let resp = match self.client.post(&url).json(&body).send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) if e.is_timeout() || e.is_connect() => {
|
||||
return Err(ClipError::Transient(anyhow::anyhow!(
|
||||
"clip client network: {e}"
|
||||
)));
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(ClipError::Transient(anyhow::anyhow!(
|
||||
"clip client request: {e}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
let status = resp.status();
|
||||
if status.is_success() {
|
||||
let body: EncodeResponse = resp.json().await.map_err(|e| {
|
||||
ClipError::Transient(anyhow::anyhow!("clip response decode: {e}"))
|
||||
})?;
|
||||
return Ok(body);
|
||||
}
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
Err(classify_error_response(status.as_u16(), &body_text))
|
||||
}
|
||||
|
||||
/// Engine reachability + device/model report. Used as a startup
|
||||
/// sanity check from the probe binary and (later) the backlog drain.
|
||||
#[allow(dead_code)] // consumed by probe + drain
|
||||
pub async fn health(&self) -> Result<ClipHealth> {
|
||||
let base = self.base_url.as_deref().context("clip client disabled")?;
|
||||
let url = format!("{}/api/internal/clip/health", base);
|
||||
let resp = self.client.get(&url).send().await?.error_for_status()?;
|
||||
let body: ClipHealth = resp.json().await?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
async fn send_multipart(
|
||||
&self,
|
||||
url: &str,
|
||||
form: reqwest::multipart::Form,
|
||||
) -> std::result::Result<EncodeResponse, ClipError> {
|
||||
let resp = match self.client.post(url).multipart(form).send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) if e.is_timeout() || e.is_connect() => {
|
||||
return Err(ClipError::Transient(anyhow::anyhow!(
|
||||
"clip client network: {e}"
|
||||
)));
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(ClipError::Transient(anyhow::anyhow!(
|
||||
"clip client request: {e}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
let status = resp.status();
|
||||
if status.is_success() {
|
||||
let body: EncodeResponse = resp.json().await.map_err(|e| {
|
||||
ClipError::Transient(anyhow::anyhow!("clip response decode: {e}"))
|
||||
})?;
|
||||
return Ok(body);
|
||||
}
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
Err(classify_error_response(status.as_u16(), &body_text))
|
||||
}
|
||||
}
|
||||
|
||||
/// Pulled out as a pure function so the marker-row contract is unit-
|
||||
/// testable without spinning up an HTTP server. Matches the shape used
|
||||
/// by face_client::classify_error_response so future retry policies
|
||||
/// can share code.
|
||||
fn classify_error_response(status: u16, body_text: &str) -> ClipError {
|
||||
let detail_code = serde_json::from_str::<serde_json::Value>(body_text)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
v.get("detail")
|
||||
.and_then(|d| d.as_str().map(str::to_string))
|
||||
.or_else(|| {
|
||||
v.get("detail")
|
||||
.and_then(|d| d.get("code"))
|
||||
.and_then(|c| c.as_str())
|
||||
.map(str::to_string)
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if status == 422 {
|
||||
return ClipError::Permanent(anyhow::anyhow!(
|
||||
"clip {} {}: {}",
|
||||
status,
|
||||
detail_code,
|
||||
body_text
|
||||
));
|
||||
}
|
||||
if status == 503 {
|
||||
return ClipError::Transient(anyhow::anyhow!(
|
||||
"clip {} {}: {}",
|
||||
status,
|
||||
detail_code,
|
||||
body_text
|
||||
));
|
||||
}
|
||||
// 408 / 413 / 429 are operator-fixable infra issues; defer.
|
||||
if matches!(status, 408 | 413 | 429) {
|
||||
return ClipError::Transient(anyhow::anyhow!(
|
||||
"clip {} {}: {}",
|
||||
status,
|
||||
detail_code,
|
||||
body_text
|
||||
));
|
||||
}
|
||||
if (400..500).contains(&status) {
|
||||
ClipError::Permanent(anyhow::anyhow!(
|
||||
"clip {} {}: {}",
|
||||
status,
|
||||
detail_code,
|
||||
body_text
|
||||
))
|
||||
} else {
|
||||
ClipError::Transient(anyhow::anyhow!(
|
||||
"clip {} {}: {}",
|
||||
status,
|
||||
detail_code,
|
||||
body_text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn is_permanent(e: &ClipError) -> bool {
|
||||
matches!(e, ClipError::Permanent(_))
|
||||
}
|
||||
fn is_transient(e: &ClipError) -> bool {
|
||||
matches!(e, ClipError::Transient(_))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_422_decode_failed_is_permanent() {
|
||||
assert!(is_permanent(&classify_error_response(
|
||||
422,
|
||||
r#"{"detail":"decode_failed: bad bytes"}"#
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_422_empty_text_is_permanent() {
|
||||
assert!(is_permanent(&classify_error_response(
|
||||
422,
|
||||
r#"{"detail":"empty_text"}"#
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_503_cuda_oom_is_transient() {
|
||||
assert!(is_transient(&classify_error_response(
|
||||
503,
|
||||
r#"{"detail":{"code":"cuda_oom","error":"out of memory"}}"#,
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_5xx_is_transient_other_4xx_is_permanent() {
|
||||
assert!(is_transient(&classify_error_response(500, "")));
|
||||
assert!(is_permanent(&classify_error_response(404, "{}")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_infra_4xx_is_transient() {
|
||||
assert!(is_transient(&classify_error_response(408, "")));
|
||||
assert!(is_transient(&classify_error_response(413, "<html>")));
|
||||
assert!(is_transient(&classify_error_response(429, "{}")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_embedding_size_mismatch_errors() {
|
||||
// dim=4 says we expect 16 bytes (4 floats × 4 bytes). Encode 8.
|
||||
use base64::Engine;
|
||||
let resp = EncodeResponse {
|
||||
model_version: "ViT-L/14".into(),
|
||||
embedding_dim: 4,
|
||||
duration_ms: 0,
|
||||
embedding: base64::engine::general_purpose::STANDARD.encode([0u8; 8]),
|
||||
};
|
||||
assert!(resp.decode_embedding().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_embedding_round_trip() {
|
||||
use base64::Engine;
|
||||
let bytes: Vec<u8> = (0..16).collect();
|
||||
let resp = EncodeResponse {
|
||||
model_version: "ViT-L/14".into(),
|
||||
embedding_dim: 4,
|
||||
duration_ms: 0,
|
||||
embedding: base64::engine::general_purpose::STANDARD.encode(&bytes),
|
||||
};
|
||||
assert_eq!(resp.decode_embedding().unwrap(), bytes);
|
||||
}
|
||||
}
|
||||
@@ -2184,6 +2184,8 @@ mod tests {
|
||||
date_taken_source: None,
|
||||
original_date_taken: None,
|
||||
original_date_taken_source: None,
|
||||
clip_embedding: None,
|
||||
clip_model_version: None,
|
||||
});
|
||||
let out = resolve_date_taken_for_context(&exif, "Screenshot_2014-06-01.png");
|
||||
assert_eq!(out.as_deref(), Some("2021-08-15"));
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod apollo_client;
|
||||
pub mod clip_client;
|
||||
pub mod daily_summary_job;
|
||||
pub mod face_client;
|
||||
pub mod handlers;
|
||||
|
||||
Reference in New Issue
Block a user