use anyhow::Result; use clap::Parser; use diesel::prelude::*; use diesel::sql_query; use diesel::sqlite::SqliteConnection; use std::env; #[derive(Parser, Debug)] #[command(author, version, about = "Diagnose embedding distribution and identify problematic summaries", long_about = None)] struct Args { /// Show detailed per-summary statistics #[arg(short, long, default_value_t = false)] verbose: bool, /// Number of top "central" summaries to show (ones that match everything) #[arg(short, long, default_value_t = 10)] top: usize, /// Test a specific query to see what matches #[arg(short, long)] query: Option, } #[derive(QueryableByName, Debug)] struct EmbeddingRow { #[diesel(sql_type = diesel::sql_types::Integer)] id: i32, #[diesel(sql_type = diesel::sql_types::Text)] date: String, #[diesel(sql_type = diesel::sql_types::Text)] contact: String, #[diesel(sql_type = diesel::sql_types::Text)] summary: String, #[diesel(sql_type = diesel::sql_types::Binary)] embedding: Vec, } fn deserialize_embedding(bytes: &[u8]) -> Result> { if !bytes.len().is_multiple_of(4) { return Err(anyhow::anyhow!("Invalid embedding byte length")); } let count = bytes.len() / 4; let mut vec = Vec::with_capacity(count); for chunk in bytes.chunks_exact(4) { let float = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); vec.push(float); } Ok(vec) } fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { if a.len() != b.len() { return 0.0; } let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); let magnitude_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); let magnitude_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); if magnitude_a == 0.0 || magnitude_b == 0.0 { return 0.0; } dot_product / (magnitude_a * magnitude_b) } fn main() -> Result<()> { dotenv::dotenv().ok(); let args = Args::parse(); let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| "auth.db".to_string()); println!("Connecting to database: {}", database_url); let mut conn = SqliteConnection::establish(&database_url)?; // Load all embeddings println!("\nLoading embeddings from daily_conversation_summaries..."); let rows: Vec = sql_query( "SELECT id, date, contact, summary, embedding FROM daily_conversation_summaries ORDER BY date" ) .load(&mut conn)?; println!("Found {} summaries with embeddings\n", rows.len()); if rows.is_empty() { println!("No summaries found!"); return Ok(()); } // Parse all embeddings let mut embeddings: Vec<(i32, String, String, String, Vec)> = Vec::new(); for row in &rows { match deserialize_embedding(&row.embedding) { Ok(emb) => { embeddings.push(( row.id, row.date.clone(), row.contact.clone(), row.summary.clone(), emb, )); } Err(e) => { println!( "Warning: Failed to parse embedding for id {}: {}", row.id, e ); } } } println!("Successfully parsed {} embeddings\n", embeddings.len()); // Compute embedding statistics println!("========================================"); println!("EMBEDDING STATISTICS"); println!("========================================\n"); // Check embedding variance (are values clustered or spread out?) let first_emb = &embeddings[0].4; let dim = first_emb.len(); println!("Embedding dimensions: {}", dim); // Calculate mean and std dev per dimension let mut dim_means: Vec = vec![0.0; dim]; let mut dim_vars: Vec = vec![0.0; dim]; for (_, _, _, _, emb) in &embeddings { for (i, &val) in emb.iter().enumerate() { dim_means[i] += val; } } for m in &mut dim_means { *m /= embeddings.len() as f32; } for (_, _, _, _, emb) in &embeddings { for (i, &val) in emb.iter().enumerate() { let diff = val - dim_means[i]; dim_vars[i] += diff * diff; } } for v in &mut dim_vars { *v = (*v / embeddings.len() as f32).sqrt(); } let avg_std_dev: f32 = dim_vars.iter().sum::() / dim as f32; let min_std_dev: f32 = dim_vars.iter().cloned().fold(f32::INFINITY, f32::min); let max_std_dev: f32 = dim_vars.iter().cloned().fold(f32::NEG_INFINITY, f32::max); println!("Per-dimension standard deviation:"); println!(" Average: {:.6}", avg_std_dev); println!(" Min: {:.6}", min_std_dev); println!(" Max: {:.6}", max_std_dev); println!(); // Compute pairwise similarities println!("Computing pairwise similarities (this may take a moment)...\n"); let mut all_similarities: Vec = Vec::new(); let mut per_embedding_avg: Vec<(usize, f32)> = Vec::new(); for i in 0..embeddings.len() { let mut sum = 0.0; let mut count = 0; for j in 0..embeddings.len() { if i != j { let sim = cosine_similarity(&embeddings[i].4, &embeddings[j].4); all_similarities.push(sim); sum += sim; count += 1; } } per_embedding_avg.push((i, sum / count as f32)); } // Sort similarities for percentile analysis all_similarities.sort_by(|a, b| a.partial_cmp(b).unwrap()); let min_sim = all_similarities.first().copied().unwrap_or(0.0); let max_sim = all_similarities.last().copied().unwrap_or(0.0); let median_sim = all_similarities[all_similarities.len() / 2]; let p25 = all_similarities[all_similarities.len() / 4]; let p75 = all_similarities[3 * all_similarities.len() / 4]; let mean_sim: f32 = all_similarities.iter().sum::() / all_similarities.len() as f32; println!("========================================"); println!("PAIRWISE SIMILARITY DISTRIBUTION"); println!("========================================\n"); println!("Total pairs analyzed: {}", all_similarities.len()); println!(); println!("Min similarity: {:.4}", min_sim); println!("25th percentile: {:.4}", p25); println!("Median similarity: {:.4}", median_sim); println!("Mean similarity: {:.4}", mean_sim); println!("75th percentile: {:.4}", p75); println!("Max similarity: {:.4}", max_sim); println!(); // Analyze distribution let count_above_08 = all_similarities.iter().filter(|&&s| s > 0.8).count(); let count_above_07 = all_similarities.iter().filter(|&&s| s > 0.7).count(); let count_above_06 = all_similarities.iter().filter(|&&s| s > 0.6).count(); let count_above_05 = all_similarities.iter().filter(|&&s| s > 0.5).count(); let count_below_03 = all_similarities.iter().filter(|&&s| s < 0.3).count(); println!("Similarity distribution:"); println!( " > 0.8: {} ({:.1}%)", count_above_08, 100.0 * count_above_08 as f32 / all_similarities.len() as f32 ); println!( " > 0.7: {} ({:.1}%)", count_above_07, 100.0 * count_above_07 as f32 / all_similarities.len() as f32 ); println!( " > 0.6: {} ({:.1}%)", count_above_06, 100.0 * count_above_06 as f32 / all_similarities.len() as f32 ); println!( " > 0.5: {} ({:.1}%)", count_above_05, 100.0 * count_above_05 as f32 / all_similarities.len() as f32 ); println!( " < 0.3: {} ({:.1}%)", count_below_03, 100.0 * count_below_03 as f32 / all_similarities.len() as f32 ); println!(); // Identify "central" embeddings (high average similarity to all others) per_embedding_avg.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); println!("========================================"); println!("TOP {} MOST 'CENTRAL' SUMMARIES", args.top); println!("(These match everything with high similarity)"); println!("========================================\n"); for (rank, (idx, avg_sim)) in per_embedding_avg.iter().take(args.top).enumerate() { let (id, date, contact, summary, _) = &embeddings[*idx]; let preview: String = summary.chars().take(80).collect(); println!("{}. [id={}, avg_sim={:.4}]", rank + 1, id, avg_sim); println!(" Date: {}, Contact: {}", date, contact); println!(" Preview: {}...", preview.replace('\n', " ")); println!(); } // Also show the least central (most unique) println!("========================================"); println!("TOP {} MOST UNIQUE SUMMARIES", args.top); println!("(These are most different from others)"); println!("========================================\n"); for (rank, (idx, avg_sim)) in per_embedding_avg.iter().rev().take(args.top).enumerate() { let (id, date, contact, summary, _) = &embeddings[*idx]; let preview: String = summary.chars().take(80).collect(); println!("{}. [id={}, avg_sim={:.4}]", rank + 1, id, avg_sim); println!(" Date: {}, Contact: {}", date, contact); println!(" Preview: {}...", preview.replace('\n', " ")); println!(); } // Diagnosis println!("========================================"); println!("DIAGNOSIS"); println!("========================================\n"); if mean_sim > 0.7 { println!("⚠️ HIGH AVERAGE SIMILARITY ({:.4})", mean_sim); println!(" All embeddings are very similar to each other."); println!(" This explains why the same summaries always match."); println!(); println!(" Possible causes:"); println!( " 1. Summaries have similar structure/phrasing (e.g., all start with 'Summary:')" ); println!(" 2. Embedding model isn't capturing semantic differences well"); println!(" 3. Daily conversations have similar topics (e.g., 'good morning', plans)"); println!(); println!(" Recommendations:"); println!(" 1. Try a different embedding model (mxbai-embed-large, bge-large)"); println!(" 2. Improve summary diversity by varying the prompt"); println!(" 3. Extract and embed only keywords/entities, not full summaries"); } else if mean_sim > 0.5 { println!("⚡ MODERATE AVERAGE SIMILARITY ({:.4})", mean_sim); println!(" Some clustering in embeddings, but some differentiation exists."); println!(); println!(" The 'central' summaries above are likely dominating search results."); println!(" Consider:"); println!(" 1. Filtering out summaries with very high centrality"); println!(" 2. Adding time-based weighting to prefer recent/relevant dates"); println!(" 3. Increasing the similarity threshold from 0.3 to 0.5"); } else { println!("✅ GOOD EMBEDDING DIVERSITY ({:.4})", mean_sim); println!(" Embeddings are well-differentiated."); println!(" If same results keep appearing, the issue may be elsewhere."); } Ok(()) }