Daily Summary Embedding Testing

This commit is contained in:
Cameron
2026-01-08 13:41:32 -05:00
parent 61e10f7678
commit 084994e0b5
8 changed files with 1000 additions and 106 deletions

View File

@@ -0,0 +1,282 @@
use anyhow::Result;
use clap::Parser;
use diesel::prelude::*;
use diesel::sql_query;
use diesel::sqlite::SqliteConnection;
use std::env;
#[derive(Parser, Debug)]
#[command(author, version, about = "Diagnose embedding distribution and identify problematic summaries", long_about = None)]
struct Args {
/// Show detailed per-summary statistics
#[arg(short, long, default_value_t = false)]
verbose: bool,
/// Number of top "central" summaries to show (ones that match everything)
#[arg(short, long, default_value_t = 10)]
top: usize,
/// Test a specific query to see what matches
#[arg(short, long)]
query: Option<String>,
}
#[derive(QueryableByName, Debug)]
struct EmbeddingRow {
#[diesel(sql_type = diesel::sql_types::Integer)]
id: i32,
#[diesel(sql_type = diesel::sql_types::Text)]
date: String,
#[diesel(sql_type = diesel::sql_types::Text)]
contact: String,
#[diesel(sql_type = diesel::sql_types::Text)]
summary: String,
#[diesel(sql_type = diesel::sql_types::Binary)]
embedding: Vec<u8>,
}
fn deserialize_embedding(bytes: &[u8]) -> Result<Vec<f32>> {
if bytes.len() % 4 != 0 {
return Err(anyhow::anyhow!("Invalid embedding byte length"));
}
let count = bytes.len() / 4;
let mut vec = Vec::with_capacity(count);
for chunk in bytes.chunks_exact(4) {
let float = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
vec.push(float);
}
Ok(vec)
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
dot_product / (magnitude_a * magnitude_b)
}
fn main() -> Result<()> {
dotenv::dotenv().ok();
let args = Args::parse();
let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| "auth.db".to_string());
println!("Connecting to database: {}", database_url);
let mut conn = SqliteConnection::establish(&database_url)?;
// Load all embeddings
println!("\nLoading embeddings from daily_conversation_summaries...");
let rows: Vec<EmbeddingRow> = sql_query(
"SELECT id, date, contact, summary, embedding FROM daily_conversation_summaries ORDER BY date"
)
.load(&mut conn)?;
println!("Found {} summaries with embeddings\n", rows.len());
if rows.is_empty() {
println!("No summaries found!");
return Ok(());
}
// Parse all embeddings
let mut embeddings: Vec<(i32, String, String, String, Vec<f32>)> = Vec::new();
for row in &rows {
match deserialize_embedding(&row.embedding) {
Ok(emb) => {
embeddings.push((
row.id,
row.date.clone(),
row.contact.clone(),
row.summary.clone(),
emb,
));
}
Err(e) => {
println!("Warning: Failed to parse embedding for id {}: {}", row.id, e);
}
}
}
println!("Successfully parsed {} embeddings\n", embeddings.len());
// Compute embedding statistics
println!("========================================");
println!("EMBEDDING STATISTICS");
println!("========================================\n");
// Check embedding variance (are values clustered or spread out?)
let first_emb = &embeddings[0].4;
let dim = first_emb.len();
println!("Embedding dimensions: {}", dim);
// Calculate mean and std dev per dimension
let mut dim_means: Vec<f32> = vec![0.0; dim];
let mut dim_vars: Vec<f32> = vec![0.0; dim];
for (_, _, _, _, emb) in &embeddings {
for (i, &val) in emb.iter().enumerate() {
dim_means[i] += val;
}
}
for m in &mut dim_means {
*m /= embeddings.len() as f32;
}
for (_, _, _, _, emb) in &embeddings {
for (i, &val) in emb.iter().enumerate() {
let diff = val - dim_means[i];
dim_vars[i] += diff * diff;
}
}
for v in &mut dim_vars {
*v = (*v / embeddings.len() as f32).sqrt();
}
let avg_std_dev: f32 = dim_vars.iter().sum::<f32>() / dim as f32;
let min_std_dev: f32 = dim_vars.iter().cloned().fold(f32::INFINITY, f32::min);
let max_std_dev: f32 = dim_vars.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
println!("Per-dimension standard deviation:");
println!(" Average: {:.6}", avg_std_dev);
println!(" Min: {:.6}", min_std_dev);
println!(" Max: {:.6}", max_std_dev);
println!();
// Compute pairwise similarities
println!("Computing pairwise similarities (this may take a moment)...\n");
let mut all_similarities: Vec<f32> = Vec::new();
let mut per_embedding_avg: Vec<(usize, f32)> = Vec::new();
for i in 0..embeddings.len() {
let mut sum = 0.0;
let mut count = 0;
for j in 0..embeddings.len() {
if i != j {
let sim = cosine_similarity(&embeddings[i].4, &embeddings[j].4);
all_similarities.push(sim);
sum += sim;
count += 1;
}
}
per_embedding_avg.push((i, sum / count as f32));
}
// Sort similarities for percentile analysis
all_similarities.sort_by(|a, b| a.partial_cmp(b).unwrap());
let min_sim = all_similarities.first().copied().unwrap_or(0.0);
let max_sim = all_similarities.last().copied().unwrap_or(0.0);
let median_sim = all_similarities[all_similarities.len() / 2];
let p25 = all_similarities[all_similarities.len() / 4];
let p75 = all_similarities[3 * all_similarities.len() / 4];
let mean_sim: f32 = all_similarities.iter().sum::<f32>() / all_similarities.len() as f32;
println!("========================================");
println!("PAIRWISE SIMILARITY DISTRIBUTION");
println!("========================================\n");
println!("Total pairs analyzed: {}", all_similarities.len());
println!();
println!("Min similarity: {:.4}", min_sim);
println!("25th percentile: {:.4}", p25);
println!("Median similarity: {:.4}", median_sim);
println!("Mean similarity: {:.4}", mean_sim);
println!("75th percentile: {:.4}", p75);
println!("Max similarity: {:.4}", max_sim);
println!();
// Analyze distribution
let count_above_08 = all_similarities.iter().filter(|&&s| s > 0.8).count();
let count_above_07 = all_similarities.iter().filter(|&&s| s > 0.7).count();
let count_above_06 = all_similarities.iter().filter(|&&s| s > 0.6).count();
let count_above_05 = all_similarities.iter().filter(|&&s| s > 0.5).count();
let count_below_03 = all_similarities.iter().filter(|&&s| s < 0.3).count();
println!("Similarity distribution:");
println!(" > 0.8: {} ({:.1}%)", count_above_08, 100.0 * count_above_08 as f32 / all_similarities.len() as f32);
println!(" > 0.7: {} ({:.1}%)", count_above_07, 100.0 * count_above_07 as f32 / all_similarities.len() as f32);
println!(" > 0.6: {} ({:.1}%)", count_above_06, 100.0 * count_above_06 as f32 / all_similarities.len() as f32);
println!(" > 0.5: {} ({:.1}%)", count_above_05, 100.0 * count_above_05 as f32 / all_similarities.len() as f32);
println!(" < 0.3: {} ({:.1}%)", count_below_03, 100.0 * count_below_03 as f32 / all_similarities.len() as f32);
println!();
// Identify "central" embeddings (high average similarity to all others)
per_embedding_avg.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
println!("========================================");
println!("TOP {} MOST 'CENTRAL' SUMMARIES", args.top);
println!("(These match everything with high similarity)");
println!("========================================\n");
for (rank, (idx, avg_sim)) in per_embedding_avg.iter().take(args.top).enumerate() {
let (id, date, contact, summary, _) = &embeddings[*idx];
let preview: String = summary.chars().take(80).collect();
println!("{}. [id={}, avg_sim={:.4}]", rank + 1, id, avg_sim);
println!(" Date: {}, Contact: {}", date, contact);
println!(" Preview: {}...", preview.replace('\n', " "));
println!();
}
// Also show the least central (most unique)
println!("========================================");
println!("TOP {} MOST UNIQUE SUMMARIES", args.top);
println!("(These are most different from others)");
println!("========================================\n");
for (rank, (idx, avg_sim)) in per_embedding_avg.iter().rev().take(args.top).enumerate() {
let (id, date, contact, summary, _) = &embeddings[*idx];
let preview: String = summary.chars().take(80).collect();
println!("{}. [id={}, avg_sim={:.4}]", rank + 1, id, avg_sim);
println!(" Date: {}, Contact: {}", date, contact);
println!(" Preview: {}...", preview.replace('\n', " "));
println!();
}
// Diagnosis
println!("========================================");
println!("DIAGNOSIS");
println!("========================================\n");
if mean_sim > 0.7 {
println!("⚠️ HIGH AVERAGE SIMILARITY ({:.4})", mean_sim);
println!(" All embeddings are very similar to each other.");
println!(" This explains why the same summaries always match.");
println!();
println!(" Possible causes:");
println!(" 1. Summaries have similar structure/phrasing (e.g., all start with 'Summary:')");
println!(" 2. Embedding model isn't capturing semantic differences well");
println!(" 3. Daily conversations have similar topics (e.g., 'good morning', plans)");
println!();
println!(" Recommendations:");
println!(" 1. Try a different embedding model (mxbai-embed-large, bge-large)");
println!(" 2. Improve summary diversity by varying the prompt");
println!(" 3. Extract and embed only keywords/entities, not full summaries");
} else if mean_sim > 0.5 {
println!("⚡ MODERATE AVERAGE SIMILARITY ({:.4})", mean_sim);
println!(" Some clustering in embeddings, but some differentiation exists.");
println!();
println!(" The 'central' summaries above are likely dominating search results.");
println!(" Consider:");
println!(" 1. Filtering out summaries with very high centrality");
println!(" 2. Adding time-based weighting to prefer recent/relevant dates");
println!(" 3. Increasing the similarity threshold from 0.3 to 0.5");
} else {
println!("✅ GOOD EMBEDDING DIVERSITY ({:.4})", mean_sim);
println!(" Embeddings are well-differentiated.");
println!(" If same results keep appearing, the issue may be elsewhere.");
}
Ok(())
}