diff --git a/src/ai/mod.rs b/src/ai/mod.rs index c5302fb..7d0802e 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -10,6 +10,7 @@ pub mod insight_generator; pub mod llamacpp; pub mod llm_client; pub mod local_llm; +pub mod nl_query; pub mod ollama; pub mod openrouter; pub mod pronunciation; diff --git a/src/ai/nl_query.rs b/src/ai/nl_query.rs new file mode 100644 index 0000000..a94fc06 --- /dev/null +++ b/src/ai/nl_query.rs @@ -0,0 +1,414 @@ +//! Natural-language → structured-query translation for unified photo search. +//! +//! The unified search endpoint (`/photos/search/unified`, Phase 2) needs to +//! turn a free-text query like *"sunset photos in Italy from last summer"* +//! into the structured filter the existing `/photos` engine understands plus +//! a semantic term for CLIP ranking. That translation is a single grounded +//! LLM call, isolated here so it can be unit-tested without a network or the +//! full `InsightGenerator`. +//! +//! Two-stage design: +//! 1. The LLM emits a [`RawNlQuery`] — references are by *name* (tags) and +//! dates as ISO strings, never numeric ids it could hallucinate. +//! 2. [`resolve_raw_query`] maps names against the real tag vocabulary and +//! converts ISO dates to unix seconds, producing a [`StructuredQuery`]. +//! A tag the model invents that isn't in the vocab is surfaced in +//! `unmatched_tags` (the caller folds it back into the semantic term) +//! rather than silently dropped — this is the anti-noise guard. +//! +//! Geocoding of `place` and person filtering are intentionally *not* handled +//! here: `place` stays as text for the caller to forward-geocode (async, see +//! `geo::forward_geocode`), and person filtering is deferred until a +//! person→photos resolver exists. + +// Phase 1: this module is fully implemented and unit-tested, but its first +// consumer (the `/photos/search/unified` endpoint) lands in Phase 2. Mirrors +// llm_client.rs's allow-until-wired pattern so the bin target stays +// clippy-clean in the interim; remove when the endpoint is added. +#![allow(dead_code)] + +use crate::ai::llm_client::{ChatMessage, LlmClient, Tool, strip_think_blocks}; +use anyhow::{Result, anyhow}; +use serde::{Deserialize, Serialize}; + +/// Raw query object as emitted by the LLM. Tag references are by name +/// (resolved against the real vocab in Rust); dates are ISO `YYYY-MM-DD`. +/// Every field is optional so a partial / minimal model response still +/// deserializes. +#[derive(Debug, Clone, Default, Deserialize, PartialEq)] +pub struct RawNlQuery { + /// Visual/scene description handed to CLIP for ranking. The descriptive + /// remainder after structured filters are peeled off. + #[serde(default)] + pub semantic: Option, + /// Tag names the photos must have. Matched case-insensitively against + /// the supplied vocabulary; non-matches land in `unmatched_tags`. + #[serde(default)] + pub tags: Vec, + /// Tag names the photos must NOT have. + #[serde(default)] + pub exclude_tags: Vec, + #[serde(default)] + pub camera_make: Option, + #[serde(default)] + pub camera_model: Option, + #[serde(default)] + pub lens_model: Option, + /// Free-text place/location name to forward-geocode (e.g. "Italy"). + #[serde(default)] + pub place: Option, + /// Inclusive start date, ISO `YYYY-MM-DD`. + #[serde(default)] + pub date_from: Option, + /// Inclusive end date, ISO `YYYY-MM-DD`. + #[serde(default)] + pub date_to: Option, + /// "photo" | "video" — normalized in [`resolve_raw_query`]. + #[serde(default)] + pub media_type: Option, +} + +/// Resolved structured query: tag names mapped to ids against the real +/// vocab, ISO dates converted to unix seconds. `place` stays as text for the +/// caller to forward-geocode into a gps circle. Serializable so the endpoint +/// can echo it back to the client as "this is how I read your query" +/// (editable filter chips). +#[derive(Debug, Clone, Default, PartialEq, Serialize)] +pub struct StructuredQuery { + pub semantic: Option, + pub tag_ids: Vec, + pub exclude_tag_ids: Vec, + /// Tag names the model produced that don't exist in the vocabulary. + /// The caller folds these back into the semantic term so the concept + /// isn't lost — and surfacing them keeps a hallucinated tag from + /// silently filtering the whole library to nothing. + pub unmatched_tags: Vec, + pub camera_make: Option, + pub camera_model: Option, + pub lens_model: Option, + /// Raw place name awaiting forward-geocoding by the caller. + pub place: Option, + pub date_from: Option, + pub date_to: Option, + /// Normalized to "photo" | "video"; `None` means no media-type filter. + pub media_type: Option, +} + +/// Convert an ISO `YYYY-MM-DD` date to a unix timestamp (seconds). With +/// `end_of_day`, returns 23:59:59 of that day so a `date_to` filter is +/// inclusive of the whole day; otherwise 00:00:00. Returns `None` for any +/// unparseable input (the filter is simply omitted rather than erroring). +pub fn iso_to_unix(date: &str, end_of_day: bool) -> Option { + let d = chrono::NaiveDate::parse_from_str(date.trim(), "%Y-%m-%d").ok()?; + let time = if end_of_day { + chrono::NaiveTime::from_hms_opt(23, 59, 59)? + } else { + chrono::NaiveTime::from_hms_opt(0, 0, 0)? + }; + Some(d.and_time(time).and_utc().timestamp()) +} + +/// Normalize a free-form media-type string to the engine's vocabulary. +/// Anything that isn't clearly photo or video (including "all") yields +/// `None` — no filter. +fn normalize_media_type(raw: &str) -> Option { + match raw.trim().to_lowercase().as_str() { + "photo" | "photos" | "image" | "images" | "picture" | "pictures" => { + Some("photo".to_string()) + } + "video" | "videos" | "movie" | "movies" | "clip" | "clips" => Some("video".to_string()), + _ => None, + } +} + +/// Resolve a raw LLM query against the real tag vocabulary, producing the +/// structured filter. Pure — no network, no LLM — so it carries the +/// correctness-critical mapping logic under unit test. +/// +/// `tag_vocab` is `(tag_id, tag_name)` pairs (the shape `TagDao::get_all_tags` +/// yields once the count is dropped). Matching is case-insensitive and exact +/// on the trimmed name. +pub fn resolve_raw_query(raw: RawNlQuery, tag_vocab: &[(i32, String)]) -> StructuredQuery { + // Case-insensitive name → id lookup. Built once per call. + let lookup: std::collections::HashMap = tag_vocab + .iter() + .map(|(id, name)| (name.trim().to_lowercase(), *id)) + .collect(); + + let resolve_names = |names: &[String], ids: &mut Vec, unmatched: &mut Vec| { + for name in names { + let key = name.trim().to_lowercase(); + if key.is_empty() { + continue; + } + match lookup.get(&key) { + Some(id) if !ids.contains(id) => ids.push(*id), + Some(_) => {} // duplicate, already collected + None => { + if !unmatched.iter().any(|u| u.eq_ignore_ascii_case(name)) { + unmatched.push(name.trim().to_string()); + } + } + } + } + }; + + let mut tag_ids = Vec::new(); + let mut unmatched_tags = Vec::new(); + resolve_names(&raw.tags, &mut tag_ids, &mut unmatched_tags); + + // Excluded tags that don't match a real tag are simply ignored — you + // can't exclude a tag that doesn't exist, and folding them into + // `semantic` would make no sense. + let mut exclude_tag_ids = Vec::new(); + let mut exclude_unmatched = Vec::new(); + resolve_names( + &raw.exclude_tags, + &mut exclude_tag_ids, + &mut exclude_unmatched, + ); + + let clean = |s: Option| s.map(|v| v.trim().to_string()).filter(|v| !v.is_empty()); + + StructuredQuery { + semantic: clean(raw.semantic), + tag_ids, + exclude_tag_ids, + unmatched_tags, + camera_make: clean(raw.camera_make), + camera_model: clean(raw.camera_model), + lens_model: clean(raw.lens_model), + place: clean(raw.place), + date_from: raw.date_from.as_deref().and_then(|d| iso_to_unix(d, false)), + date_to: raw.date_to.as_deref().and_then(|d| iso_to_unix(d, true)), + media_type: raw.media_type.as_deref().and_then(normalize_media_type), + } +} + +/// Build the grounded system prompt. The model is told the current date (so +/// "last summer" resolves) and the exact tag vocabulary (so it uses real +/// tags or routes the concept to `semantic` instead of inventing one). +fn build_system_prompt(tag_vocab: &[(i32, String)], today: chrono::NaiveDate) -> String { + // Cap the vocab dump so a huge library doesn't blow the context window; + // the most-used tags are the ones a query is likely to reference. + const MAX_TAGS: usize = 400; + let mut names: Vec<&str> = tag_vocab.iter().map(|(_, n)| n.as_str()).collect(); + names.sort_unstable(); + names.dedup(); + let shown = names.len().min(MAX_TAGS); + let vocab = names[..shown].join(", "); + let truncation = if names.len() > MAX_TAGS { + format!(" (showing {MAX_TAGS} of {} tags)", names.len()) + } else { + String::new() + }; + + format!( + "You translate a user's natural-language photo-search request into a JSON \ +filter. Today's date is {today}. Respond with ONLY a JSON object, no prose, no \ +code fences.\n\n\ +Schema (all fields optional):\n\ +{{\n \ +\"semantic\": string|null, // visual scene/subject for image similarity search\n \ +\"tags\": string[], // ONLY names from the tag list below\n \ +\"exclude_tags\": string[], // ONLY names from the tag list below\n \ +\"camera_make\": string|null,\n \ +\"camera_model\": string|null,\n \ +\"lens_model\": string|null,\n \ +\"place\": string|null, // a location name to look up (city, country, landmark)\n \ +\"date_from\": \"YYYY-MM-DD\"|null, // inclusive\n \ +\"date_to\": \"YYYY-MM-DD\"|null, // inclusive\n \ +\"media_type\": \"photo\"|\"video\"|null\n\ +}}\n\n\ +Rules:\n\ +- Put descriptive/visual concepts (\"sunset\", \"crowded beach\", \"red car\") in \"semantic\".\n\ +- Only use \"tags\"/\"exclude_tags\" values that appear EXACTLY in the tag list. If a \ +concept isn't a listed tag, put it in \"semantic\" instead — never invent a tag.\n\ +- Resolve relative dates against today's date (\"last summer\", \"2023\", \"last month\").\n\ +- Put place/location names in \"place\" (not \"semantic\").\n\ +- Omit (use null / empty array) anything the request doesn't mention.\n\n\ +Available tags{truncation}: {vocab}" + ) +} + +/// Extract the JSON object from a model response that may include a leading +/// `` block, code fences, or trailing prose. Strips the think block +/// first (so reasoning that mentions braces can't fool the scan), then +/// returns the substring from the first `{` to the last `}` inclusive — or +/// the trimmed text if no braces are found (which then fails to parse with a +/// clear error). +fn extract_json(raw: &str) -> String { + let s = strip_think_blocks(raw); + let start = s.find('{'); + let end = s.rfind('}'); + match (start, end) { + (Some(a), Some(b)) if b >= a => s[a..=b].to_string(), + _ => s.trim().to_string(), + } +} + +/// Parse a model response string into a [`StructuredQuery`], resolving names +/// against the vocab. Separated from the LLM call so it's unit-testable. +pub fn parse_response(response: &str, tag_vocab: &[(i32, String)]) -> Result { + let json = extract_json(response); + let raw: RawNlQuery = serde_json::from_str(&json) + .map_err(|e| anyhow!("failed to parse NL query JSON: {e}; raw response: {response:?}"))?; + Ok(resolve_raw_query(raw, tag_vocab)) +} + +/// Translate a natural-language query into a [`StructuredQuery`] via one +/// grounded LLM call. The `client` is any configured backend (the unified +/// endpoint passes the resolved chat backend); `tag_vocab` grounds the tag +/// mapping; `today` anchors relative-date resolution. +pub async fn translate_nl_query( + client: &dyn LlmClient, + nl: &str, + tag_vocab: &[(i32, String)], + today: chrono::NaiveDate, +) -> Result { + let system = build_system_prompt(tag_vocab, today); + let messages = vec![ChatMessage::system(system), ChatMessage::user(nl)]; + let (msg, _, _) = client.chat_with_tools(messages, Vec::::new()).await?; + parse_response(&msg.content, tag_vocab) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn vocab() -> Vec<(i32, String)> { + vec![ + (1, "beach".to_string()), + (2, "Sunset".to_string()), // mixed case to exercise case-insensitivity + (3, "family".to_string()), + ] + } + + #[test] + fn iso_to_unix_start_and_end_of_day() { + // 2023-01-01 UTC midnight = 1672531200. + assert_eq!(iso_to_unix("2023-01-01", false), Some(1_672_531_200)); + // End of that day is 86399 seconds later. + assert_eq!( + iso_to_unix("2023-01-01", true), + Some(1_672_531_200 + 86_399) + ); + } + + #[test] + fn iso_to_unix_rejects_garbage() { + assert_eq!(iso_to_unix("last summer", false), None); + assert_eq!(iso_to_unix("2023-13-99", false), None); + assert_eq!(iso_to_unix("", false), None); + } + + #[test] + fn resolve_matches_tags_case_insensitively() { + let raw = RawNlQuery { + tags: vec!["BEACH".to_string(), "sunset".to_string()], + ..Default::default() + }; + let q = resolve_raw_query(raw, &vocab()); + assert_eq!(q.tag_ids, vec![1, 2]); + assert!(q.unmatched_tags.is_empty()); + } + + #[test] + fn resolve_surfaces_unmatched_tags_not_silently_dropped() { + // A hallucinated / non-vocab tag must be surfaced so the caller can + // fold it into semantic — never silently used as a hard filter. + let raw = RawNlQuery { + tags: vec!["beach".to_string(), "golden hour".to_string()], + ..Default::default() + }; + let q = resolve_raw_query(raw, &vocab()); + assert_eq!(q.tag_ids, vec![1]); + assert_eq!(q.unmatched_tags, vec!["golden hour".to_string()]); + } + + #[test] + fn resolve_dedups_repeated_tags() { + let raw = RawNlQuery { + tags: vec![ + "beach".to_string(), + "Beach".to_string(), + "beach".to_string(), + ], + ..Default::default() + }; + let q = resolve_raw_query(raw, &vocab()); + assert_eq!(q.tag_ids, vec![1]); + } + + #[test] + fn resolve_normalizes_media_type_and_dates() { + let raw = RawNlQuery { + media_type: Some("Videos".to_string()), + date_from: Some("2023-06-01".to_string()), + date_to: Some("2023-06-30".to_string()), + ..Default::default() + }; + let q = resolve_raw_query(raw, &vocab()); + assert_eq!(q.media_type.as_deref(), Some("video")); + assert_eq!(q.date_from, iso_to_unix("2023-06-01", false)); + assert_eq!(q.date_to, iso_to_unix("2023-06-30", true)); + } + + #[test] + fn resolve_media_type_all_is_no_filter() { + let raw = RawNlQuery { + media_type: Some("all".to_string()), + ..Default::default() + }; + assert_eq!(resolve_raw_query(raw, &vocab()).media_type, None); + } + + #[test] + fn resolve_trims_and_empties_to_none() { + let raw = RawNlQuery { + semantic: Some(" ".to_string()), + camera_make: Some(" Fujifilm ".to_string()), + place: Some("".to_string()), + ..Default::default() + }; + let q = resolve_raw_query(raw, &vocab()); + assert_eq!(q.semantic, None); + assert_eq!(q.camera_make.as_deref(), Some("Fujifilm")); + assert_eq!(q.place, None); + } + + #[test] + fn parse_response_handles_code_fences_and_prose() { + let resp = "Here is the filter:\n```json\n{\"semantic\":\"sunset\",\"tags\":[\"beach\"]}\n```\nDone."; + let q = parse_response(resp, &vocab()).expect("parse"); + assert_eq!(q.semantic.as_deref(), Some("sunset")); + assert_eq!(q.tag_ids, vec![1]); + } + + #[test] + fn parse_response_handles_think_block_then_json() { + let resp = "user wants beach sunsets{\"tags\":[\"beach\",\"sunset\"]}"; + let q = parse_response(resp, &vocab()).expect("parse"); + assert_eq!(q.tag_ids, vec![1, 2]); + } + + #[test] + fn parse_response_errors_on_non_json() { + assert!(parse_response("I cannot help with that.", &vocab()).is_err()); + } + + #[test] + fn build_system_prompt_includes_date_and_vocab() { + let today = chrono::NaiveDate::from_ymd_opt(2026, 6, 14).unwrap(); + let prompt = build_system_prompt(&vocab(), today); + assert!( + prompt.contains("2026-06-14"), + "prompt should state today's date" + ); + assert!(prompt.contains("beach"), "prompt should list the vocab"); + assert!( + prompt.contains("never invent a tag"), + "prompt should warn against inventing tags" + ); + } +} diff --git a/src/geo.rs b/src/geo.rs index 46cc1dc..b7ef9d1 100644 --- a/src/geo.rs +++ b/src/geo.rs @@ -1,4 +1,5 @@ /// Geographic calculation utilities for GPS-based search +use serde::Deserialize; use std::f64; /// Calculate distance between two GPS coordinates using the Haversine formula. @@ -61,6 +62,148 @@ pub fn gps_bounding_box(lat: f64, lon: f64, radius_km: f64) -> (f64, f64, f64, f ) } +/// A place resolved from a free-text query via forward geocoding. +/// +/// The filter pipeline searches a *circle* (`gps_lat`/`gps_lon`/ +/// `gps_radius_km`), but a place can be anything from a single address to +/// a whole country. We collapse Nominatim's bounding box into the smallest +/// circle that circumscribes it (see [`bbox_to_circle`]) so "Portland" and +/// "Italy" both map onto the existing circle filter without a schema change. +// Phase 1: forward geocoding is implemented and unit-tested here, but its +// first consumer (the `/photos/search/unified` endpoint) lands in Phase 2. +// allow-until-wired (mirrors llm_client.rs); remove when the endpoint is added. +#[allow(dead_code)] +#[derive(Debug, Clone, PartialEq)] +pub struct GeoPlace { + /// Nominatim's canonical name for the match (e.g. "Italia"). + pub display_name: String, + /// Centroid latitude in decimal degrees. + pub lat: f64, + /// Centroid longitude in decimal degrees. + pub lon: f64, + /// Radius (km) of a circle centred on the centroid that covers the + /// matched area. Floored to [`MIN_PLACE_RADIUS_KM`] so a point result + /// (whose bounding box is microscopic) still yields a usable circle. + pub radius_km: f64, +} + +/// Floor for a geocoded place's radius. Point results (a street address) +/// come back with a near-zero bounding box; without a floor the circle +/// filter would match nothing. +#[allow(dead_code)] +pub const MIN_PLACE_RADIUS_KM: f64 = 0.5; + +/// Collapse a bounding box into the centroid + circumscribing radius. +/// +/// Input is Nominatim's `boundingbox` order: `(south_lat, north_lat, +/// west_lon, east_lon)`. The radius is the *largest* great-circle distance +/// from the centroid to any of the four corners, so the resulting circle +/// fully covers the box. (The corners aren't equidistant on a sphere — +/// longitude lines converge toward the poles, so the equator-facing edge's +/// corners are farthest; taking the max guarantees coverage in either +/// hemisphere.) +/// +/// Pure and exact (no flooring) so it can be unit-tested directly; callers +/// apply [`MIN_PLACE_RADIUS_KM`] when turning the result into a filter. +#[allow(dead_code)] +pub fn bbox_to_circle(south: f64, north: f64, west: f64, east: f64) -> (f64, f64, f64) { + let center_lat = (south + north) / 2.0; + let center_lon = (west + east) / 2.0; + let radius_km = [(south, west), (south, east), (north, west), (north, east)] + .iter() + .map(|(clat, clon)| haversine_distance(center_lat, center_lon, *clat, *clon)) + .fold(0.0_f64, f64::max); + (center_lat, center_lon, radius_km) +} + +/// Raw Nominatim `/search` result. `lat`/`lon` arrive as strings and +/// `boundingbox` as a 4-element string array `[south, north, west, east]`. +#[allow(dead_code)] +#[derive(Deserialize)] +struct NominatimSearchResult { + lat: String, + lon: String, + display_name: String, + boundingbox: Option<[String; 4]>, +} + +/// Forward-geocode a free-text place name to a [`GeoPlace`] via the public +/// OpenStreetMap Nominatim `/search` endpoint. +/// +/// Mirrors `InsightGenerator::reverse_geocode`'s error posture: any network, +/// HTTP, or parse failure returns `None` rather than propagating, so a flaky +/// geocoder degrades the query to "no location filter" instead of failing it. +/// +/// Nominatim's usage policy requires a `User-Agent` and rate-limits to ~1 +/// request/second; callers doing this interactively should cache results. +#[allow(dead_code)] +pub async fn forward_geocode(query: &str) -> Option { + let q = query.trim(); + if q.is_empty() { + return None; + } + + let client = reqwest::Client::new(); + let response = match client + .get("https://nominatim.openstreetmap.org/search") + .query(&[("format", "json"), ("limit", "1"), ("q", q)]) + .header("User-Agent", "ImageAPI/1.0") // Nominatim requires User-Agent + .send() + .await + { + Ok(resp) => resp, + Err(e) => { + log::warn!("Forward geocoding network error for {q:?}: {e}"); + return None; + } + }; + + if !response.status().is_success() { + log::warn!( + "Forward geocoding HTTP error for {q:?}: {}", + response.status() + ); + return None; + } + + let results: Vec = match response.json().await { + Ok(r) => r, + Err(e) => { + log::warn!("Forward geocoding JSON parse error for {q:?}: {e}"); + return None; + } + }; + + let top = results.into_iter().next()?; + let lat: f64 = top.lat.parse().ok()?; + let lon: f64 = top.lon.parse().ok()?; + + // Prefer the bounding box (handles large places); fall back to a + // point + floor radius when Nominatim omits it. + let (center_lat, center_lon, radius_km) = match &top.boundingbox { + Some([s, n, w, e]) => match (s.parse(), n.parse(), w.parse(), e.parse()) { + (Ok(s), Ok(n), Ok(w), Ok(e)) => bbox_to_circle(s, n, w, e), + _ => (lat, lon, 0.0), + }, + None => (lat, lon, 0.0), + }; + + let place = GeoPlace { + display_name: top.display_name, + lat: center_lat, + lon: center_lon, + radius_km: radius_km.max(MIN_PLACE_RADIUS_KM), + }; + log::info!( + "Forward geocoded {q:?} -> {} ({:.4}, {:.4}, r={:.1}km)", + place.display_name, + place.lat, + place.lon, + place.radius_km + ); + Some(place) +} + #[cfg(test)] mod tests { use super::*; @@ -118,4 +261,41 @@ mod tests { distance ); } + + #[test] + fn test_bbox_to_circle_centroid() { + // Symmetric box around (10, 20): centroid should land dead centre. + let (lat, lon, radius) = bbox_to_circle(9.0, 11.0, 19.0, 21.0); + assert!((lat - 10.0).abs() < 1e-9, "centroid lat, got {lat}"); + assert!((lon - 20.0).abs() < 1e-9, "centroid lon, got {lon}"); + assert!(radius > 0.0, "radius should be positive, got {radius}"); + } + + #[test] + fn test_bbox_to_circle_covers_corner() { + // The radius must reach every corner of the box. Verify the + // centroid-to-corner distance equals the returned radius for all + // four corners (they're symmetric, so all equal). + let (south, north, west, east) = (40.0, 42.0, -74.0, -72.0); + let (lat, lon, radius) = bbox_to_circle(south, north, west, east); + for (clat, clon) in [(south, west), (south, east), (north, west), (north, east)] { + let d = haversine_distance(lat, lon, clat, clon); + assert!( + d <= radius + 1e-6, + "corner ({clat},{clon}) at {d}km should be within radius {radius}km" + ); + } + } + + #[test] + fn test_bbox_to_circle_country_vs_city_scale() { + // A country-sized box yields a far larger radius than a city-sized + // one — confirming the bbox approach scales with place size. + let (_, _, country) = bbox_to_circle(35.5, 47.1, 6.6, 18.5); // ~Italy + let (_, _, city) = bbox_to_circle(45.4, 45.6, -122.8, -122.5); // ~Portland + assert!( + country > city * 10.0, + "country radius {country}km should dwarf city radius {city}km" + ); + } }