Files
ImageApi/src/database/search_dao.rs
2026-01-14 13:17:58 -05:00

517 lines
18 KiB
Rust

use diesel::prelude::*;
use diesel::sqlite::SqliteConnection;
use serde::Serialize;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex};
use crate::database::{DbError, DbErrorKind, connect};
use crate::otel::trace_db_call;
/// Represents a search history record
#[derive(Serialize, Clone, Debug)]
pub struct SearchRecord {
pub id: i32,
pub timestamp: i64,
pub query: String,
pub search_engine: Option<String>,
pub created_at: i64,
pub source_file: Option<String>,
}
/// Data for inserting a new search record
#[derive(Clone, Debug)]
pub struct InsertSearchRecord {
pub timestamp: i64,
pub query: String,
pub search_engine: Option<String>,
pub embedding: Vec<f32>, // 768-dim, REQUIRED
pub created_at: i64,
pub source_file: Option<String>,
}
pub trait SearchHistoryDao: Sync + Send {
/// Store search with embedding
fn store_search(
&mut self,
context: &opentelemetry::Context,
search: InsertSearchRecord,
) -> Result<SearchRecord, DbError>;
/// Batch insert searches
fn store_searches_batch(
&mut self,
context: &opentelemetry::Context,
searches: Vec<InsertSearchRecord>,
) -> Result<usize, DbError>;
/// Find searches in time range (for temporal context)
fn find_searches_in_range(
&mut self,
context: &opentelemetry::Context,
start_ts: i64,
end_ts: i64,
) -> Result<Vec<SearchRecord>, DbError>;
/// Find semantically similar searches (PRIMARY - embeddings shine here)
fn find_similar_searches(
&mut self,
context: &opentelemetry::Context,
query_embedding: &[f32],
limit: usize,
) -> Result<Vec<SearchRecord>, DbError>;
/// Hybrid: Time window + semantic ranking
fn find_relevant_searches_hybrid(
&mut self,
context: &opentelemetry::Context,
center_timestamp: i64,
time_window_days: i64,
query_embedding: Option<&[f32]>,
limit: usize,
) -> Result<Vec<SearchRecord>, DbError>;
/// Deduplication check
fn search_exists(
&mut self,
context: &opentelemetry::Context,
timestamp: i64,
query: &str,
) -> Result<bool, DbError>;
/// Get count of search records
fn get_search_count(&mut self, context: &opentelemetry::Context) -> Result<i64, DbError>;
}
pub struct SqliteSearchHistoryDao {
connection: Arc<Mutex<SqliteConnection>>,
}
impl Default for SqliteSearchHistoryDao {
fn default() -> Self {
Self::new()
}
}
impl SqliteSearchHistoryDao {
pub fn new() -> Self {
SqliteSearchHistoryDao {
connection: Arc::new(Mutex::new(connect())),
}
}
fn serialize_vector(vec: &[f32]) -> Vec<u8> {
use zerocopy::IntoBytes;
vec.as_bytes().to_vec()
}
fn deserialize_vector(bytes: &[u8]) -> Result<Vec<f32>, DbError> {
if !bytes.len().is_multiple_of(4) {
return Err(DbError::new(DbErrorKind::QueryError));
}
let count = bytes.len() / 4;
let mut vec = Vec::with_capacity(count);
for chunk in bytes.chunks_exact(4) {
let float = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
vec.push(float);
}
Ok(vec)
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
dot_product / (magnitude_a * magnitude_b)
}
}
#[derive(QueryableByName)]
struct SearchRecordWithVectorRow {
#[diesel(sql_type = diesel::sql_types::Integer)]
id: i32,
#[diesel(sql_type = diesel::sql_types::BigInt)]
timestamp: i64,
#[diesel(sql_type = diesel::sql_types::Text)]
query: String,
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
search_engine: Option<String>,
#[diesel(sql_type = diesel::sql_types::Binary)]
embedding: Vec<u8>,
#[diesel(sql_type = diesel::sql_types::BigInt)]
created_at: i64,
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
source_file: Option<String>,
}
impl SearchRecordWithVectorRow {
fn to_search_record(&self) -> SearchRecord {
SearchRecord {
id: self.id,
timestamp: self.timestamp,
query: self.query.clone(),
search_engine: self.search_engine.clone(),
created_at: self.created_at,
source_file: self.source_file.clone(),
}
}
}
#[derive(QueryableByName)]
struct LastInsertRowId {
#[diesel(sql_type = diesel::sql_types::Integer)]
id: i32,
}
impl SearchHistoryDao for SqliteSearchHistoryDao {
fn store_search(
&mut self,
context: &opentelemetry::Context,
search: InsertSearchRecord,
) -> Result<SearchRecord, DbError> {
trace_db_call(context, "insert", "store_search", |_span| {
let mut conn = self
.connection
.lock()
.expect("Unable to get SearchHistoryDao");
// Validate embedding dimensions (REQUIRED for searches)
if search.embedding.len() != 768 {
return Err(anyhow::anyhow!(
"Invalid embedding dimensions: {} (expected 768)",
search.embedding.len()
));
}
let embedding_bytes = Self::serialize_vector(&search.embedding);
// INSERT OR IGNORE to handle re-imports (UNIQUE constraint on timestamp+query)
diesel::sql_query(
"INSERT OR IGNORE INTO search_history
(timestamp, query, search_engine, embedding, created_at, source_file)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
)
.bind::<diesel::sql_types::BigInt, _>(search.timestamp)
.bind::<diesel::sql_types::Text, _>(&search.query)
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(&search.search_engine)
.bind::<diesel::sql_types::Binary, _>(&embedding_bytes)
.bind::<diesel::sql_types::BigInt, _>(search.created_at)
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(&search.source_file)
.execute(conn.deref_mut())
.map_err(|e| anyhow::anyhow!("Insert error: {:?}", e))?;
let row_id: i32 = diesel::sql_query("SELECT last_insert_rowid() as id")
.get_result::<LastInsertRowId>(conn.deref_mut())
.map(|r| r.id)
.map_err(|e| anyhow::anyhow!("Failed to get last insert ID: {:?}", e))?;
Ok(SearchRecord {
id: row_id,
timestamp: search.timestamp,
query: search.query,
search_engine: search.search_engine,
created_at: search.created_at,
source_file: search.source_file,
})
})
.map_err(|_| DbError::new(DbErrorKind::InsertError))
}
fn store_searches_batch(
&mut self,
context: &opentelemetry::Context,
searches: Vec<InsertSearchRecord>,
) -> Result<usize, DbError> {
trace_db_call(context, "insert", "store_searches_batch", |_span| {
let mut conn = self
.connection
.lock()
.expect("Unable to get SearchHistoryDao");
let mut inserted = 0;
conn.transaction::<_, anyhow::Error, _>(|conn| {
for search in searches {
// Validate embedding (REQUIRED)
if search.embedding.len() != 768 {
log::warn!(
"Skipping search with invalid embedding dimensions: {}",
search.embedding.len()
);
continue;
}
let embedding_bytes = Self::serialize_vector(&search.embedding);
let rows_affected = diesel::sql_query(
"INSERT OR IGNORE INTO search_history
(timestamp, query, search_engine, embedding, created_at, source_file)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
)
.bind::<diesel::sql_types::BigInt, _>(search.timestamp)
.bind::<diesel::sql_types::Text, _>(&search.query)
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(
&search.search_engine,
)
.bind::<diesel::sql_types::Binary, _>(&embedding_bytes)
.bind::<diesel::sql_types::BigInt, _>(search.created_at)
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(
&search.source_file,
)
.execute(conn)
.map_err(|e| anyhow::anyhow!("Batch insert error: {:?}", e))?;
if rows_affected > 0 {
inserted += 1;
}
}
Ok(())
})
.map_err(|e| anyhow::anyhow!("Transaction error: {:?}", e))?;
Ok(inserted)
})
.map_err(|_| DbError::new(DbErrorKind::InsertError))
}
fn find_searches_in_range(
&mut self,
context: &opentelemetry::Context,
start_ts: i64,
end_ts: i64,
) -> Result<Vec<SearchRecord>, DbError> {
trace_db_call(context, "query", "find_searches_in_range", |_span| {
let mut conn = self
.connection
.lock()
.expect("Unable to get SearchHistoryDao");
diesel::sql_query(
"SELECT id, timestamp, query, search_engine, embedding, created_at, source_file
FROM search_history
WHERE timestamp >= ?1 AND timestamp <= ?2
ORDER BY timestamp DESC",
)
.bind::<diesel::sql_types::BigInt, _>(start_ts)
.bind::<diesel::sql_types::BigInt, _>(end_ts)
.load::<SearchRecordWithVectorRow>(conn.deref_mut())
.map(|rows| rows.into_iter().map(|r| r.to_search_record()).collect())
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))
})
.map_err(|_| DbError::new(DbErrorKind::QueryError))
}
fn find_similar_searches(
&mut self,
context: &opentelemetry::Context,
query_embedding: &[f32],
limit: usize,
) -> Result<Vec<SearchRecord>, DbError> {
trace_db_call(context, "query", "find_similar_searches", |_span| {
let mut conn = self
.connection
.lock()
.expect("Unable to get SearchHistoryDao");
if query_embedding.len() != 768 {
return Err(anyhow::anyhow!(
"Invalid query embedding dimensions: {} (expected 768)",
query_embedding.len()
));
}
// Load all searches with embeddings
let results = diesel::sql_query(
"SELECT id, timestamp, query, search_engine, embedding, created_at, source_file
FROM search_history",
)
.load::<SearchRecordWithVectorRow>(conn.deref_mut())
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
// Compute similarities
let mut scored_searches: Vec<(f32, SearchRecord)> = results
.into_iter()
.filter_map(|row| {
if let Ok(emb) = Self::deserialize_vector(&row.embedding) {
let similarity = Self::cosine_similarity(query_embedding, &emb);
Some((similarity, row.to_search_record()))
} else {
None
}
})
.collect();
// Sort by similarity descending
scored_searches
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
log::info!("Found {} similar searches", scored_searches.len());
if !scored_searches.is_empty() {
log::info!(
"Top similarity: {:.4} for query: '{}'",
scored_searches[0].0,
scored_searches[0].1.query
);
}
Ok(scored_searches
.into_iter()
.take(limit)
.map(|(_, search)| search)
.collect())
})
.map_err(|_| DbError::new(DbErrorKind::QueryError))
}
fn find_relevant_searches_hybrid(
&mut self,
context: &opentelemetry::Context,
center_timestamp: i64,
time_window_days: i64,
query_embedding: Option<&[f32]>,
limit: usize,
) -> Result<Vec<SearchRecord>, DbError> {
trace_db_call(context, "query", "find_relevant_searches_hybrid", |_span| {
let window_seconds = time_window_days * 86400;
let start_ts = center_timestamp - window_seconds;
let end_ts = center_timestamp + window_seconds;
let mut conn = self
.connection
.lock()
.expect("Unable to get SearchHistoryDao");
// Step 1: Time-based filter (fast, indexed)
let searches_in_range = diesel::sql_query(
"SELECT id, timestamp, query, search_engine, embedding, created_at, source_file
FROM search_history
WHERE timestamp >= ?1 AND timestamp <= ?2",
)
.bind::<diesel::sql_types::BigInt, _>(start_ts)
.bind::<diesel::sql_types::BigInt, _>(end_ts)
.load::<SearchRecordWithVectorRow>(conn.deref_mut())
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
// Step 2: If query embedding provided, rank by semantic similarity
if let Some(query_emb) = query_embedding {
if query_emb.len() != 768 {
return Err(anyhow::anyhow!(
"Invalid query embedding dimensions: {} (expected 768)",
query_emb.len()
));
}
let mut scored_searches: Vec<(f32, SearchRecord)> = searches_in_range
.into_iter()
.filter_map(|row| {
if let Ok(emb) = Self::deserialize_vector(&row.embedding) {
let similarity = Self::cosine_similarity(query_emb, &emb);
Some((similarity, row.to_search_record()))
} else {
None
}
})
.collect();
// Sort by similarity descending
scored_searches
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
log::info!(
"Hybrid query: {} searches in time range, ranked by similarity",
scored_searches.len()
);
if !scored_searches.is_empty() {
log::info!(
"Top similarity: {:.4} for '{}'",
scored_searches[0].0,
scored_searches[0].1.query
);
}
Ok(scored_searches
.into_iter()
.take(limit)
.map(|(_, search)| search)
.collect())
} else {
// No semantic ranking, just return time-sorted (most recent first)
log::info!(
"Time-only query: {} searches in range",
searches_in_range.len()
);
Ok(searches_in_range
.into_iter()
.take(limit)
.map(|r| r.to_search_record())
.collect())
}
})
.map_err(|_| DbError::new(DbErrorKind::QueryError))
}
fn search_exists(
&mut self,
context: &opentelemetry::Context,
timestamp: i64,
query: &str,
) -> Result<bool, DbError> {
trace_db_call(context, "query", "search_exists", |_span| {
let mut conn = self
.connection
.lock()
.expect("Unable to get SearchHistoryDao");
#[derive(QueryableByName)]
struct CountResult {
#[diesel(sql_type = diesel::sql_types::Integer)]
count: i32,
}
let result: CountResult = diesel::sql_query(
"SELECT COUNT(*) as count FROM search_history WHERE timestamp = ?1 AND query = ?2",
)
.bind::<diesel::sql_types::BigInt, _>(timestamp)
.bind::<diesel::sql_types::Text, _>(query)
.get_result(conn.deref_mut())
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
Ok(result.count > 0)
})
.map_err(|_| DbError::new(DbErrorKind::QueryError))
}
fn get_search_count(&mut self, context: &opentelemetry::Context) -> Result<i64, DbError> {
trace_db_call(context, "query", "get_search_count", |_span| {
let mut conn = self
.connection
.lock()
.expect("Unable to get SearchHistoryDao");
#[derive(QueryableByName)]
struct CountResult {
#[diesel(sql_type = diesel::sql_types::BigInt)]
count: i64,
}
let result: CountResult =
diesel::sql_query("SELECT COUNT(*) as count FROM search_history")
.get_result(conn.deref_mut())
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
Ok(result.count)
})
.map_err(|_| DbError::new(DbErrorKind::QueryError))
}
}