Files
ImageApi/src/ai/clip_client.rs
T
Cameron Cordes 66267cc345 clip-search: fmt + clippy clamp + test AppState arg
Pulls cargo fmt + clippy pass over the new files only — pre-existing
files left untouched even though fmt has drift on them. clamp(1,200)
swaps a manual min/max chain that clippy flagged. test AppState
constructor needed ClipClient::new(None) so the lib-test target
compiles.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 16:10:52 -04:00

393 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Thin async HTTP client for Apollo's `/api/internal/clip/*` endpoints.
//!
//! Apollo hosts the OpenAI CLIP inference service (ViT-L/14 by default,
//! configurable via `APOLLO_CLIP_MODEL`). This client is the ImageApi side
//! of the contract: shove image bytes through `/encode_image` to populate
//! `image_exif.clip_embedding` during backfill, and call `/encode_text` to
//! encode a user's natural-language query at search time. The actual
//! cosine-similarity rerank runs locally in ImageApi.
//!
//! Mirrors `face_client.rs` / `tag_client.rs` shape: optional base URL
//! (None = disabled — feature off, drain and search no-op), reqwest
//! client with a generous timeout because GPU inference under a backlog
//! can queue server-side (Apollo's threadpool is bounded to 1 worker on
//! CUDA).
//!
//! Configured via `APOLLO_CLIP_API_BASE_URL`, falling back to
//! `APOLLO_API_BASE_URL` when the dedicated var is unset (single-Apollo
//! deploys are the common case).
//!
//! Wire format:
//! - `/encode_image`: multipart/form-data with `file=<bytes>` and
//! `meta=<json>` (content_hash / library_id / rel_path for logging).
//! - `/encode_text`: JSON `{"text": "<query>"}`.
//!
//! Both return `{model_version, embedding_dim, duration_ms, embedding}`
//! where `embedding` is base64 of `dim×4` little-endian float32 bytes,
//! L2-normalized so the rerank reduces to a plain dot product.
//!
//! Error mapping (reflected in [`ClipError`]):
//! - 422 `decode_failed` / `empty_text` → permanent: ImageApi marks the
//! row failed or surfaces the empty-query error to the search caller.
//! - 503 `cuda_oom` / `engine_unavailable` → defer-and-retry: no marker.
//! - Any other 5xx / network error → defer.
use anyhow::{Context, Result};
use base64::Engine;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Serialize)]
pub struct EncodeImageMeta {
pub content_hash: String,
pub library_id: i32,
pub rel_path: String,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)] // duration_ms logged by the backfill drain
pub struct EncodeResponse {
pub model_version: String,
pub embedding_dim: i32,
pub duration_ms: i64,
/// base64 of `embedding_dim * 4` bytes (LE float32). ImageApi stores
/// the decoded bytes verbatim as a BLOB.
pub embedding: String,
}
impl EncodeResponse {
/// Decode the wire-format embedding back into raw bytes for storage.
/// Validates the buffer is `embedding_dim * 4` bytes long so a
/// malformed response surfaces here rather than as a downstream
/// silent length mismatch.
pub fn decode_embedding(&self) -> Result<Vec<u8>> {
let bytes = base64::engine::general_purpose::STANDARD
.decode(self.embedding.as_bytes())
.context("clip embedding base64 decode")?;
let expected = (self.embedding_dim as usize) * 4;
if bytes.len() != expected {
anyhow::bail!(
"clip embedding wrong size: got {} bytes, expected {} ({} * 4)",
bytes.len(),
expected,
self.embedding_dim
);
}
Ok(bytes)
}
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)] // load_error consumed by future health probe
pub struct ClipHealth {
pub loaded: bool,
pub device: String,
pub model_version: String,
pub embedding_dim: i32,
#[serde(default)]
pub load_error: Option<String>,
}
#[derive(Debug)]
pub enum ClipError {
/// Apollo refused for a reason that won't change on retry (decode
/// failure on /encode_image, empty text on /encode_text).
Permanent(anyhow::Error),
/// Apollo couldn't process this turn but might next time (CUDA OOM,
/// engine not loaded, network hiccup).
Transient(anyhow::Error),
/// Feature is disabled (no `APOLLO_CLIP_API_BASE_URL` /
/// `APOLLO_API_BASE_URL`).
Disabled,
}
impl std::fmt::Display for ClipError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClipError::Permanent(e) => write!(f, "permanent: {e}"),
ClipError::Transient(e) => write!(f, "transient: {e}"),
ClipError::Disabled => write!(f, "clip client disabled"),
}
}
}
impl std::error::Error for ClipError {}
#[derive(Clone)]
pub struct ClipClient {
client: Client,
base_url: Option<String>,
}
impl ClipClient {
pub fn new(base_url: Option<String>) -> Self {
let timeout_secs = std::env::var("CLIP_REQUEST_TIMEOUT_SEC")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(60);
let client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.expect("reqwest client build");
Self {
client,
base_url: base_url.map(|u| u.trim_end_matches('/').to_string()),
}
}
/// Read both standard env vars. `APOLLO_CLIP_API_BASE_URL` wins;
/// fallback to `APOLLO_API_BASE_URL`. Both unset → disabled.
pub fn from_env() -> Self {
let base = std::env::var("APOLLO_CLIP_API_BASE_URL")
.ok()
.filter(|s| !s.trim().is_empty())
.or_else(|| {
std::env::var("APOLLO_API_BASE_URL")
.ok()
.filter(|s| !s.trim().is_empty())
});
Self::new(base)
}
pub fn is_enabled(&self) -> bool {
self.base_url.is_some()
}
/// Encode an image to a 768-d (ViT-L/14) or 512-d (ViT-B/32)
/// L2-normalized embedding. Used by the backfill drain.
pub async fn encode_image(
&self,
bytes: Vec<u8>,
meta: EncodeImageMeta,
) -> std::result::Result<EncodeResponse, ClipError> {
let Some(base) = self.base_url.as_deref() else {
return Err(ClipError::Disabled);
};
let url = format!("{}/api/internal/clip/encode_image", base);
let meta_json = serde_json::to_string(&meta)
.map_err(|e| ClipError::Permanent(anyhow::anyhow!("meta serialize: {e}")))?;
let form = reqwest::multipart::Form::new()
.text("meta", meta_json)
.part(
"file",
reqwest::multipart::Part::bytes(bytes)
.file_name(meta.rel_path.clone())
.mime_str("application/octet-stream")
.unwrap_or_else(|_| reqwest::multipart::Part::bytes(Vec::new())),
);
self.send_multipart(&url, form).await
}
/// Encode a natural-language query to an embedding. Used by the
/// search route to rank stored image embeddings by cosine sim.
pub async fn encode_text(&self, text: &str) -> std::result::Result<EncodeResponse, ClipError> {
let Some(base) = self.base_url.as_deref() else {
return Err(ClipError::Disabled);
};
let url = format!("{}/api/internal/clip/encode_text", base);
let body = serde_json::json!({ "text": text });
let resp = match self.client.post(&url).json(&body).send().await {
Ok(r) => r,
Err(e) if e.is_timeout() || e.is_connect() => {
return Err(ClipError::Transient(anyhow::anyhow!(
"clip client network: {e}"
)));
}
Err(e) => {
return Err(ClipError::Transient(anyhow::anyhow!(
"clip client request: {e}"
)));
}
};
let status = resp.status();
if status.is_success() {
let body: EncodeResponse = resp
.json()
.await
.map_err(|e| ClipError::Transient(anyhow::anyhow!("clip response decode: {e}")))?;
return Ok(body);
}
let body_text = resp.text().await.unwrap_or_default();
Err(classify_error_response(status.as_u16(), &body_text))
}
/// Engine reachability + device/model report. Used as a startup
/// sanity check from the probe binary and (later) the backlog drain.
#[allow(dead_code)] // consumed by probe + drain
pub async fn health(&self) -> Result<ClipHealth> {
let base = self.base_url.as_deref().context("clip client disabled")?;
let url = format!("{}/api/internal/clip/health", base);
let resp = self.client.get(&url).send().await?.error_for_status()?;
let body: ClipHealth = resp.json().await?;
Ok(body)
}
async fn send_multipart(
&self,
url: &str,
form: reqwest::multipart::Form,
) -> std::result::Result<EncodeResponse, ClipError> {
let resp = match self.client.post(url).multipart(form).send().await {
Ok(r) => r,
Err(e) if e.is_timeout() || e.is_connect() => {
return Err(ClipError::Transient(anyhow::anyhow!(
"clip client network: {e}"
)));
}
Err(e) => {
return Err(ClipError::Transient(anyhow::anyhow!(
"clip client request: {e}"
)));
}
};
let status = resp.status();
if status.is_success() {
let body: EncodeResponse = resp
.json()
.await
.map_err(|e| ClipError::Transient(anyhow::anyhow!("clip response decode: {e}")))?;
return Ok(body);
}
let body_text = resp.text().await.unwrap_or_default();
Err(classify_error_response(status.as_u16(), &body_text))
}
}
/// Pulled out as a pure function so the marker-row contract is unit-
/// testable without spinning up an HTTP server. Matches the shape used
/// by face_client::classify_error_response so future retry policies
/// can share code.
fn classify_error_response(status: u16, body_text: &str) -> ClipError {
let detail_code = serde_json::from_str::<serde_json::Value>(body_text)
.ok()
.and_then(|v| {
v.get("detail")
.and_then(|d| d.as_str().map(str::to_string))
.or_else(|| {
v.get("detail")
.and_then(|d| d.get("code"))
.and_then(|c| c.as_str())
.map(str::to_string)
})
})
.unwrap_or_default();
if status == 422 {
return ClipError::Permanent(anyhow::anyhow!(
"clip {} {}: {}",
status,
detail_code,
body_text
));
}
if status == 503 {
return ClipError::Transient(anyhow::anyhow!(
"clip {} {}: {}",
status,
detail_code,
body_text
));
}
// 408 / 413 / 429 are operator-fixable infra issues; defer.
if matches!(status, 408 | 413 | 429) {
return ClipError::Transient(anyhow::anyhow!(
"clip {} {}: {}",
status,
detail_code,
body_text
));
}
if (400..500).contains(&status) {
ClipError::Permanent(anyhow::anyhow!(
"clip {} {}: {}",
status,
detail_code,
body_text
))
} else {
ClipError::Transient(anyhow::anyhow!(
"clip {} {}: {}",
status,
detail_code,
body_text
))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn is_permanent(e: &ClipError) -> bool {
matches!(e, ClipError::Permanent(_))
}
fn is_transient(e: &ClipError) -> bool {
matches!(e, ClipError::Transient(_))
}
#[test]
fn classify_422_decode_failed_is_permanent() {
assert!(is_permanent(&classify_error_response(
422,
r#"{"detail":"decode_failed: bad bytes"}"#
)));
}
#[test]
fn classify_422_empty_text_is_permanent() {
assert!(is_permanent(&classify_error_response(
422,
r#"{"detail":"empty_text"}"#
)));
}
#[test]
fn classify_503_cuda_oom_is_transient() {
assert!(is_transient(&classify_error_response(
503,
r#"{"detail":{"code":"cuda_oom","error":"out of memory"}}"#,
)));
}
#[test]
fn classify_5xx_is_transient_other_4xx_is_permanent() {
assert!(is_transient(&classify_error_response(500, "")));
assert!(is_permanent(&classify_error_response(404, "{}")));
}
#[test]
fn classify_infra_4xx_is_transient() {
assert!(is_transient(&classify_error_response(408, "")));
assert!(is_transient(&classify_error_response(413, "<html>")));
assert!(is_transient(&classify_error_response(429, "{}")));
}
#[test]
fn decode_embedding_size_mismatch_errors() {
// dim=4 says we expect 16 bytes (4 floats × 4 bytes). Encode 8.
use base64::Engine;
let resp = EncodeResponse {
model_version: "ViT-L/14".into(),
embedding_dim: 4,
duration_ms: 0,
embedding: base64::engine::general_purpose::STANDARD.encode([0u8; 8]),
};
assert!(resp.decode_embedding().is_err());
}
#[test]
fn decode_embedding_round_trip() {
use base64::Engine;
let bytes: Vec<u8> = (0..16).collect();
let resp = EncodeResponse {
model_version: "ViT-L/14".into(),
embedding_dim: 4,
duration_ms: 0,
embedding: base64::engine::general_purpose::STANDARD.encode(&bytes),
};
assert_eq!(resp.decode_embedding().unwrap(), bytes);
}
}