Cleanup unused message embedding code
Fixup some warnings
This commit is contained in:
@@ -26,6 +26,7 @@ pub struct CalendarEvent {
|
||||
|
||||
/// Data for inserting a new calendar event
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct InsertCalendarEvent {
|
||||
pub event_uid: Option<String>,
|
||||
pub summary: String,
|
||||
@@ -219,12 +220,13 @@ impl CalendarEventDao for SqliteCalendarEventDao {
|
||||
|
||||
// Validate embedding dimensions if provided
|
||||
if let Some(ref emb) = event.embedding
|
||||
&& emb.len() != 768 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid embedding dimensions: {} (expected 768)",
|
||||
emb.len()
|
||||
));
|
||||
}
|
||||
&& emb.len() != 768
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid embedding dimensions: {} (expected 768)",
|
||||
emb.len()
|
||||
));
|
||||
}
|
||||
|
||||
let embedding_bytes = event.embedding.as_ref().map(|e| Self::serialize_vector(e));
|
||||
|
||||
@@ -289,13 +291,14 @@ impl CalendarEventDao for SqliteCalendarEventDao {
|
||||
for event in events {
|
||||
// Validate embedding if provided
|
||||
if let Some(ref emb) = event.embedding
|
||||
&& emb.len() != 768 {
|
||||
log::warn!(
|
||||
"Skipping event with invalid embedding dimensions: {}",
|
||||
emb.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
&& emb.len() != 768
|
||||
{
|
||||
log::warn!(
|
||||
"Skipping event with invalid embedding dimensions: {}",
|
||||
emb.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let embedding_bytes =
|
||||
event.embedding.as_ref().map(|e| Self::serialize_vector(e));
|
||||
|
||||
@@ -1,569 +0,0 @@
|
||||
use diesel::prelude::*;
|
||||
use diesel::sqlite::SqliteConnection;
|
||||
use serde::Serialize;
|
||||
use std::ops::DerefMut;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::database::{DbError, DbErrorKind, connect};
|
||||
use crate::otel::trace_db_call;
|
||||
|
||||
/// Represents a stored message embedding
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
pub struct MessageEmbedding {
|
||||
pub id: i32,
|
||||
pub contact: String,
|
||||
pub body: String,
|
||||
pub timestamp: i64,
|
||||
pub is_sent: bool,
|
||||
pub created_at: i64,
|
||||
pub model_version: String,
|
||||
}
|
||||
|
||||
/// Data for inserting a new message embedding
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct InsertMessageEmbedding {
|
||||
pub contact: String,
|
||||
pub body: String,
|
||||
pub timestamp: i64,
|
||||
pub is_sent: bool,
|
||||
pub embedding: Vec<f32>,
|
||||
pub created_at: i64,
|
||||
pub model_version: String,
|
||||
}
|
||||
|
||||
pub trait EmbeddingDao: Sync + Send {
|
||||
/// Store a message with its embedding vector
|
||||
fn store_message_embedding(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
message: InsertMessageEmbedding,
|
||||
) -> Result<MessageEmbedding, DbError>;
|
||||
|
||||
/// Store multiple messages with embeddings in a single transaction
|
||||
/// Returns the number of successfully stored messages
|
||||
fn store_message_embeddings_batch(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
messages: Vec<InsertMessageEmbedding>,
|
||||
) -> Result<usize, DbError>;
|
||||
|
||||
/// Find semantically similar messages using vector similarity search
|
||||
/// Returns the top `limit` most similar messages
|
||||
/// If contact_filter is provided, only return messages from that contact
|
||||
/// Otherwise, search across all contacts for cross-perspective context
|
||||
fn find_similar_messages(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
contact_filter: Option<&str>,
|
||||
) -> Result<Vec<MessageEmbedding>, DbError>;
|
||||
|
||||
/// Get the count of embedded messages for a specific contact
|
||||
fn get_message_count(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
contact: &str,
|
||||
) -> Result<i64, DbError>;
|
||||
|
||||
/// Check if embeddings exist for a contact (idempotency check)
|
||||
fn has_embeddings_for_contact(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
contact: &str,
|
||||
) -> Result<bool, DbError>;
|
||||
|
||||
/// Check if a specific message already has an embedding
|
||||
fn message_exists(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
contact: &str,
|
||||
body: &str,
|
||||
timestamp: i64,
|
||||
) -> Result<bool, DbError>;
|
||||
}
|
||||
|
||||
pub struct SqliteEmbeddingDao {
|
||||
connection: Arc<Mutex<SqliteConnection>>,
|
||||
}
|
||||
|
||||
impl Default for SqliteEmbeddingDao {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SqliteEmbeddingDao {
|
||||
pub fn new() -> Self {
|
||||
SqliteEmbeddingDao {
|
||||
connection: Arc::new(Mutex::new(connect())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize f32 vector to bytes for BLOB storage
|
||||
fn serialize_vector(vec: &[f32]) -> Vec<u8> {
|
||||
// Convert f32 slice to bytes using zerocopy
|
||||
use zerocopy::IntoBytes;
|
||||
vec.as_bytes().to_vec()
|
||||
}
|
||||
|
||||
/// Deserialize bytes from BLOB back to f32 vector
|
||||
fn deserialize_vector(bytes: &[u8]) -> Result<Vec<f32>, DbError> {
|
||||
if !bytes.len().is_multiple_of(4) {
|
||||
return Err(DbError::new(DbErrorKind::QueryError));
|
||||
}
|
||||
|
||||
let count = bytes.len() / 4;
|
||||
let mut vec = Vec::with_capacity(count);
|
||||
|
||||
for chunk in bytes.chunks_exact(4) {
|
||||
let float = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
|
||||
vec.push(float);
|
||||
}
|
||||
|
||||
Ok(vec)
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if magnitude_a == 0.0 || magnitude_b == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
dot_product / (magnitude_a * magnitude_b)
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingDao for SqliteEmbeddingDao {
|
||||
fn store_message_embedding(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
message: InsertMessageEmbedding,
|
||||
) -> Result<MessageEmbedding, DbError> {
|
||||
trace_db_call(context, "insert", "store_message_embedding", |_span| {
|
||||
let mut conn = self.connection.lock().expect("Unable to get EmbeddingDao");
|
||||
|
||||
// Validate embedding dimensions
|
||||
if message.embedding.len() != 768 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid embedding dimensions: {} (expected 768)",
|
||||
message.embedding.len()
|
||||
));
|
||||
}
|
||||
|
||||
// Serialize embedding to bytes
|
||||
let embedding_bytes = Self::serialize_vector(&message.embedding);
|
||||
|
||||
// Insert into message_embeddings table with BLOB
|
||||
// Use INSERT OR IGNORE to skip duplicates (based on UNIQUE constraint)
|
||||
let insert_result = diesel::sql_query(
|
||||
"INSERT OR IGNORE INTO message_embeddings (contact, body, timestamp, is_sent, embedding, created_at, model_version)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)"
|
||||
)
|
||||
.bind::<diesel::sql_types::Text, _>(&message.contact)
|
||||
.bind::<diesel::sql_types::Text, _>(&message.body)
|
||||
.bind::<diesel::sql_types::BigInt, _>(message.timestamp)
|
||||
.bind::<diesel::sql_types::Bool, _>(message.is_sent)
|
||||
.bind::<diesel::sql_types::Binary, _>(&embedding_bytes)
|
||||
.bind::<diesel::sql_types::BigInt, _>(message.created_at)
|
||||
.bind::<diesel::sql_types::Text, _>(&message.model_version)
|
||||
.execute(conn.deref_mut())
|
||||
.map_err(|e| anyhow::anyhow!("Insert error: {:?}", e))?;
|
||||
|
||||
// If INSERT OR IGNORE skipped (duplicate), find the existing record
|
||||
let row_id: i32 = if insert_result == 0 {
|
||||
// Duplicate - find the existing record
|
||||
diesel::sql_query(
|
||||
"SELECT id FROM message_embeddings WHERE contact = ?1 AND body = ?2 AND timestamp = ?3"
|
||||
)
|
||||
.bind::<diesel::sql_types::Text, _>(&message.contact)
|
||||
.bind::<diesel::sql_types::Text, _>(&message.body)
|
||||
.bind::<diesel::sql_types::BigInt, _>(message.timestamp)
|
||||
.get_result::<LastInsertRowId>(conn.deref_mut())
|
||||
.map(|r| r.id as i32)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to find existing record: {:?}", e))?
|
||||
} else {
|
||||
// New insert - get the last inserted row ID
|
||||
diesel::sql_query("SELECT last_insert_rowid() as id")
|
||||
.get_result::<LastInsertRowId>(conn.deref_mut())
|
||||
.map(|r| r.id as i32)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to get last insert ID: {:?}", e))?
|
||||
};
|
||||
|
||||
// Return the stored message
|
||||
Ok(MessageEmbedding {
|
||||
id: row_id,
|
||||
contact: message.contact,
|
||||
body: message.body,
|
||||
timestamp: message.timestamp,
|
||||
is_sent: message.is_sent,
|
||||
created_at: message.created_at,
|
||||
model_version: message.model_version,
|
||||
})
|
||||
})
|
||||
.map_err(|_| DbError::new(DbErrorKind::InsertError))
|
||||
}
|
||||
|
||||
fn store_message_embeddings_batch(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
messages: Vec<InsertMessageEmbedding>,
|
||||
) -> Result<usize, DbError> {
|
||||
trace_db_call(context, "insert", "store_message_embeddings_batch", |_span| {
|
||||
let mut conn = self.connection.lock().expect("Unable to get EmbeddingDao");
|
||||
|
||||
// Start transaction
|
||||
conn.transaction::<_, anyhow::Error, _>(|conn| {
|
||||
let mut stored_count = 0;
|
||||
|
||||
for message in messages {
|
||||
// Validate embedding dimensions
|
||||
if message.embedding.len() != 768 {
|
||||
log::warn!(
|
||||
"Invalid embedding dimensions: {} (expected 768), skipping",
|
||||
message.embedding.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Serialize embedding to bytes
|
||||
let embedding_bytes = Self::serialize_vector(&message.embedding);
|
||||
|
||||
// Insert into message_embeddings table with BLOB
|
||||
// Use INSERT OR IGNORE to skip duplicates (based on UNIQUE constraint)
|
||||
match diesel::sql_query(
|
||||
"INSERT OR IGNORE INTO message_embeddings (contact, body, timestamp, is_sent, embedding, created_at, model_version)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)"
|
||||
)
|
||||
.bind::<diesel::sql_types::Text, _>(&message.contact)
|
||||
.bind::<diesel::sql_types::Text, _>(&message.body)
|
||||
.bind::<diesel::sql_types::BigInt, _>(message.timestamp)
|
||||
.bind::<diesel::sql_types::Bool, _>(message.is_sent)
|
||||
.bind::<diesel::sql_types::Binary, _>(&embedding_bytes)
|
||||
.bind::<diesel::sql_types::BigInt, _>(message.created_at)
|
||||
.bind::<diesel::sql_types::Text, _>(&message.model_version)
|
||||
.execute(conn)
|
||||
{
|
||||
Ok(rows) if rows > 0 => stored_count += 1,
|
||||
Ok(_) => {
|
||||
// INSERT OR IGNORE skipped (duplicate)
|
||||
log::debug!("Skipped duplicate message: {:?}", message.body.chars().take(50).collect::<String>());
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to insert message in batch: {:?}", e);
|
||||
// Continue with other messages instead of failing entire batch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stored_count)
|
||||
})
|
||||
.map_err(|e| anyhow::anyhow!("Transaction error: {:?}", e))
|
||||
})
|
||||
.map_err(|_| DbError::new(DbErrorKind::InsertError))
|
||||
}
|
||||
|
||||
fn find_similar_messages(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
contact_filter: Option<&str>,
|
||||
) -> Result<Vec<MessageEmbedding>, DbError> {
|
||||
trace_db_call(context, "query", "find_similar_messages", |_span| {
|
||||
let mut conn = self.connection.lock().expect("Unable to get EmbeddingDao");
|
||||
|
||||
// Validate embedding dimensions
|
||||
if query_embedding.len() != 768 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid query embedding dimensions: {} (expected 768)",
|
||||
query_embedding.len()
|
||||
));
|
||||
}
|
||||
|
||||
// Load messages with optional contact filter
|
||||
let results = if let Some(contact) = contact_filter {
|
||||
log::debug!("RAG search filtered to contact: {}", contact);
|
||||
diesel::sql_query(
|
||||
"SELECT id, contact, body, timestamp, is_sent, embedding, created_at, model_version
|
||||
FROM message_embeddings WHERE contact = ?1"
|
||||
)
|
||||
.bind::<diesel::sql_types::Text, _>(contact)
|
||||
.load::<MessageEmbeddingWithVectorRow>(conn.deref_mut())
|
||||
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?
|
||||
} else {
|
||||
log::debug!("RAG search across ALL contacts (cross-perspective)");
|
||||
diesel::sql_query(
|
||||
"SELECT id, contact, body, timestamp, is_sent, embedding, created_at, model_version
|
||||
FROM message_embeddings"
|
||||
)
|
||||
.load::<MessageEmbeddingWithVectorRow>(conn.deref_mut())
|
||||
.map_err(|e| anyhow::anyhow!("Query error: {:?}", e))?
|
||||
};
|
||||
|
||||
log::debug!("Loaded {} messages for similarity comparison", results.len());
|
||||
|
||||
// Compute similarity for each message
|
||||
let mut scored_messages: Vec<(f32, MessageEmbedding)> = results
|
||||
.into_iter()
|
||||
.filter_map(|row| {
|
||||
// Deserialize the embedding BLOB
|
||||
match Self::deserialize_vector(&row.embedding) {
|
||||
Ok(embedding) => {
|
||||
// Compute cosine similarity
|
||||
let similarity = Self::cosine_similarity(query_embedding, &embedding);
|
||||
Some((
|
||||
similarity,
|
||||
MessageEmbedding {
|
||||
id: row.id,
|
||||
contact: row.contact,
|
||||
body: row.body,
|
||||
timestamp: row.timestamp,
|
||||
is_sent: row.is_sent,
|
||||
created_at: row.created_at,
|
||||
model_version: row.model_version,
|
||||
},
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to deserialize embedding for message {}: {:?}", row.id, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by similarity (highest first)
|
||||
scored_messages.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Log similarity score distribution
|
||||
if !scored_messages.is_empty() {
|
||||
log::info!(
|
||||
"Similarity score distribution - Top: {:.3}, Median: {:.3}, Bottom: {:.3}",
|
||||
scored_messages.first().map(|(s, _)| *s).unwrap_or(0.0),
|
||||
scored_messages.get(scored_messages.len() / 2).map(|(s, _)| *s).unwrap_or(0.0),
|
||||
scored_messages.last().map(|(s, _)| *s).unwrap_or(0.0)
|
||||
);
|
||||
}
|
||||
|
||||
// Apply minimum similarity threshold
|
||||
// With single-contact embeddings, scores tend to be higher due to writing style similarity
|
||||
// Using 0.65 to get only truly semantically relevant messages
|
||||
let min_similarity = 0.65;
|
||||
let filtered_messages: Vec<(f32, MessageEmbedding)> = scored_messages
|
||||
.into_iter()
|
||||
.filter(|(similarity, _)| *similarity >= min_similarity)
|
||||
.collect();
|
||||
|
||||
log::info!(
|
||||
"After similarity filtering (min_similarity={}): {} messages passed threshold",
|
||||
min_similarity,
|
||||
filtered_messages.len()
|
||||
);
|
||||
|
||||
// Filter out short/generic messages (under 30 characters)
|
||||
// This removes conversational closings like "Thanks for talking" that dominate results
|
||||
let min_message_length = 30;
|
||||
|
||||
// Common closing phrases that should be excluded from RAG results
|
||||
let stop_phrases = [
|
||||
"thanks for talking",
|
||||
"thank you for talking",
|
||||
"good talking",
|
||||
"nice talking",
|
||||
"good night",
|
||||
"good morning",
|
||||
"love you",
|
||||
];
|
||||
|
||||
let filtered_messages: Vec<(f32, MessageEmbedding)> = filtered_messages
|
||||
.into_iter()
|
||||
.filter(|(_, message)| {
|
||||
// Filter by length
|
||||
if message.body.len() < min_message_length {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Filter out messages that are primarily generic closings
|
||||
let body_lower = message.body.to_lowercase();
|
||||
for phrase in &stop_phrases {
|
||||
// If the message contains this phrase and is short, it's likely just a closing
|
||||
if body_lower.contains(phrase) && message.body.len() < 100 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
log::info!(
|
||||
"After length filtering (min {} chars): {} messages remain",
|
||||
min_message_length,
|
||||
filtered_messages.len()
|
||||
);
|
||||
|
||||
// Apply temporal diversity filter - don't return too many messages from the same day
|
||||
// This prevents RAG from returning clusters of messages from one conversation
|
||||
let mut filtered_with_diversity = Vec::new();
|
||||
let mut dates_seen: std::collections::HashMap<chrono::NaiveDate, usize> = std::collections::HashMap::new();
|
||||
let max_per_day = 3; // Maximum 3 messages from any single day
|
||||
|
||||
for (similarity, message) in filtered_messages.into_iter() {
|
||||
let date = chrono::DateTime::from_timestamp(message.timestamp, 0)
|
||||
.map(|dt| dt.date_naive())
|
||||
.unwrap_or_else(|| chrono::Utc::now().date_naive());
|
||||
|
||||
let count = dates_seen.entry(date).or_insert(0);
|
||||
if *count < max_per_day {
|
||||
*count += 1;
|
||||
filtered_with_diversity.push((similarity, message));
|
||||
}
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"After temporal diversity filtering (max {} per day): {} messages remain",
|
||||
max_per_day,
|
||||
filtered_with_diversity.len()
|
||||
);
|
||||
|
||||
// Take top N results from diversity-filtered messages
|
||||
let top_results: Vec<MessageEmbedding> = filtered_with_diversity
|
||||
.into_iter()
|
||||
.take(limit)
|
||||
.map(|(similarity, message)| {
|
||||
let time = chrono::DateTime::from_timestamp(message.timestamp, 0)
|
||||
.map(|dt| dt.format("%Y-%m-%d").to_string())
|
||||
.unwrap_or_default();
|
||||
log::info!(
|
||||
"RAG Match: similarity={:.3}, date={}, contact={}, body=\"{}\"",
|
||||
similarity,
|
||||
time,
|
||||
message.contact,
|
||||
&message.body.chars().take(80).collect::<String>()
|
||||
);
|
||||
message
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(top_results)
|
||||
})
|
||||
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
||||
}
|
||||
|
||||
fn get_message_count(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
contact: &str,
|
||||
) -> Result<i64, DbError> {
|
||||
trace_db_call(context, "query", "get_message_count", |_span| {
|
||||
let mut conn = self.connection.lock().expect("Unable to get EmbeddingDao");
|
||||
|
||||
let count = diesel::sql_query(
|
||||
"SELECT COUNT(*) as count FROM message_embeddings WHERE contact = ?1",
|
||||
)
|
||||
.bind::<diesel::sql_types::Text, _>(contact)
|
||||
.get_result::<CountResult>(conn.deref_mut())
|
||||
.map(|r| r.count)
|
||||
.map_err(|e| anyhow::anyhow!("Count query error: {:?}", e))?;
|
||||
|
||||
Ok(count)
|
||||
})
|
||||
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
||||
}
|
||||
|
||||
fn has_embeddings_for_contact(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
contact: &str,
|
||||
) -> Result<bool, DbError> {
|
||||
self.get_message_count(context, contact)
|
||||
.map(|count| count > 0)
|
||||
}
|
||||
|
||||
fn message_exists(
|
||||
&mut self,
|
||||
context: &opentelemetry::Context,
|
||||
contact: &str,
|
||||
body: &str,
|
||||
timestamp: i64,
|
||||
) -> Result<bool, DbError> {
|
||||
trace_db_call(context, "query", "message_exists", |_span| {
|
||||
let mut conn = self.connection.lock().expect("Unable to get EmbeddingDao");
|
||||
|
||||
let count = diesel::sql_query(
|
||||
"SELECT COUNT(*) as count FROM message_embeddings
|
||||
WHERE contact = ?1 AND body = ?2 AND timestamp = ?3",
|
||||
)
|
||||
.bind::<diesel::sql_types::Text, _>(contact)
|
||||
.bind::<diesel::sql_types::Text, _>(body)
|
||||
.bind::<diesel::sql_types::BigInt, _>(timestamp)
|
||||
.get_result::<CountResult>(conn.deref_mut())
|
||||
.map(|r| r.count)
|
||||
.map_err(|e| anyhow::anyhow!("Count query error: {:?}", e))?;
|
||||
|
||||
Ok(count > 0)
|
||||
})
|
||||
.map_err(|_| DbError::new(DbErrorKind::QueryError))
|
||||
}
|
||||
}
|
||||
|
||||
// Helper structs for raw SQL queries
|
||||
|
||||
#[derive(QueryableByName)]
|
||||
struct LastInsertRowId {
|
||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||
id: i64,
|
||||
}
|
||||
|
||||
#[derive(QueryableByName)]
|
||||
struct MessageEmbeddingRow {
|
||||
#[diesel(sql_type = diesel::sql_types::Integer)]
|
||||
id: i32,
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
contact: String,
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
body: String,
|
||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||
timestamp: i64,
|
||||
#[diesel(sql_type = diesel::sql_types::Bool)]
|
||||
is_sent: bool,
|
||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||
created_at: i64,
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
model_version: String,
|
||||
}
|
||||
|
||||
#[derive(QueryableByName)]
|
||||
struct MessageEmbeddingWithVectorRow {
|
||||
#[diesel(sql_type = diesel::sql_types::Integer)]
|
||||
id: i32,
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
contact: String,
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
body: String,
|
||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||
timestamp: i64,
|
||||
#[diesel(sql_type = diesel::sql_types::Bool)]
|
||||
is_sent: bool,
|
||||
#[diesel(sql_type = diesel::sql_types::Binary)]
|
||||
embedding: Vec<u8>,
|
||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||
created_at: i64,
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
model_version: String,
|
||||
}
|
||||
|
||||
#[derive(QueryableByName)]
|
||||
struct CountResult {
|
||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||
count: i64,
|
||||
}
|
||||
@@ -214,12 +214,13 @@ impl LocationHistoryDao for SqliteLocationHistoryDao {
|
||||
|
||||
// Validate embedding dimensions if provided (rare for location data)
|
||||
if let Some(ref emb) = location.embedding
|
||||
&& emb.len() != 768 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid embedding dimensions: {} (expected 768)",
|
||||
emb.len()
|
||||
));
|
||||
}
|
||||
&& emb.len() != 768
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid embedding dimensions: {} (expected 768)",
|
||||
emb.len()
|
||||
));
|
||||
}
|
||||
|
||||
let embedding_bytes = location
|
||||
.embedding
|
||||
@@ -289,13 +290,14 @@ impl LocationHistoryDao for SqliteLocationHistoryDao {
|
||||
for location in locations {
|
||||
// Validate embedding if provided (rare)
|
||||
if let Some(ref emb) = location.embedding
|
||||
&& emb.len() != 768 {
|
||||
log::warn!(
|
||||
"Skipping location with invalid embedding dimensions: {}",
|
||||
emb.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
&& emb.len() != 768
|
||||
{
|
||||
log::warn!(
|
||||
"Skipping location with invalid embedding dimensions: {}",
|
||||
emb.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let embedding_bytes = location
|
||||
.embedding
|
||||
|
||||
@@ -11,7 +11,6 @@ use crate::otel::trace_db_call;
|
||||
|
||||
pub mod calendar_dao;
|
||||
pub mod daily_summary_dao;
|
||||
pub mod embeddings_dao;
|
||||
pub mod insights_dao;
|
||||
pub mod location_dao;
|
||||
pub mod models;
|
||||
@@ -20,7 +19,6 @@ pub mod search_dao;
|
||||
|
||||
pub use calendar_dao::{CalendarEventDao, SqliteCalendarEventDao};
|
||||
pub use daily_summary_dao::{DailySummaryDao, InsertDailySummary, SqliteDailySummaryDao};
|
||||
pub use embeddings_dao::{EmbeddingDao, InsertMessageEmbedding};
|
||||
pub use insights_dao::{InsightDao, SqliteInsightDao};
|
||||
pub use location_dao::{LocationHistoryDao, SqliteLocationHistoryDao};
|
||||
pub use search_dao::{SearchHistoryDao, SqliteSearchHistoryDao};
|
||||
|
||||
Reference in New Issue
Block a user