490 lines
18 KiB
Rust
490 lines
18 KiB
Rust
use chrono::NaiveDate;
|
|
use diesel::prelude::*;
|
|
use diesel::sqlite::SqliteConnection;
|
|
use serde::Serialize;
|
|
use std::ops::DerefMut;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use crate::database::{DbError, DbErrorKind, connect};
|
|
use crate::otel::trace_db_call;
|
|
|
|
/// Represents a daily conversation summary
|
|
#[derive(Serialize, Clone, Debug)]
|
|
pub struct DailySummary {
|
|
pub id: i32,
|
|
pub date: String,
|
|
pub contact: String,
|
|
pub summary: String,
|
|
pub message_count: i32,
|
|
pub created_at: i64,
|
|
pub model_version: String,
|
|
}
|
|
|
|
/// Data for inserting a new daily summary
|
|
#[derive(Clone, Debug)]
|
|
pub struct InsertDailySummary {
|
|
pub date: String,
|
|
pub contact: String,
|
|
pub summary: String,
|
|
pub message_count: i32,
|
|
pub embedding: Vec<f32>,
|
|
pub created_at: i64,
|
|
pub model_version: String,
|
|
}
|
|
|
|
pub trait DailySummaryDao: Sync + Send {
|
|
/// Store a daily summary with its embedding
|
|
fn store_summary(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
summary: InsertDailySummary,
|
|
) -> Result<DailySummary, DbError>;
|
|
|
|
/// Find semantically similar daily summaries using vector similarity
|
|
fn find_similar_summaries(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
query_embedding: &[f32],
|
|
limit: usize,
|
|
) -> Result<Vec<DailySummary>, DbError>;
|
|
|
|
/// Find semantically similar daily summaries with time-based weighting
|
|
/// Combines cosine similarity with temporal proximity to target_date
|
|
/// Final score = similarity * time_weight, where time_weight decays with distance from target_date
|
|
fn find_similar_summaries_with_time_weight(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
query_embedding: &[f32],
|
|
target_date: &str,
|
|
limit: usize,
|
|
) -> Result<Vec<DailySummary>, DbError>;
|
|
|
|
/// Check if a summary exists for a given date and contact
|
|
fn summary_exists(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
date: &str,
|
|
contact: &str,
|
|
) -> Result<bool, DbError>;
|
|
|
|
/// Get count of summaries for a contact
|
|
fn get_summary_count(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
contact: &str,
|
|
) -> Result<i64, DbError>;
|
|
}
|
|
|
|
pub struct SqliteDailySummaryDao {
|
|
connection: Arc<Mutex<SqliteConnection>>,
|
|
}
|
|
|
|
impl Default for SqliteDailySummaryDao {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl SqliteDailySummaryDao {
|
|
pub fn new() -> Self {
|
|
SqliteDailySummaryDao {
|
|
connection: Arc::new(Mutex::new(connect())),
|
|
}
|
|
}
|
|
|
|
fn serialize_vector(vec: &[f32]) -> Vec<u8> {
|
|
use zerocopy::IntoBytes;
|
|
vec.as_bytes().to_vec()
|
|
}
|
|
|
|
fn deserialize_vector(bytes: &[u8]) -> Result<Vec<f32>, DbError> {
|
|
if !bytes.len().is_multiple_of(4) {
|
|
return Err(DbError::new(DbErrorKind::QueryError));
|
|
}
|
|
|
|
let count = bytes.len() / 4;
|
|
let mut vec = Vec::with_capacity(count);
|
|
|
|
for chunk in bytes.chunks_exact(4) {
|
|
let float = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
|
|
vec.push(float);
|
|
}
|
|
|
|
Ok(vec)
|
|
}
|
|
|
|
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
if a.len() != b.len() {
|
|
return 0.0;
|
|
}
|
|
|
|
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
|
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
if magnitude_a == 0.0 || magnitude_b == 0.0 {
|
|
return 0.0;
|
|
}
|
|
|
|
dot_product / (magnitude_a * magnitude_b)
|
|
}
|
|
}
|
|
|
|
impl DailySummaryDao for SqliteDailySummaryDao {
|
|
fn store_summary(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
summary: InsertDailySummary,
|
|
) -> Result<DailySummary, DbError> {
|
|
trace_db_call(context, "insert", "store_summary", |_span| {
|
|
let mut conn = self
|
|
.connection
|
|
.lock()
|
|
.expect("Unable to get DailySummaryDao");
|
|
|
|
// Validate embedding dimensions
|
|
if summary.embedding.len() != 768 {
|
|
return Err(anyhow::anyhow!(
|
|
"Invalid embedding dimensions: {} (expected 768)",
|
|
summary.embedding.len()
|
|
));
|
|
}
|
|
|
|
let embedding_bytes = Self::serialize_vector(&summary.embedding);
|
|
|
|
// INSERT OR REPLACE to handle updates if summary needs regeneration
|
|
diesel::sql_query(
|
|
"INSERT OR REPLACE INTO daily_conversation_summaries
|
|
(date, contact, summary, message_count, embedding, created_at, model_version)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
|
|
)
|
|
.bind::<diesel::sql_types::Text, _>(&summary.date)
|
|
.bind::<diesel::sql_types::Text, _>(&summary.contact)
|
|
.bind::<diesel::sql_types::Text, _>(&summary.summary)
|
|
.bind::<diesel::sql_types::Integer, _>(summary.message_count)
|
|
.bind::<diesel::sql_types::Binary, _>(&embedding_bytes)
|
|
.bind::<diesel::sql_types::BigInt, _>(summary.created_at)
|
|
.bind::<diesel::sql_types::Text, _>(&summary.model_version)
|
|
.execute(conn.deref_mut())
|
|
.map_err(|e| anyhow::anyhow!("Insert error: {:?}", e))?;
|
|
|
|
let row_id: i32 = diesel::sql_query("SELECT last_insert_rowid() as id")
|
|
.get_result::<LastInsertRowId>(conn.deref_mut())
|
|
.map(|r| r.id as i32)
|
|
.map_err(|e| anyhow::anyhow!("Failed to get last insert ID: {:?}", e))?;
|
|
|
|
Ok(DailySummary {
|
|
id: row_id,
|
|
date: summary.date,
|
|
contact: summary.contact,
|
|
summary: summary.summary,
|
|
message_count: summary.message_count,
|
|
created_at: summary.created_at,
|
|
model_version: summary.model_version,
|
|
})
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::InsertError))
|
|
}
|
|
|
|
fn find_similar_summaries(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
query_embedding: &[f32],
|
|
limit: usize,
|
|
) -> Result<Vec<DailySummary>, DbError> {
|
|
trace_db_call(context, "query", "find_similar_summaries", |_span| {
|
|
let mut conn = self.connection.lock().expect("Unable to get DailySummaryDao");
|
|
|
|
if query_embedding.len() != 768 {
|
|
return Err(anyhow::anyhow!(
|
|
"Invalid query embedding dimensions: {} (expected 768)",
|
|
query_embedding.len()
|
|
));
|
|
}
|
|
|
|
// Load all summaries with embeddings
|
|
let results = diesel::sql_query(
|
|
"SELECT id, date, contact, summary, message_count, embedding, created_at, model_version
|
|
FROM daily_conversation_summaries"
|
|
)
|
|
.load::<DailySummaryWithVectorRow>(conn.deref_mut())
|
|
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
|
|
|
|
log::info!("Loaded {} daily summaries for similarity comparison", results.len());
|
|
|
|
// Compute similarity for each summary
|
|
let mut scored_summaries: Vec<(f32, DailySummary)> = results
|
|
.into_iter()
|
|
.filter_map(|row| {
|
|
match Self::deserialize_vector(&row.embedding) {
|
|
Ok(embedding) => {
|
|
let similarity = Self::cosine_similarity(query_embedding, &embedding);
|
|
Some((
|
|
similarity,
|
|
DailySummary {
|
|
id: row.id,
|
|
date: row.date,
|
|
contact: row.contact,
|
|
summary: row.summary,
|
|
message_count: row.message_count,
|
|
created_at: row.created_at,
|
|
model_version: row.model_version,
|
|
},
|
|
))
|
|
}
|
|
Err(e) => {
|
|
log::warn!("Failed to deserialize embedding for summary {}: {:?}", row.id, e);
|
|
None
|
|
}
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
// Sort by similarity (highest first)
|
|
scored_summaries.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
// Filter out poor matches (similarity < 0.3 is likely noise)
|
|
scored_summaries.retain(|(similarity, _)| *similarity >= 0.3);
|
|
|
|
// Log similarity distribution
|
|
if !scored_summaries.is_empty() {
|
|
let top_score = scored_summaries.first().map(|(s, _)| *s).unwrap_or(0.0);
|
|
let median_score = scored_summaries.get(scored_summaries.len() / 2).map(|(s, _)| *s).unwrap_or(0.0);
|
|
|
|
log::info!(
|
|
"Daily summary similarity - Top: {:.3}, Median: {:.3}, Count: {} (after 0.3 threshold)",
|
|
top_score,
|
|
median_score,
|
|
scored_summaries.len()
|
|
);
|
|
} else {
|
|
log::warn!("No daily summaries met the 0.3 similarity threshold");
|
|
}
|
|
|
|
// Take top N and log matches
|
|
let top_results: Vec<DailySummary> = scored_summaries
|
|
.into_iter()
|
|
.take(limit)
|
|
.map(|(similarity, summary)| {
|
|
log::info!(
|
|
"Summary match: similarity={:.3}, date={}, contact={}, summary=\"{}\"",
|
|
similarity,
|
|
summary.date,
|
|
summary.contact,
|
|
summary.summary.chars().take(100).collect::<String>()
|
|
);
|
|
summary
|
|
})
|
|
.collect();
|
|
|
|
Ok(top_results)
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
|
|
fn find_similar_summaries_with_time_weight(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
query_embedding: &[f32],
|
|
target_date: &str,
|
|
limit: usize,
|
|
) -> Result<Vec<DailySummary>, DbError> {
|
|
trace_db_call(context, "query", "find_similar_summaries_with_time_weight", |_span| {
|
|
let mut conn = self.connection.lock().expect("Unable to get DailySummaryDao");
|
|
|
|
if query_embedding.len() != 768 {
|
|
return Err(anyhow::anyhow!(
|
|
"Invalid query embedding dimensions: {} (expected 768)",
|
|
query_embedding.len()
|
|
));
|
|
}
|
|
|
|
// Parse target date
|
|
let target = NaiveDate::parse_from_str(target_date, "%Y-%m-%d")
|
|
.map_err(|e| anyhow::anyhow!("Invalid target date: {}", e))?;
|
|
|
|
// Load all summaries with embeddings
|
|
let results = diesel::sql_query(
|
|
"SELECT id, date, contact, summary, message_count, embedding, created_at, model_version
|
|
FROM daily_conversation_summaries"
|
|
)
|
|
.load::<DailySummaryWithVectorRow>(conn.deref_mut())
|
|
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?;
|
|
|
|
log::info!("Loaded {} daily summaries for time-weighted similarity (target: {})", results.len(), target_date);
|
|
|
|
// Compute time-weighted similarity for each summary
|
|
// Score = cosine_similarity * time_weight
|
|
// time_weight = 1 / (1 + days_distance/30) - decays with ~30 day half-life
|
|
let mut scored_summaries: Vec<(f32, f32, i64, DailySummary)> = results
|
|
.into_iter()
|
|
.filter_map(|row| {
|
|
match Self::deserialize_vector(&row.embedding) {
|
|
Ok(embedding) => {
|
|
let similarity = Self::cosine_similarity(query_embedding, &embedding);
|
|
|
|
// Calculate time weight
|
|
let summary_date = NaiveDate::parse_from_str(&row.date, "%Y-%m-%d").ok()?;
|
|
let days_distance = (target - summary_date).num_days().abs();
|
|
|
|
// Exponential decay with 30-day half-life
|
|
// At 0 days: weight = 1.0
|
|
// At 30 days: weight = 0.5
|
|
// At 60 days: weight = 0.25
|
|
// At 365 days: weight ~= 0.0001
|
|
let time_weight = 0.5_f32.powf(days_distance as f32 / 30.0);
|
|
|
|
// Combined score - but ensure semantic similarity still matters
|
|
// We use sqrt to soften the time weight's impact
|
|
let combined_score = similarity * time_weight.sqrt();
|
|
|
|
Some((
|
|
combined_score,
|
|
similarity,
|
|
days_distance,
|
|
DailySummary {
|
|
id: row.id,
|
|
date: row.date,
|
|
contact: row.contact,
|
|
summary: row.summary,
|
|
message_count: row.message_count,
|
|
created_at: row.created_at,
|
|
model_version: row.model_version,
|
|
},
|
|
))
|
|
}
|
|
Err(e) => {
|
|
log::warn!("Failed to deserialize embedding for summary {}: {:?}", row.id, e);
|
|
None
|
|
}
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
// Sort by combined score (highest first)
|
|
scored_summaries.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
// Filter out poor matches (base similarity < 0.5 - stricter than before since we have time weighting)
|
|
scored_summaries.retain(|(_, similarity, _, _)| *similarity >= 0.5);
|
|
|
|
// Log similarity distribution
|
|
if !scored_summaries.is_empty() {
|
|
let (top_combined, top_sim, top_days, _) = &scored_summaries[0];
|
|
log::info!(
|
|
"Time-weighted similarity - Top: combined={:.3} (sim={:.3}, days={}), Count: {} matches",
|
|
top_combined,
|
|
top_sim,
|
|
top_days,
|
|
scored_summaries.len()
|
|
);
|
|
} else {
|
|
log::warn!("No daily summaries met the 0.5 similarity threshold");
|
|
}
|
|
|
|
// Take top N and log matches
|
|
let top_results: Vec<DailySummary> = scored_summaries
|
|
.into_iter()
|
|
.take(limit)
|
|
.map(|(combined, similarity, days, summary)| {
|
|
log::info!(
|
|
"Summary match: combined={:.3} (sim={:.3}, days={}), date={}, contact={}, summary=\"{}\"",
|
|
combined,
|
|
similarity,
|
|
days,
|
|
summary.date,
|
|
summary.contact,
|
|
summary.summary.chars().take(80).collect::<String>()
|
|
);
|
|
summary
|
|
})
|
|
.collect();
|
|
|
|
Ok(top_results)
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
|
|
fn summary_exists(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
date: &str,
|
|
contact: &str,
|
|
) -> Result<bool, DbError> {
|
|
trace_db_call(context, "query", "summary_exists", |_span| {
|
|
let mut conn = self
|
|
.connection
|
|
.lock()
|
|
.expect("Unable to get DailySummaryDao");
|
|
|
|
let count = diesel::sql_query(
|
|
"SELECT COUNT(*) as count FROM daily_conversation_summaries
|
|
WHERE date = ?1 AND contact = ?2",
|
|
)
|
|
.bind::<diesel::sql_types::Text, _>(date)
|
|
.bind::<diesel::sql_types::Text, _>(contact)
|
|
.get_result::<CountResult>(conn.deref_mut())
|
|
.map(|r| r.count)
|
|
.map_err(|e| anyhow::anyhow!("Count query error: {:?}", e))?;
|
|
|
|
Ok(count > 0)
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
|
|
fn get_summary_count(
|
|
&mut self,
|
|
context: &opentelemetry::Context,
|
|
contact: &str,
|
|
) -> Result<i64, DbError> {
|
|
trace_db_call(context, "query", "get_summary_count", |_span| {
|
|
let mut conn = self
|
|
.connection
|
|
.lock()
|
|
.expect("Unable to get DailySummaryDao");
|
|
|
|
diesel::sql_query(
|
|
"SELECT COUNT(*) as count FROM daily_conversation_summaries WHERE contact = ?1",
|
|
)
|
|
.bind::<diesel::sql_types::Text, _>(contact)
|
|
.get_result::<CountResult>(conn.deref_mut())
|
|
.map(|r| r.count)
|
|
.map_err(|e| anyhow::anyhow!("Count query error: {:?}", e))
|
|
})
|
|
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
|
}
|
|
}
|
|
|
|
// Helper structs for raw SQL queries
|
|
|
|
#[derive(QueryableByName)]
|
|
struct LastInsertRowId {
|
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
|
id: i64,
|
|
}
|
|
|
|
#[derive(QueryableByName)]
|
|
struct DailySummaryWithVectorRow {
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
|
id: i32,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
date: String,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
contact: String,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
summary: String,
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
|
message_count: i32,
|
|
#[diesel(sql_type = diesel::sql_types::Binary)]
|
|
embedding: Vec<u8>,
|
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
|
created_at: i64,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
model_version: String,
|
|
}
|
|
|
|
#[derive(QueryableByName)]
|
|
struct CountResult {
|
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
|
count: i64,
|
|
}
|