552 lines
21 KiB
Rust
552 lines
21 KiB
Rust
use diesel::prelude::*;
|
|
use diesel::sqlite::SqliteConnection;
|
|
use serde::Serialize;
|
|
use std::ops::DerefMut;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use crate::database::{DbError, DbErrorKind, connect};
|
|
use crate::otel::trace_db_call;
|
|
|
|
/// Represents a calendar event
|
|
#[derive(Serialize, Clone, Debug)]
|
|
pub struct CalendarEvent {
|
|
pub id: i32,
|
|
pub event_uid: Option<String>,
|
|
pub summary: String,
|
|
pub description: Option<String>,
|
|
pub location: Option<String>,
|
|
pub start_time: i64,
|
|
pub end_time: i64,
|
|
pub all_day: bool,
|
|
pub organizer: Option<String>,
|
|
pub attendees: Option<String>, // JSON string
|
|
pub created_at: i64,
|
|
pub source_file: Option<String>,
|
|
}
|
|
|
|
/// Data for inserting a new calendar event
|
|
#[derive(Clone, Debug)]
|
|
pub struct InsertCalendarEvent {
|
|
pub event_uid: Option<String>,
|
|
pub summary: String,
|
|
pub description: Option<String>,
|
|
pub location: Option<String>,
|
|
pub start_time: i64,
|
|
pub end_time: i64,
|
|
pub all_day: bool,
|
|
pub organizer: Option<String>,
|
|
pub attendees: Option<String>,
|
|
pub embedding: Option<Vec<f32>>, // 768-dim, optional
|
|
pub created_at: i64,
|
|
pub source_file: Option<String>,
|
|
}
|
|
|
|
pub trait CalendarEventDao: Sync + Send {
|
|
/// Store calendar event with optional embedding
|
|
fn store_event(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
event: InsertCalendarEvent,
|
|
) -> Result<CalendarEvent, DbError>;
|
|
|
|
/// Batch insert events (for import efficiency)
|
|
fn store_events_batch(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
events: Vec<InsertCalendarEvent>,
|
|
) -> Result<usize, DbError>;
|
|
|
|
/// Find events in time range (PRIMARY query method)
|
|
fn find_events_in_range(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
start_ts: i64,
|
|
end_ts: i64,
|
|
) -> Result<Vec<CalendarEvent>, DbError>;
|
|
|
|
/// Find semantically similar events (SECONDARY - requires embeddings)
|
|
fn find_similar_events(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
query_embedding: &[f32],
|
|
limit: usize,
|
|
) -> Result<Vec<CalendarEvent>, DbError>;
|
|
|
|
/// Hybrid: Time-filtered + semantic ranking
|
|
/// "Events during photo timestamp ±N days, ranked by similarity to context"
|
|
fn find_relevant_events_hybrid(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
center_timestamp: i64,
|
|
time_window_days: i64,
|
|
query_embedding: Option<&[f32]>,
|
|
limit: usize,
|
|
) -> Result<Vec<CalendarEvent>, DbError>;
|
|
|
|
/// Check if event exists (idempotency)
|
|
fn event_exists(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
event_uid: &str,
|
|
start_time: i64,
|
|
) -> Result<bool, DbError>;
|
|
|
|
/// Get count of events
|
|
fn get_event_count(&mut self, context: &opentelemetry::Context) -> Result<i64, DbError>;
|
|
}
|
|
|
|
pub struct SqliteCalendarEventDao {
|
|
connection: Arc<Mutex<SqliteConnection>>,
|
|
}
|
|
|
|
impl Default for SqliteCalendarEventDao {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl SqliteCalendarEventDao {
|
|
pub fn new() -> Self {
|
|
SqliteCalendarEventDao {
|
|
connection: Arc::new(Mutex::new(connect())),
|
|
}
|
|
}
|
|
|
|
fn serialize_vector(vec: &[f32]) -> Vec<u8> {
|
|
use zerocopy::IntoBytes;
|
|
vec.as_bytes().to_vec()
|
|
}
|
|
|
|
fn deserialize_vector(bytes: &[u8]) -> Result<Vec<f32>, DbError> {
|
|
if !bytes.len().is_multiple_of(4) {
|
|
return Err(DbError::new(DbErrorKind::QueryError));
|
|
}
|
|
|
|
let count = bytes.len() / 4;
|
|
let mut vec = Vec::with_capacity(count);
|
|
|
|
for chunk in bytes.chunks_exact(4) {
|
|
let float = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
|
|
vec.push(float);
|
|
}
|
|
|
|
Ok(vec)
|
|
}
|
|
|
|
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
if a.len() != b.len() {
|
|
return 0.0;
|
|
}
|
|
|
|
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
|
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
if magnitude_a == 0.0 || magnitude_b == 0.0 {
|
|
return 0.0;
|
|
}
|
|
|
|
dot_product / (magnitude_a * magnitude_b)
|
|
}
|
|
}
|
|
|
|
#[derive(QueryableByName)]
|
|
struct CalendarEventWithVectorRow {
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
|
id: i32,
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
|
event_uid: Option<String>,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
summary: String,
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
|
description: Option<String>,
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
|
location: Option<String>,
|
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
|
start_time: i64,
|
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
|
end_time: i64,
|
|
#[diesel(sql_type = diesel::sql_types::Bool)]
|
|
all_day: bool,
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
|
organizer: Option<String>,
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
|
attendees: Option<String>,
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Binary>)]
|
|
embedding: Option<Vec<u8>>,
|
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
|
created_at: i64,
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
|
source_file: Option<String>,
|
|
}
|
|
|
|
impl CalendarEventWithVectorRow {
|
|
fn to_calendar_event(&self) -> CalendarEvent {
|
|
CalendarEvent {
|
|
id: self.id,
|
|
event_uid: self.event_uid.clone(),
|
|
summary: self.summary.clone(),
|
|
description: self.description.clone(),
|
|
location: self.location.clone(),
|
|
start_time: self.start_time,
|
|
end_time: self.end_time,
|
|
all_day: self.all_day,
|
|
organizer: self.organizer.clone(),
|
|
attendees: self.attendees.clone(),
|
|
created_at: self.created_at,
|
|
source_file: self.source_file.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(QueryableByName)]
|
|
struct LastInsertRowId {
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
|
id: i32,
|
|
}
|
|
|
|
impl CalendarEventDao for SqliteCalendarEventDao {
|
|
fn store_event(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
event: InsertCalendarEvent,
|
|
) -> Result<CalendarEvent, DbError> {
|
|
trace_db_call(context, "insert", "store_event", |_span| {
|
|
let mut conn = self
|
|
.connection
|
|
.lock()
|
|
.expect("Unable to get CalendarEventDao");
|
|
|
|
// Validate embedding dimensions if provided
|
|
if let Some(ref emb) = event.embedding
|
|
&& emb.len() != 768 {
|
|
return Err(anyhow::anyhow!(
|
|
"Invalid embedding dimensions: {} (expected 768)",
|
|
emb.len()
|
|
));
|
|
}
|
|
|
|
let embedding_bytes = event.embedding.as_ref().map(|e| Self::serialize_vector(e));
|
|
|
|
// INSERT OR REPLACE to handle re-imports
|
|
diesel::sql_query(
|
|
"INSERT OR REPLACE INTO calendar_events
|
|
(event_uid, summary, description, location, start_time, end_time, all_day,
|
|
organizer, attendees, embedding, created_at, source_file)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
|
|
)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(&event.event_uid)
|
|
.bind::<diesel::sql_types::Text, _>(&event.summary)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(&event.description)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(&event.location)
|
|
.bind::<diesel::sql_types::BigInt, _>(event.start_time)
|
|
.bind::<diesel::sql_types::BigInt, _>(event.end_time)
|
|
.bind::<diesel::sql_types::Bool, _>(event.all_day)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(&event.organizer)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(&event.attendees)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Binary>, _>(&embedding_bytes)
|
|
.bind::<diesel::sql_types::BigInt, _>(event.created_at)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(&event.source_file)
|
|
.execute(conn.deref_mut())
|
|
.map_err(|e| anyhow::anyhow!("Insert error: {:?}", e))?;
|
|
|
|
let row_id: i32 = diesel::sql_query("SELECT last_insert_rowid() as id")
|
|
.get_result::<LastInsertRowId>(conn.deref_mut())
|
|
.map(|r| r.id)
|
|
.map_err(|e| anyhow::anyhow!("Failed to get last insert ID: {:?}", e))?;
|
|
|
|
Ok(CalendarEvent {
|
|
id: row_id,
|
|
event_uid: event.event_uid,
|
|
summary: event.summary,
|
|
description: event.description,
|
|
location: event.location,
|
|
start_time: event.start_time,
|
|
end_time: event.end_time,
|
|
all_day: event.all_day,
|
|
organizer: event.organizer,
|
|
attendees: event.attendees,
|
|
created_at: event.created_at,
|
|
source_file: event.source_file,
|
|
})
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::InsertError))
|
|
}
|
|
|
|
fn store_events_batch(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
events: Vec<InsertCalendarEvent>,
|
|
) -> Result<usize, DbError> {
|
|
trace_db_call(context, "insert", "store_events_batch", |_span| {
|
|
let mut conn = self
|
|
.connection
|
|
.lock()
|
|
.expect("Unable to get CalendarEventDao");
|
|
let mut inserted = 0;
|
|
|
|
conn.transaction::<_, anyhow::Error, _>(|conn| {
|
|
for event in events {
|
|
// Validate embedding if provided
|
|
if let Some(ref emb) = event.embedding
|
|
&& emb.len() != 768 {
|
|
log::warn!(
|
|
"Skipping event with invalid embedding dimensions: {}",
|
|
emb.len()
|
|
);
|
|
continue;
|
|
}
|
|
|
|
let embedding_bytes =
|
|
event.embedding.as_ref().map(|e| Self::serialize_vector(e));
|
|
|
|
diesel::sql_query(
|
|
"INSERT OR REPLACE INTO calendar_events
|
|
(event_uid, summary, description, location, start_time, end_time, all_day,
|
|
organizer, attendees, embedding, created_at, source_file)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
|
|
)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(
|
|
&event.event_uid,
|
|
)
|
|
.bind::<diesel::sql_types::Text, _>(&event.summary)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(
|
|
&event.description,
|
|
)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(
|
|
&event.location,
|
|
)
|
|
.bind::<diesel::sql_types::BigInt, _>(event.start_time)
|
|
.bind::<diesel::sql_types::BigInt, _>(event.end_time)
|
|
.bind::<diesel::sql_types::Bool, _>(event.all_day)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(
|
|
&event.organizer,
|
|
)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(
|
|
&event.attendees,
|
|
)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Binary>, _>(
|
|
&embedding_bytes,
|
|
)
|
|
.bind::<diesel::sql_types::BigInt, _>(event.created_at)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(
|
|
&event.source_file,
|
|
)
|
|
.execute(conn)
|
|
.map_err(|e| anyhow::anyhow!("Batch insert error: {:?}", e))?;
|
|
|
|
inserted += 1;
|
|
}
|
|
Ok(())
|
|
})
|
|
.map_err(|e| anyhow::anyhow!("Transaction error: {:?}", e))?;
|
|
|
|
Ok(inserted)
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::InsertError))
|
|
}
|
|
|
|
fn find_events_in_range(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
start_ts: i64,
|
|
end_ts: i64,
|
|
) -> Result<Vec<CalendarEvent>, DbError> {
|
|
trace_db_call(context, "query", "find_events_in_range", |_span| {
|
|
let mut conn = self.connection.lock().expect("Unable to get CalendarEventDao");
|
|
|
|
diesel::sql_query(
|
|
"SELECT id, event_uid, summary, description, location, start_time, end_time, all_day,
|
|
organizer, attendees, NULL as embedding, created_at, source_file
|
|
FROM calendar_events
|
|
WHERE start_time >= ?1 AND start_time <= ?2
|
|
ORDER BY start_time ASC"
|
|
)
|
|
.bind::<diesel::sql_types::BigInt, _>(start_ts)
|
|
.bind::<diesel::sql_types::BigInt, _>(end_ts)
|
|
.load::<CalendarEventWithVectorRow>(conn.deref_mut())
|
|
.map(|rows| rows.into_iter().map(|r| r.to_calendar_event()).collect())
|
|
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
|
|
fn find_similar_events(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
query_embedding: &[f32],
|
|
limit: usize,
|
|
) -> Result<Vec<CalendarEvent>, DbError> {
|
|
trace_db_call(context, "query", "find_similar_events", |_span| {
|
|
let mut conn = self.connection.lock().expect("Unable to get CalendarEventDao");
|
|
|
|
if query_embedding.len() != 768 {
|
|
return Err(anyhow::anyhow!(
|
|
"Invalid query embedding dimensions: {} (expected 768)",
|
|
query_embedding.len()
|
|
));
|
|
}
|
|
|
|
// Load all events with embeddings
|
|
let results = diesel::sql_query(
|
|
"SELECT id, event_uid, summary, description, location, start_time, end_time, all_day,
|
|
organizer, attendees, embedding, created_at, source_file
|
|
FROM calendar_events
|
|
WHERE embedding IS NOT NULL"
|
|
)
|
|
.load::<CalendarEventWithVectorRow>(conn.deref_mut())
|
|
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
|
|
|
|
// Compute similarities
|
|
let mut scored_events: Vec<(f32, CalendarEvent)> = results
|
|
.into_iter()
|
|
.filter_map(|row| {
|
|
if let Some(ref emb_bytes) = row.embedding {
|
|
if let Ok(emb) = Self::deserialize_vector(emb_bytes) {
|
|
let similarity = Self::cosine_similarity(query_embedding, &emb);
|
|
Some((similarity, row.to_calendar_event()))
|
|
} else {
|
|
None
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
// Sort by similarity descending
|
|
scored_events.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
log::info!("Found {} similar calendar events", scored_events.len());
|
|
if !scored_events.is_empty() {
|
|
log::info!("Top similarity: {:.4}", scored_events[0].0);
|
|
}
|
|
|
|
Ok(scored_events.into_iter().take(limit).map(|(_, event)| event).collect())
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
|
|
fn find_relevant_events_hybrid(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
center_timestamp: i64,
|
|
time_window_days: i64,
|
|
query_embedding: Option<&[f32]>,
|
|
limit: usize,
|
|
) -> Result<Vec<CalendarEvent>, DbError> {
|
|
trace_db_call(context, "query", "find_relevant_events_hybrid", |_span| {
|
|
let window_seconds = time_window_days * 86400;
|
|
let start_ts = center_timestamp - window_seconds;
|
|
let end_ts = center_timestamp + window_seconds;
|
|
|
|
let mut conn = self.connection.lock().expect("Unable to get CalendarEventDao");
|
|
|
|
// Step 1: Time-based filter (fast, indexed)
|
|
let events_in_range = diesel::sql_query(
|
|
"SELECT id, event_uid, summary, description, location, start_time, end_time, all_day,
|
|
organizer, attendees, embedding, created_at, source_file
|
|
FROM calendar_events
|
|
WHERE start_time >= ?1 AND start_time <= ?2"
|
|
)
|
|
.bind::<diesel::sql_types::BigInt, _>(start_ts)
|
|
.bind::<diesel::sql_types::BigInt, _>(end_ts)
|
|
.load::<CalendarEventWithVectorRow>(conn.deref_mut())
|
|
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
|
|
|
|
// Step 2: If query embedding provided, rank by semantic similarity
|
|
if let Some(query_emb) = query_embedding {
|
|
if query_emb.len() != 768 {
|
|
return Err(anyhow::anyhow!(
|
|
"Invalid query embedding dimensions: {} (expected 768)",
|
|
query_emb.len()
|
|
));
|
|
}
|
|
|
|
let mut scored_events: Vec<(f32, CalendarEvent)> = events_in_range
|
|
.into_iter()
|
|
.map(|row| {
|
|
// Events with embeddings get semantic scoring
|
|
let similarity = if let Some(ref emb_bytes) = row.embedding {
|
|
if let Ok(emb) = Self::deserialize_vector(emb_bytes) {
|
|
Self::cosine_similarity(query_emb, &emb)
|
|
} else {
|
|
0.5 // Neutral score for deserialization errors
|
|
}
|
|
} else {
|
|
0.5 // Neutral score for events without embeddings
|
|
};
|
|
(similarity, row.to_calendar_event())
|
|
})
|
|
.collect();
|
|
|
|
// Sort by similarity descending
|
|
scored_events.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
log::info!("Hybrid query: {} events in time range, ranked by similarity", scored_events.len());
|
|
if !scored_events.is_empty() {
|
|
log::info!("Top similarity: {:.4}", scored_events[0].0);
|
|
}
|
|
|
|
Ok(scored_events.into_iter().take(limit).map(|(_, event)| event).collect())
|
|
} else {
|
|
// No semantic ranking, just return time-sorted (limit applied)
|
|
log::info!("Time-only query: {} events in range", events_in_range.len());
|
|
Ok(events_in_range.into_iter().take(limit).map(|r| r.to_calendar_event()).collect())
|
|
}
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
|
|
fn event_exists(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
event_uid: &str,
|
|
start_time: i64,
|
|
) -> Result<bool, DbError> {
|
|
trace_db_call(context, "query", "event_exists", |_span| {
|
|
let mut conn = self.connection.lock().expect("Unable to get CalendarEventDao");
|
|
|
|
#[derive(QueryableByName)]
|
|
struct CountResult {
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
|
count: i32,
|
|
}
|
|
|
|
let result: CountResult = diesel::sql_query(
|
|
"SELECT COUNT(*) as count FROM calendar_events WHERE event_uid = ?1 AND start_time = ?2"
|
|
)
|
|
.bind::<diesel::sql_types::Text, _>(event_uid)
|
|
.bind::<diesel::sql_types::BigInt, _>(start_time)
|
|
.get_result(conn.deref_mut())
|
|
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
|
|
|
|
Ok(result.count > 0)
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
|
|
fn get_event_count(&mut self, context: &opentelemetry::Context) -> Result<i64, DbError> {
|
|
trace_db_call(context, "query", "get_event_count", |_span| {
|
|
let mut conn = self
|
|
.connection
|
|
.lock()
|
|
.expect("Unable to get CalendarEventDao");
|
|
|
|
#[derive(QueryableByName)]
|
|
struct CountResult {
|
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
|
count: i64,
|
|
}
|
|
|
|
let result: CountResult =
|
|
diesel::sql_query("SELECT COUNT(*) as count FROM calendar_events")
|
|
.get_result(conn.deref_mut())
|
|
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
|
|
|
|
Ok(result.count)
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
}
|