Feature/unified nl search #106
@@ -21,12 +21,6 @@
|
||||
//! `geo::forward_geocode`), and person filtering is deferred until a
|
||||
//! person→photos resolver exists.
|
||||
|
||||
// Phase 1: this module is fully implemented and unit-tested, but its first
|
||||
// consumer (the `/photos/search/unified` endpoint) lands in Phase 2. Mirrors
|
||||
// llm_client.rs's allow-until-wired pattern so the bin target stays
|
||||
// clippy-clean in the interim; remove when the endpoint is added.
|
||||
#![allow(dead_code)]
|
||||
|
||||
use crate::ai::llm_client::{ChatMessage, LlmClient, Tool, strip_think_blocks};
|
||||
use anyhow::{Result, anyhow};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
+214
-184
@@ -124,65 +124,161 @@ fn dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
|
||||
pub async fn search_photos(
|
||||
state: web::Data<AppState>,
|
||||
exif_dao: web::Data<Mutex<Box<dyn ExifDao>>>,
|
||||
query: web::Query<SearchQuery>,
|
||||
) -> ActixResult<HttpResponse> {
|
||||
let q_text = query.q.trim().to_string();
|
||||
if q_text.is_empty() {
|
||||
return Ok(HttpResponse::BadRequest().json(SearchError {
|
||||
error: "query parameter `q` is required".into(),
|
||||
}));
|
||||
}
|
||||
/// Failure modes of [`score_photos`]. Carries enough to let each caller pick
|
||||
/// an appropriate HTTP status (the CLIP service being down is a 502, a
|
||||
/// disabled feature is a 503, a rejected query is a 400, a DB failure 500).
|
||||
pub enum ScoreError {
|
||||
/// CLIP search isn't configured at all (no Apollo endpoint).
|
||||
Disabled,
|
||||
/// The query was rejected by the encoder (client error).
|
||||
Rejected(String),
|
||||
/// The CLIP service is transiently unavailable (upstream error).
|
||||
Unavailable(String),
|
||||
/// The encoder returned an embedding we couldn't decode.
|
||||
MalformedEmbedding,
|
||||
/// A database / index load failure.
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
/// Result of scoring the whole library against a query embedding: the
|
||||
/// resolved model version, how many embeddings were considered, and every
|
||||
/// `(score, content_hash)` above threshold, sorted by descending score.
|
||||
/// Pagination and path resolution are the caller's job (see [`resolve_hits`])
|
||||
/// so this core can be reused for both the plain search endpoint and the
|
||||
/// unified endpoint (which filters by hash before paginating).
|
||||
pub struct ScoredPhotos {
|
||||
pub model_version: String,
|
||||
pub considered: usize,
|
||||
/// `(cosine_score, content_hash)` pairs, descending by score.
|
||||
pub hits: Vec<(f32, String)>,
|
||||
}
|
||||
|
||||
/// Encode `q_text` via CLIP and score it against every stored embedding in
|
||||
/// the given library scope. Returns all matches above `threshold`, sorted by
|
||||
/// descending similarity. Pure of HTTP concerns so it's shared by
|
||||
/// `search_photos` and the unified search endpoint.
|
||||
pub async fn score_photos(
|
||||
state: &AppState,
|
||||
exif_dao: &Mutex<Box<dyn ExifDao>>,
|
||||
q_text: &str,
|
||||
library_ids: &[i32],
|
||||
threshold: f32,
|
||||
model_version: Option<&str>,
|
||||
) -> Result<ScoredPhotos, ScoreError> {
|
||||
if !state.clip_client.is_enabled() {
|
||||
return Ok(HttpResponse::ServiceUnavailable().json(SearchError {
|
||||
error: "CLIP search is disabled (no Apollo CLIP endpoint configured)".into(),
|
||||
}));
|
||||
return Err(ScoreError::Disabled);
|
||||
}
|
||||
|
||||
let limit = query.limit.clamp(1, 200);
|
||||
let offset = query.offset;
|
||||
let threshold = query.threshold.clamp(-1.0, 1.0);
|
||||
|
||||
// 1. Encode the query text. Fast — Apollo's text encoder is ~50ms
|
||||
// on CPU. Bail with a clear error message if Apollo's down so the
|
||||
// user sees "service unavailable" rather than empty results.
|
||||
let query_resp = match state.clip_client.encode_text(&q_text).await {
|
||||
// 1. Encode the query text. Fast — Apollo's text encoder is ~50ms on CPU.
|
||||
let query_resp = match state.clip_client.encode_text(q_text).await {
|
||||
Ok(r) => r,
|
||||
Err(ClipError::Permanent(e)) => {
|
||||
return Ok(HttpResponse::BadRequest().json(SearchError {
|
||||
error: format!("query rejected: {e}"),
|
||||
}));
|
||||
}
|
||||
Err(ClipError::Transient(e)) => {
|
||||
return Ok(HttpResponse::BadGateway().json(SearchError {
|
||||
error: format!("CLIP service unavailable: {e}"),
|
||||
}));
|
||||
}
|
||||
Err(ClipError::Disabled) => {
|
||||
return Ok(HttpResponse::ServiceUnavailable().json(SearchError {
|
||||
error: "CLIP service disabled".into(),
|
||||
}));
|
||||
}
|
||||
Err(ClipError::Permanent(e)) => return Err(ScoreError::Rejected(e.to_string())),
|
||||
Err(ClipError::Transient(e)) => return Err(ScoreError::Unavailable(e.to_string())),
|
||||
Err(ClipError::Disabled) => return Err(ScoreError::Disabled),
|
||||
};
|
||||
// decode_embedding works on raw bytes; the wire format is b64.
|
||||
let query_bytes = base64::engine::general_purpose::STANDARD
|
||||
.decode(query_resp.embedding.as_bytes())
|
||||
.unwrap_or_default();
|
||||
let query_vec = match decode_embedding(&query_bytes) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
return Ok(HttpResponse::BadGateway().json(SearchError {
|
||||
error: "CLIP service returned a malformed query embedding".into(),
|
||||
}));
|
||||
}
|
||||
};
|
||||
let query_vec = decode_embedding(&query_bytes).ok_or(ScoreError::MalformedEmbedding)?;
|
||||
|
||||
// 2. Decide which library scope to search. `library_ids` (multi)
|
||||
// wins over the legacy `library` (single) when both are present;
|
||||
// either / both empty falls back to "every enabled library".
|
||||
let library_ids: Vec<i32> = if let Some(raw) = query.library_ids.as_deref() {
|
||||
// 2. Pull the (hash, embedding) matrix under the dao lock, release
|
||||
// before scoring. The caller-supplied `model_version` (or the live
|
||||
// engine's) forces a strict join so a mid-flight model swap can't mix
|
||||
// geometries.
|
||||
let ctx = opentelemetry::Context::current();
|
||||
let rows: Vec<(String, Vec<u8>)> = {
|
||||
let mut dao = exif_dao.lock().expect("exif dao");
|
||||
dao.list_clip_index(
|
||||
&ctx,
|
||||
library_ids,
|
||||
model_version.or(Some(&query_resp.model_version)),
|
||||
)
|
||||
.map_err(|e| {
|
||||
log::warn!("clip_search: list_clip_index failed: {:?}", e);
|
||||
ScoreError::Internal("failed to load search index".into())
|
||||
})?
|
||||
};
|
||||
let considered = rows.len();
|
||||
|
||||
// 3. Score. Keep all matches and sort at the end (~microseconds at 14k).
|
||||
let mut hits: Vec<(f32, String)> = Vec::with_capacity(considered);
|
||||
for (hash, blob) in rows {
|
||||
let Some(emb) = decode_embedding(&blob) else {
|
||||
continue;
|
||||
};
|
||||
if emb.len() != query_vec.len() {
|
||||
continue;
|
||||
}
|
||||
let sim = dot(&emb, &query_vec);
|
||||
if sim < threshold {
|
||||
continue;
|
||||
}
|
||||
hits.push((sim, hash));
|
||||
}
|
||||
hits.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
Ok(ScoredPhotos {
|
||||
model_version: query_resp.model_version,
|
||||
considered,
|
||||
hits,
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolve a page of `(score, content_hash)` pairs back to [`SearchHit`]s
|
||||
/// (each carrying `library_id` + `rel_path`). Hashes that no longer resolve
|
||||
/// to a row are skipped. Shared by both endpoints.
|
||||
pub fn resolve_hits(
|
||||
exif_dao: &Mutex<Box<dyn ExifDao>>,
|
||||
scored: &[(f32, String)],
|
||||
) -> Vec<SearchHit> {
|
||||
if scored.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let ctx = opentelemetry::Context::current();
|
||||
let hashes: Vec<String> = scored.iter().map(|(_, h)| h.clone()).collect();
|
||||
let mut dao = exif_dao.lock().expect("exif dao");
|
||||
let path_map = dao
|
||||
.get_rel_paths_for_hashes(&ctx, &hashes)
|
||||
.unwrap_or_else(|e| {
|
||||
log::warn!("clip_search: get_rel_paths_for_hashes failed: {:?}", e);
|
||||
std::collections::HashMap::new()
|
||||
});
|
||||
|
||||
let mut results = Vec::with_capacity(scored.len());
|
||||
for (score, hash) in scored {
|
||||
let row = match dao.find_by_content_hash(&ctx, hash) {
|
||||
Ok(Some(r)) => r,
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
log::warn!("clip_search: find_by_content_hash failed for {hash}: {e:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Prefer get_rel_paths_for_hashes's first entry (shares image_exif's
|
||||
// natural order), falling back to the ImageExif row.
|
||||
let rel_path = path_map
|
||||
.get(hash)
|
||||
.and_then(|paths| paths.first().cloned())
|
||||
.unwrap_or(row.file_path);
|
||||
results.push(SearchHit {
|
||||
library_id: row.library_id,
|
||||
rel_path,
|
||||
content_hash: hash.clone(),
|
||||
score: *score,
|
||||
});
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
/// Parse the `library_ids` (multi) / `library` (single) scope params into a
|
||||
/// deduped id list. Empty = "every enabled library". Shared so the unified
|
||||
/// endpoint scopes CLIP identically.
|
||||
pub fn parse_library_scope(
|
||||
library_ids: Option<&str>,
|
||||
library: Option<i32>,
|
||||
) -> Result<Vec<i32>, String> {
|
||||
if let Some(raw) = library_ids {
|
||||
let mut out: Vec<i32> = Vec::new();
|
||||
for piece in raw.split(',') {
|
||||
let trimmed = piece.trim();
|
||||
@@ -195,158 +291,92 @@ pub async fn search_photos(
|
||||
out.push(id);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
return Ok(HttpResponse::BadRequest().json(SearchError {
|
||||
error: format!("invalid library_ids entry: {trimmed:?}"),
|
||||
}));
|
||||
}
|
||||
Err(_) => return Err(format!("invalid library_ids entry: {trimmed:?}")),
|
||||
}
|
||||
}
|
||||
out
|
||||
} else if let Some(id) = query.library {
|
||||
vec![id]
|
||||
Ok(out)
|
||||
} else if let Some(id) = library {
|
||||
Ok(vec![id])
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Pull the (hash, embedding) matrix. Lock contention here is
|
||||
// bounded — one big SELECT under a mutex Arc<Mutex<dyn ExifDao>>
|
||||
// and then we release before scoring. If this becomes a hotspot
|
||||
// we'll cache the decoded matrix in AppState with TTL.
|
||||
let ctx = opentelemetry::Context::current();
|
||||
let rows: Vec<(String, Vec<u8>)> = {
|
||||
let mut dao = exif_dao.lock().expect("exif dao");
|
||||
match dao.list_clip_index(
|
||||
&ctx,
|
||||
&library_ids,
|
||||
query
|
||||
.model_version
|
||||
.as_deref()
|
||||
.or(Some(&query_resp.model_version)),
|
||||
) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
log::warn!("clip_search: list_clip_index failed: {:?}", e);
|
||||
return Ok(HttpResponse::InternalServerError().json(SearchError {
|
||||
error: "failed to load search index".into(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
};
|
||||
let considered = rows.len();
|
||||
if considered == 0 {
|
||||
return Ok(HttpResponse::Ok().json(SearchResponse {
|
||||
query: q_text,
|
||||
model_version: query_resp.model_version,
|
||||
threshold,
|
||||
considered,
|
||||
total_matching: 0,
|
||||
offset,
|
||||
results: Vec::new(),
|
||||
pub async fn search_photos(
|
||||
state: web::Data<AppState>,
|
||||
exif_dao: web::Data<Mutex<Box<dyn ExifDao>>>,
|
||||
query: web::Query<SearchQuery>,
|
||||
) -> ActixResult<HttpResponse> {
|
||||
let q_text = query.q.trim().to_string();
|
||||
if q_text.is_empty() {
|
||||
return Ok(HttpResponse::BadRequest().json(SearchError {
|
||||
error: "query parameter `q` is required".into(),
|
||||
}));
|
||||
}
|
||||
|
||||
// 4. Score. Cap the loop's transient allocation; we keep all scores
|
||||
// and sort at the end. With ~14k entries the sort is microseconds.
|
||||
let mut scored: Vec<(f32, String)> = Vec::with_capacity(considered);
|
||||
for (hash, blob) in rows {
|
||||
let Some(emb) = decode_embedding(&blob) else {
|
||||
continue;
|
||||
};
|
||||
if emb.len() != query_vec.len() {
|
||||
continue;
|
||||
}
|
||||
let sim = dot(&emb, &query_vec);
|
||||
if sim < threshold {
|
||||
continue;
|
||||
}
|
||||
scored.push((sim, hash));
|
||||
}
|
||||
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let total_matching = scored.len();
|
||||
// Pagination — slice the sorted list at `[offset, offset+limit)`.
|
||||
// Offsets past the end produce empty pages rather than an error so
|
||||
// the client can stop fetching naturally on "load more" past the end.
|
||||
let scored: Vec<(f32, String)> = if offset >= total_matching {
|
||||
let limit = query.limit.clamp(1, 200);
|
||||
let offset = query.offset;
|
||||
let threshold = query.threshold.clamp(-1.0, 1.0);
|
||||
|
||||
let library_ids = match parse_library_scope(query.library_ids.as_deref(), query.library) {
|
||||
Ok(ids) => ids,
|
||||
Err(msg) => return Ok(HttpResponse::BadRequest().json(SearchError { error: msg })),
|
||||
};
|
||||
|
||||
let scored = match score_photos(
|
||||
&state,
|
||||
&exif_dao,
|
||||
&q_text,
|
||||
&library_ids,
|
||||
threshold,
|
||||
query.model_version.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(e) => return Ok(score_error_response(e)),
|
||||
};
|
||||
|
||||
let total_matching = scored.hits.len();
|
||||
// Pagination — slice the sorted list at `[offset, offset+limit)`. Offsets
|
||||
// past the end produce empty pages so "load more" stops naturally.
|
||||
let page: Vec<(f32, String)> = if offset >= total_matching {
|
||||
Vec::new()
|
||||
} else {
|
||||
let end = (offset + limit).min(total_matching);
|
||||
scored[offset..end].to_vec()
|
||||
scored.hits[offset..end].to_vec()
|
||||
};
|
||||
|
||||
if scored.is_empty() {
|
||||
return Ok(HttpResponse::Ok().json(SearchResponse {
|
||||
query: q_text,
|
||||
model_version: query_resp.model_version,
|
||||
threshold,
|
||||
considered,
|
||||
total_matching,
|
||||
offset,
|
||||
results: Vec::new(),
|
||||
}));
|
||||
}
|
||||
|
||||
// 5. Resolve each surviving hash back to a `(library_id, rel_path)`.
|
||||
// `get_rel_paths_by_hash` returns every rel_path; we pick the first
|
||||
// one for the result. Apollo / the UI can fetch alternatives via
|
||||
// /image/metadata when needed.
|
||||
let hashes: Vec<String> = scored.iter().map(|(_, h)| h.clone()).collect();
|
||||
let path_map = {
|
||||
let mut dao = exif_dao.lock().expect("exif dao");
|
||||
match dao.get_rel_paths_for_hashes(&ctx, &hashes) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
log::warn!("clip_search: get_rel_paths_for_hashes failed: {:?}", e);
|
||||
return Ok(HttpResponse::InternalServerError().json(SearchError {
|
||||
error: "failed to resolve photo paths".into(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// We need (library_id, rel_path) — get_rel_paths_for_hashes only
|
||||
// returns rel_paths. Cross-reference via find_by_content_hash to
|
||||
// pick the library too. Single call per surviving hash; cheap at
|
||||
// top-20.
|
||||
let mut results = Vec::with_capacity(scored.len());
|
||||
{
|
||||
let mut dao = exif_dao.lock().expect("exif dao");
|
||||
for (score, hash) in scored {
|
||||
let row = match dao.find_by_content_hash(&ctx, &hash) {
|
||||
Ok(Some(r)) => r,
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"clip_search: find_by_content_hash failed for {}: {:?}",
|
||||
hash,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Prefer get_rel_paths_for_hashes's first entry if it
|
||||
// exists (it shares semantics with `image_exif`'s natural
|
||||
// order), falling back to the ImageExif row.
|
||||
let rel_path = path_map
|
||||
.get(&hash)
|
||||
.and_then(|paths| paths.first().cloned())
|
||||
.unwrap_or(row.file_path);
|
||||
results.push(SearchHit {
|
||||
library_id: row.library_id,
|
||||
rel_path,
|
||||
content_hash: hash,
|
||||
score,
|
||||
});
|
||||
}
|
||||
}
|
||||
let results = resolve_hits(&exif_dao, &page);
|
||||
|
||||
Ok(HttpResponse::Ok().json(SearchResponse {
|
||||
query: q_text,
|
||||
model_version: query_resp.model_version,
|
||||
model_version: scored.model_version,
|
||||
threshold,
|
||||
considered,
|
||||
considered: scored.considered,
|
||||
total_matching,
|
||||
offset,
|
||||
results,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Map a [`ScoreError`] to the HTTP response `search_photos` historically
|
||||
/// returned for each failure mode. Reused by the unified endpoint.
|
||||
pub fn score_error_response(e: ScoreError) -> HttpResponse {
|
||||
match e {
|
||||
ScoreError::Disabled => HttpResponse::ServiceUnavailable().json(SearchError {
|
||||
error: "CLIP search is disabled (no Apollo CLIP endpoint configured)".into(),
|
||||
}),
|
||||
ScoreError::Rejected(msg) => HttpResponse::BadRequest().json(SearchError {
|
||||
error: format!("query rejected: {msg}"),
|
||||
}),
|
||||
ScoreError::Unavailable(msg) => HttpResponse::BadGateway().json(SearchError {
|
||||
error: format!("CLIP service unavailable: {msg}"),
|
||||
}),
|
||||
ScoreError::MalformedEmbedding => HttpResponse::BadGateway().json(SearchError {
|
||||
error: "CLIP service returned a malformed query embedding".into(),
|
||||
}),
|
||||
ScoreError::Internal(msg) => {
|
||||
HttpResponse::InternalServerError().json(SearchError { error: msg })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,10 +69,6 @@ pub fn gps_bounding_box(lat: f64, lon: f64, radius_km: f64) -> (f64, f64, f64, f
|
||||
/// a whole country. We collapse Nominatim's bounding box into the smallest
|
||||
/// circle that circumscribes it (see [`bbox_to_circle`]) so "Portland" and
|
||||
/// "Italy" both map onto the existing circle filter without a schema change.
|
||||
// Phase 1: forward geocoding is implemented and unit-tested here, but its
|
||||
// first consumer (the `/photos/search/unified` endpoint) lands in Phase 2.
|
||||
// allow-until-wired (mirrors llm_client.rs); remove when the endpoint is added.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct GeoPlace {
|
||||
/// Nominatim's canonical name for the match (e.g. "Italia").
|
||||
@@ -90,7 +86,6 @@ pub struct GeoPlace {
|
||||
/// Floor for a geocoded place's radius. Point results (a street address)
|
||||
/// come back with a near-zero bounding box; without a floor the circle
|
||||
/// filter would match nothing.
|
||||
#[allow(dead_code)]
|
||||
pub const MIN_PLACE_RADIUS_KM: f64 = 0.5;
|
||||
|
||||
/// Collapse a bounding box into the centroid + circumscribing radius.
|
||||
@@ -105,7 +100,6 @@ pub const MIN_PLACE_RADIUS_KM: f64 = 0.5;
|
||||
///
|
||||
/// Pure and exact (no flooring) so it can be unit-tested directly; callers
|
||||
/// apply [`MIN_PLACE_RADIUS_KM`] when turning the result into a filter.
|
||||
#[allow(dead_code)]
|
||||
pub fn bbox_to_circle(south: f64, north: f64, west: f64, east: f64) -> (f64, f64, f64) {
|
||||
let center_lat = (south + north) / 2.0;
|
||||
let center_lon = (west + east) / 2.0;
|
||||
@@ -118,7 +112,6 @@ pub fn bbox_to_circle(south: f64, north: f64, west: f64, east: f64) -> (f64, f64
|
||||
|
||||
/// Raw Nominatim `/search` result. `lat`/`lon` arrive as strings and
|
||||
/// `boundingbox` as a 4-element string array `[south, north, west, east]`.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Deserialize)]
|
||||
struct NominatimSearchResult {
|
||||
lat: String,
|
||||
@@ -136,7 +129,6 @@ struct NominatimSearchResult {
|
||||
///
|
||||
/// Nominatim's usage policy requires a `User-Agent` and rate-limits to ~1
|
||||
/// request/second; callers doing this interactively should cache results.
|
||||
#[allow(dead_code)]
|
||||
pub async fn forward_geocode(query: &str) -> Option<GeoPlace> {
|
||||
let q = query.trim();
|
||||
if q.is_empty() {
|
||||
|
||||
@@ -35,6 +35,7 @@ pub mod tags;
|
||||
#[cfg(test)]
|
||||
pub mod testhelpers;
|
||||
pub mod thumbnails;
|
||||
pub mod unified_search;
|
||||
pub mod utils;
|
||||
pub mod video;
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ mod perceptual_hash;
|
||||
mod state;
|
||||
mod tags;
|
||||
mod thumbnails;
|
||||
mod unified_search;
|
||||
mod utils;
|
||||
mod video;
|
||||
mod watcher;
|
||||
@@ -333,6 +334,13 @@ fn main() -> std::io::Result<()> {
|
||||
web::resource("/photos/search")
|
||||
.route(web::get().to(clip_search::search_photos)),
|
||||
)
|
||||
.service(
|
||||
// Unified natural-language search: LLM translates the
|
||||
// query into structured filters + a semantic term, then
|
||||
// filters constrain and CLIP ranks. See src/unified_search.rs.
|
||||
web::resource("/photos/search/unified")
|
||||
.route(web::get().to(unified_search::unified_search::<SqliteTagDao>)),
|
||||
)
|
||||
.service(web::resource("/file/move").post(move_file::<RealFileSystem>))
|
||||
.service(handlers::image::get_image)
|
||||
.service(handlers::image::upload_image)
|
||||
|
||||
@@ -0,0 +1,452 @@
|
||||
//! `/photos/search/unified?q=<natural language>` — unified NL photo search.
|
||||
//!
|
||||
//! One free-text box that composes the two existing engines instead of making
|
||||
//! the user pick between them:
|
||||
//! 1. A grounded local-LLM call ([`crate::ai::nl_query`]) translates the
|
||||
//! query into a structured filter + a semantic term.
|
||||
//! 2. Structured filters (tags / EXIF / geo / date / media-type) define the
|
||||
//! candidate set; the semantic term ranks within it via CLIP.
|
||||
//!
|
||||
//! Path A (orchestration): we reuse `clip_search`'s scoring core and the
|
||||
//! existing `ExifDao` / `TagDao` queries, joining on `content_hash`. EXIF rows
|
||||
//! are the universal candidate carrier — each has `(library_id, file_path,
|
||||
//! content_hash, date_taken)` — so the structured filter is just a predicate
|
||||
//! over them, and the CLIP hits (which key on `content_hash`) intersect by
|
||||
//! hash. No new schema, no surgery on `list_photos`.
|
||||
//!
|
||||
//! Degenerate cases collapse to the existing behavior: semantic-only → plain
|
||||
//! CLIP search; filters-only → a date-sorted filtered listing.
|
||||
//!
|
||||
//! Person filtering is intentionally deferred (no person→photos resolver yet).
|
||||
|
||||
use crate::AppState;
|
||||
use crate::ai::backend::{BackendKind, SamplingOverrides};
|
||||
use crate::ai::nl_query::{StructuredQuery, translate_nl_query};
|
||||
use crate::clip_search::{
|
||||
SearchHit, parse_library_scope, resolve_hits, score_error_response, score_photos,
|
||||
};
|
||||
use crate::data::Claims;
|
||||
use crate::database::ExifDao;
|
||||
use crate::file_types::{is_image_file, is_video_file};
|
||||
use crate::geo::{forward_geocode, gps_bounding_box, haversine_distance};
|
||||
use crate::tags::TagDao;
|
||||
use actix_web::HttpResponse;
|
||||
use actix_web::web::{Data, Query};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UnifiedQuery {
|
||||
/// Natural-language query. Required; empty triggers 400.
|
||||
pub q: String,
|
||||
#[serde(default = "default_limit")]
|
||||
pub limit: usize,
|
||||
#[serde(default)]
|
||||
pub offset: usize,
|
||||
/// CLIP cosine floor for the semantic ranking stage. Same default as the
|
||||
/// plain search endpoint.
|
||||
#[serde(default = "default_threshold")]
|
||||
pub threshold: f32,
|
||||
/// Legacy single-library scope (see clip_search).
|
||||
pub library: Option<i32>,
|
||||
/// Multi-library scope, comma-separated ids.
|
||||
pub library_ids: Option<String>,
|
||||
}
|
||||
|
||||
fn default_limit() -> usize {
|
||||
20
|
||||
}
|
||||
fn default_threshold() -> f32 {
|
||||
0.20
|
||||
}
|
||||
|
||||
/// A geocoded place echoed back so the client can show / edit the location
|
||||
/// filter it actually searched.
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResolvedPlace {
|
||||
display_name: String,
|
||||
lat: f64,
|
||||
lon: f64,
|
||||
radius_km: f64,
|
||||
}
|
||||
|
||||
/// How the server interpreted the NL query — echoed to the client to render
|
||||
/// editable filter chips. tag ids map to the client's existing tag list.
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Interpreted {
|
||||
semantic: Option<String>,
|
||||
tag_ids: Vec<i32>,
|
||||
exclude_tag_ids: Vec<i32>,
|
||||
/// Words the model treated as tags that don't exist in the vocab; folded
|
||||
/// into the semantic term and surfaced here so the UI can explain it.
|
||||
unmatched_tags: Vec<String>,
|
||||
camera_make: Option<String>,
|
||||
camera_model: Option<String>,
|
||||
lens_model: Option<String>,
|
||||
date_from: Option<i64>,
|
||||
date_to: Option<i64>,
|
||||
media_type: Option<String>,
|
||||
place: Option<ResolvedPlace>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct UnifiedResponse {
|
||||
query: String,
|
||||
interpreted: Interpreted,
|
||||
/// CLIP model version used for ranking; `None` when the query had no
|
||||
/// semantic term (filters-only).
|
||||
model_version: Option<String>,
|
||||
/// Embeddings scored by CLIP (0 when filters-only).
|
||||
considered: usize,
|
||||
/// Matches before pagination.
|
||||
total_matching: usize,
|
||||
offset: usize,
|
||||
results: Vec<SearchHit>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ErrorBody {
|
||||
error: String,
|
||||
}
|
||||
|
||||
fn bad_request(msg: impl Into<String>) -> HttpResponse {
|
||||
HttpResponse::BadRequest().json(ErrorBody { error: msg.into() })
|
||||
}
|
||||
|
||||
/// Combine the model's semantic term with any tag words that didn't match the
|
||||
/// vocab, so a hallucinated/non-vocab tag becomes a soft semantic signal
|
||||
/// rather than being dropped.
|
||||
fn effective_semantic(sq: &StructuredQuery) -> Option<String> {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
if let Some(s) = sq.semantic.as_deref() {
|
||||
parts.push(s.to_string());
|
||||
}
|
||||
parts.extend(sq.unmatched_tags.iter().cloned());
|
||||
if parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(parts.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unified_search<TagD: TagDao>(
|
||||
_: Claims,
|
||||
state: Data<AppState>,
|
||||
exif_dao: Data<Mutex<Box<dyn ExifDao>>>,
|
||||
tag_dao: Data<Mutex<TagD>>,
|
||||
query: Query<UnifiedQuery>,
|
||||
) -> HttpResponse {
|
||||
let nl = query.q.trim().to_string();
|
||||
if nl.is_empty() {
|
||||
return bad_request("query parameter `q` is required");
|
||||
}
|
||||
|
||||
let limit = query.limit.clamp(1, 200);
|
||||
let offset = query.offset;
|
||||
let threshold = query.threshold.clamp(-1.0, 1.0);
|
||||
|
||||
let library_ids = match parse_library_scope(query.library_ids.as_deref(), query.library) {
|
||||
Ok(ids) => ids,
|
||||
Err(msg) => return bad_request(msg),
|
||||
};
|
||||
|
||||
let ctx = opentelemetry::Context::current();
|
||||
|
||||
// ── 1. Translate the NL query, grounded on the real tag vocabulary ──
|
||||
let tag_vocab: Vec<(i32, String)> = {
|
||||
let mut dao = tag_dao.lock().expect("tag dao");
|
||||
match dao.get_all_tags(&ctx, None) {
|
||||
Ok(tags) => tags.into_iter().map(|(_, t)| (t.id, t.name)).collect(),
|
||||
Err(e) => {
|
||||
log::warn!("unified_search: get_all_tags failed: {e:?}");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Respect env/config for the LLM backend (LLM_BACKEND → ollama or
|
||||
// llama-swap); local only, no hybrid, per the feature's design.
|
||||
let overrides = SamplingOverrides {
|
||||
model: None,
|
||||
num_ctx: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
min_p: None,
|
||||
};
|
||||
let backend = match state
|
||||
.insight_generator
|
||||
.resolve_backend(BackendKind::Local, &overrides)
|
||||
.await
|
||||
{
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
log::warn!("unified_search: resolve_backend failed: {e:?}");
|
||||
return HttpResponse::ServiceUnavailable().json(ErrorBody {
|
||||
error: "LLM backend unavailable".into(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let today = chrono::Utc::now().date_naive();
|
||||
let sq = match translate_nl_query(backend.chat(), &nl, &tag_vocab, today).await {
|
||||
Ok(sq) => sq,
|
||||
Err(e) => {
|
||||
log::warn!("unified_search: translate_nl_query failed: {e:?}");
|
||||
return HttpResponse::BadGateway().json(ErrorBody {
|
||||
error: "could not interpret the query".into(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// ── 2. Forward-geocode the place name into a gps circle ──
|
||||
let resolved_place = match sq.place.as_deref() {
|
||||
Some(p) => forward_geocode(p).await.map(|g| ResolvedPlace {
|
||||
display_name: g.display_name,
|
||||
lat: g.lat,
|
||||
lon: g.lon,
|
||||
radius_km: g.radius_km,
|
||||
}),
|
||||
None => None,
|
||||
};
|
||||
let gps = resolved_place.as_ref().map(|p| (p.lat, p.lon, p.radius_km));
|
||||
|
||||
let semantic = effective_semantic(&sq);
|
||||
|
||||
let has_exif_filter = sq.camera_make.is_some()
|
||||
|| sq.camera_model.is_some()
|
||||
|| sq.lens_model.is_some()
|
||||
|| sq.date_from.is_some()
|
||||
|| sq.date_to.is_some();
|
||||
let has_struct =
|
||||
has_exif_filter || gps.is_some() || !sq.tag_ids.is_empty() || sq.media_type.is_some();
|
||||
|
||||
// ── 3. Build the structured candidate set (EXIF rows passing every
|
||||
// filter). Skipped entirely for a pure-semantic query. ──
|
||||
let mut candidate: Vec<crate::database::models::ImageExif> = Vec::new();
|
||||
let mut allowed_hashes: HashSet<String> = HashSet::new();
|
||||
if has_struct {
|
||||
// Tag membership set (rel_path only — same cross-library imprecision
|
||||
// as the existing /photos tag listing). ALL-mode: the photo must
|
||||
// carry every named tag.
|
||||
let tag_set: Option<HashSet<String>> = if sq.tag_ids.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut dao = tag_dao.lock().expect("tag dao");
|
||||
match dao.get_files_with_all_tag_ids(
|
||||
sq.tag_ids.clone(),
|
||||
sq.exclude_tag_ids.clone(),
|
||||
&ctx,
|
||||
) {
|
||||
Ok(files) => Some(files.into_iter().map(|f| f.file_name).collect()),
|
||||
Err(e) => {
|
||||
log::warn!("unified_search: tag filter failed: {e:?}");
|
||||
Some(HashSet::new())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// EXIF query handles camera/lens/gps-box/date. With no EXIF filters
|
||||
// it returns the whole table, which we then narrow by the predicates
|
||||
// below (tags / media / scope). Fine at personal-library scale.
|
||||
let gps_bounds = gps.map(|(lat, lon, r)| gps_bounding_box(lat, lon, r));
|
||||
let rows = {
|
||||
let mut dao = exif_dao.lock().expect("exif dao");
|
||||
dao.query_by_exif(
|
||||
&ctx,
|
||||
None, // scope filtered in-Rust to support multi-library
|
||||
sq.camera_make.as_deref(),
|
||||
sq.camera_model.as_deref(),
|
||||
sq.lens_model.as_deref(),
|
||||
gps_bounds,
|
||||
sq.date_from,
|
||||
sq.date_to,
|
||||
)
|
||||
.unwrap_or_else(|e| {
|
||||
log::warn!("unified_search: query_by_exif failed: {e:?}");
|
||||
Vec::new()
|
||||
})
|
||||
};
|
||||
|
||||
candidate = rows
|
||||
.into_iter()
|
||||
.filter(|row| {
|
||||
// Library scope.
|
||||
if !library_ids.is_empty() && !library_ids.contains(&row.library_id) {
|
||||
return false;
|
||||
}
|
||||
// Precise GPS distance (the EXIF query only did a coarse box).
|
||||
if let Some((lat, lon, radius_km)) = gps {
|
||||
match (row.gps_latitude, row.gps_longitude) {
|
||||
(Some(plat), Some(plon)) => {
|
||||
if haversine_distance(lat, lon, plat as f64, plon as f64) > radius_km {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
_ => return false,
|
||||
}
|
||||
}
|
||||
// Media type.
|
||||
if let Some(mt) = sq.media_type.as_deref() {
|
||||
let p = Path::new(&row.file_path);
|
||||
let ok = if mt == "video" {
|
||||
is_video_file(p)
|
||||
} else {
|
||||
is_image_file(p)
|
||||
};
|
||||
if !ok {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Tag membership.
|
||||
if let Some(ts) = &tag_set
|
||||
&& !ts.contains(&row.file_path)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
allowed_hashes = candidate
|
||||
.iter()
|
||||
.filter_map(|r| r.content_hash.clone())
|
||||
.collect();
|
||||
}
|
||||
|
||||
// ── 4. Rank ──
|
||||
match semantic {
|
||||
Some(ref sem) => {
|
||||
// Semantic term present: CLIP-rank, then keep only hits that pass
|
||||
// the structured filters (by content_hash).
|
||||
let scored =
|
||||
match score_photos(&state, &exif_dao, sem, &library_ids, threshold, None).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => return score_error_response(e),
|
||||
};
|
||||
let hits: Vec<(f32, String)> = if has_struct {
|
||||
scored
|
||||
.hits
|
||||
.into_iter()
|
||||
.filter(|(_, h)| allowed_hashes.contains(h))
|
||||
.collect()
|
||||
} else {
|
||||
scored.hits
|
||||
};
|
||||
let total_matching = hits.len();
|
||||
let page = paginate(&hits, offset, limit);
|
||||
let results = resolve_hits(&exif_dao, &page);
|
||||
HttpResponse::Ok().json(UnifiedResponse {
|
||||
query: nl,
|
||||
interpreted: interpreted(&sq, resolved_place),
|
||||
model_version: Some(scored.model_version),
|
||||
considered: scored.considered,
|
||||
total_matching,
|
||||
offset,
|
||||
results,
|
||||
})
|
||||
}
|
||||
None => {
|
||||
// Filters-only: no semantic term. Require at least one filter,
|
||||
// then return the candidate set newest-first.
|
||||
if !has_struct {
|
||||
return bad_request("query had no searchable terms");
|
||||
}
|
||||
candidate.sort_by(|a, b| b.date_taken.cmp(&a.date_taken));
|
||||
let total_matching = candidate.len();
|
||||
let end = (offset + limit).min(total_matching);
|
||||
let results: Vec<SearchHit> = if offset >= total_matching {
|
||||
Vec::new()
|
||||
} else {
|
||||
candidate[offset..end]
|
||||
.iter()
|
||||
.map(|r| SearchHit {
|
||||
library_id: r.library_id,
|
||||
rel_path: r.file_path.clone(),
|
||||
content_hash: r.content_hash.clone().unwrap_or_default(),
|
||||
score: 0.0,
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
HttpResponse::Ok().json(UnifiedResponse {
|
||||
query: nl,
|
||||
interpreted: interpreted(&sq, resolved_place),
|
||||
model_version: None,
|
||||
considered: 0,
|
||||
total_matching,
|
||||
offset,
|
||||
results,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Slice a sorted hit list at `[offset, offset+limit)`, tolerating
|
||||
/// out-of-range offsets (empty page).
|
||||
fn paginate(hits: &[(f32, String)], offset: usize, limit: usize) -> Vec<(f32, String)> {
|
||||
if offset >= hits.len() {
|
||||
return Vec::new();
|
||||
}
|
||||
let end = (offset + limit).min(hits.len());
|
||||
hits[offset..end].to_vec()
|
||||
}
|
||||
|
||||
fn interpreted(sq: &StructuredQuery, place: Option<ResolvedPlace>) -> Interpreted {
|
||||
Interpreted {
|
||||
semantic: sq.semantic.clone(),
|
||||
tag_ids: sq.tag_ids.clone(),
|
||||
exclude_tag_ids: sq.exclude_tag_ids.clone(),
|
||||
unmatched_tags: sq.unmatched_tags.clone(),
|
||||
camera_make: sq.camera_make.clone(),
|
||||
camera_model: sq.camera_model.clone(),
|
||||
lens_model: sq.lens_model.clone(),
|
||||
date_from: sq.date_from,
|
||||
date_to: sq.date_to,
|
||||
media_type: sq.media_type.clone(),
|
||||
place,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ai::nl_query::StructuredQuery;
|
||||
|
||||
#[test]
|
||||
fn effective_semantic_combines_semantic_and_unmatched() {
|
||||
let sq = StructuredQuery {
|
||||
semantic: Some("sunset".into()),
|
||||
unmatched_tags: vec!["golden hour".into()],
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(
|
||||
effective_semantic(&sq).as_deref(),
|
||||
Some("sunset golden hour")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_semantic_none_when_empty() {
|
||||
let sq = StructuredQuery::default();
|
||||
assert_eq!(effective_semantic(&sq), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_semantic_unmatched_only() {
|
||||
let sq = StructuredQuery {
|
||||
unmatched_tags: vec!["disco".into()],
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(effective_semantic(&sq).as_deref(), Some("disco"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn paginate_handles_out_of_range_offset() {
|
||||
let hits = vec![(0.9, "a".to_string()), (0.8, "b".to_string())];
|
||||
assert_eq!(paginate(&hits, 5, 10).len(), 0);
|
||||
assert_eq!(paginate(&hits, 0, 1).len(), 1);
|
||||
assert_eq!(paginate(&hits, 1, 10).len(), 1);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user