308 lines
11 KiB
Rust
308 lines
11 KiB
Rust
use anyhow::Result;
|
|
use clap::Parser;
|
|
use diesel::prelude::*;
|
|
use diesel::sql_query;
|
|
use diesel::sqlite::SqliteConnection;
|
|
use std::env;
|
|
|
|
#[derive(Parser, Debug)]
|
|
#[command(author, version, about = "Diagnose embedding distribution and identify problematic summaries", long_about = None)]
|
|
struct Args {
|
|
/// Show detailed per-summary statistics
|
|
#[arg(short, long, default_value_t = false)]
|
|
verbose: bool,
|
|
|
|
/// Number of top "central" summaries to show (ones that match everything)
|
|
#[arg(short, long, default_value_t = 10)]
|
|
top: usize,
|
|
|
|
/// Test a specific query to see what matches
|
|
#[arg(short, long)]
|
|
query: Option<String>,
|
|
}
|
|
|
|
#[derive(QueryableByName, Debug)]
|
|
struct EmbeddingRow {
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
|
id: i32,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
date: String,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
contact: String,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
summary: String,
|
|
#[diesel(sql_type = diesel::sql_types::Binary)]
|
|
embedding: Vec<u8>,
|
|
}
|
|
|
|
fn deserialize_embedding(bytes: &[u8]) -> Result<Vec<f32>> {
|
|
if !bytes.len().is_multiple_of(4) {
|
|
return Err(anyhow::anyhow!("Invalid embedding byte length"));
|
|
}
|
|
|
|
let count = bytes.len() / 4;
|
|
let mut vec = Vec::with_capacity(count);
|
|
|
|
for chunk in bytes.chunks_exact(4) {
|
|
let float = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
|
|
vec.push(float);
|
|
}
|
|
|
|
Ok(vec)
|
|
}
|
|
|
|
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
if a.len() != b.len() {
|
|
return 0.0;
|
|
}
|
|
|
|
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
|
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
if magnitude_a == 0.0 || magnitude_b == 0.0 {
|
|
return 0.0;
|
|
}
|
|
|
|
dot_product / (magnitude_a * magnitude_b)
|
|
}
|
|
|
|
fn main() -> Result<()> {
|
|
dotenv::dotenv().ok();
|
|
let args = Args::parse();
|
|
|
|
let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| "auth.db".to_string());
|
|
println!("Connecting to database: {}", database_url);
|
|
|
|
let mut conn = SqliteConnection::establish(&database_url)?;
|
|
|
|
// Load all embeddings
|
|
println!("\nLoading embeddings from daily_conversation_summaries...");
|
|
let rows: Vec<EmbeddingRow> = sql_query(
|
|
"SELECT id, date, contact, summary, embedding FROM daily_conversation_summaries ORDER BY date"
|
|
)
|
|
.load(&mut conn)?;
|
|
|
|
println!("Found {} summaries with embeddings\n", rows.len());
|
|
|
|
if rows.is_empty() {
|
|
println!("No summaries found!");
|
|
return Ok(());
|
|
}
|
|
|
|
// Parse all embeddings
|
|
let mut embeddings: Vec<(i32, String, String, String, Vec<f32>)> = Vec::new();
|
|
for row in &rows {
|
|
match deserialize_embedding(&row.embedding) {
|
|
Ok(emb) => {
|
|
embeddings.push((
|
|
row.id,
|
|
row.date.clone(),
|
|
row.contact.clone(),
|
|
row.summary.clone(),
|
|
emb,
|
|
));
|
|
}
|
|
Err(e) => {
|
|
println!(
|
|
"Warning: Failed to parse embedding for id {}: {}",
|
|
row.id, e
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
println!("Successfully parsed {} embeddings\n", embeddings.len());
|
|
|
|
// Compute embedding statistics
|
|
println!("========================================");
|
|
println!("EMBEDDING STATISTICS");
|
|
println!("========================================\n");
|
|
|
|
// Check embedding variance (are values clustered or spread out?)
|
|
let first_emb = &embeddings[0].4;
|
|
let dim = first_emb.len();
|
|
println!("Embedding dimensions: {}", dim);
|
|
|
|
// Calculate mean and std dev per dimension
|
|
let mut dim_means: Vec<f32> = vec![0.0; dim];
|
|
let mut dim_vars: Vec<f32> = vec![0.0; dim];
|
|
|
|
for (_, _, _, _, emb) in &embeddings {
|
|
for (i, &val) in emb.iter().enumerate() {
|
|
dim_means[i] += val;
|
|
}
|
|
}
|
|
for m in &mut dim_means {
|
|
*m /= embeddings.len() as f32;
|
|
}
|
|
|
|
for (_, _, _, _, emb) in &embeddings {
|
|
for (i, &val) in emb.iter().enumerate() {
|
|
let diff = val - dim_means[i];
|
|
dim_vars[i] += diff * diff;
|
|
}
|
|
}
|
|
for v in &mut dim_vars {
|
|
*v = (*v / embeddings.len() as f32).sqrt();
|
|
}
|
|
|
|
let avg_std_dev: f32 = dim_vars.iter().sum::<f32>() / dim as f32;
|
|
let min_std_dev: f32 = dim_vars.iter().cloned().fold(f32::INFINITY, f32::min);
|
|
let max_std_dev: f32 = dim_vars.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
|
|
|
println!("Per-dimension standard deviation:");
|
|
println!(" Average: {:.6}", avg_std_dev);
|
|
println!(" Min: {:.6}", min_std_dev);
|
|
println!(" Max: {:.6}", max_std_dev);
|
|
println!();
|
|
|
|
// Compute pairwise similarities
|
|
println!("Computing pairwise similarities (this may take a moment)...\n");
|
|
|
|
let mut all_similarities: Vec<f32> = Vec::new();
|
|
let mut per_embedding_avg: Vec<(usize, f32)> = Vec::new();
|
|
|
|
for i in 0..embeddings.len() {
|
|
let mut sum = 0.0;
|
|
let mut count = 0;
|
|
for j in 0..embeddings.len() {
|
|
if i != j {
|
|
let sim = cosine_similarity(&embeddings[i].4, &embeddings[j].4);
|
|
all_similarities.push(sim);
|
|
sum += sim;
|
|
count += 1;
|
|
}
|
|
}
|
|
per_embedding_avg.push((i, sum / count as f32));
|
|
}
|
|
|
|
// Sort similarities for percentile analysis
|
|
all_similarities.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
|
|
let min_sim = all_similarities.first().copied().unwrap_or(0.0);
|
|
let max_sim = all_similarities.last().copied().unwrap_or(0.0);
|
|
let median_sim = all_similarities[all_similarities.len() / 2];
|
|
let p25 = all_similarities[all_similarities.len() / 4];
|
|
let p75 = all_similarities[3 * all_similarities.len() / 4];
|
|
let mean_sim: f32 = all_similarities.iter().sum::<f32>() / all_similarities.len() as f32;
|
|
|
|
println!("========================================");
|
|
println!("PAIRWISE SIMILARITY DISTRIBUTION");
|
|
println!("========================================\n");
|
|
println!("Total pairs analyzed: {}", all_similarities.len());
|
|
println!();
|
|
println!("Min similarity: {:.4}", min_sim);
|
|
println!("25th percentile: {:.4}", p25);
|
|
println!("Median similarity: {:.4}", median_sim);
|
|
println!("Mean similarity: {:.4}", mean_sim);
|
|
println!("75th percentile: {:.4}", p75);
|
|
println!("Max similarity: {:.4}", max_sim);
|
|
println!();
|
|
|
|
// Analyze distribution
|
|
let count_above_08 = all_similarities.iter().filter(|&&s| s > 0.8).count();
|
|
let count_above_07 = all_similarities.iter().filter(|&&s| s > 0.7).count();
|
|
let count_above_06 = all_similarities.iter().filter(|&&s| s > 0.6).count();
|
|
let count_above_05 = all_similarities.iter().filter(|&&s| s > 0.5).count();
|
|
let count_below_03 = all_similarities.iter().filter(|&&s| s < 0.3).count();
|
|
|
|
println!("Similarity distribution:");
|
|
println!(
|
|
" > 0.8: {} ({:.1}%)",
|
|
count_above_08,
|
|
100.0 * count_above_08 as f32 / all_similarities.len() as f32
|
|
);
|
|
println!(
|
|
" > 0.7: {} ({:.1}%)",
|
|
count_above_07,
|
|
100.0 * count_above_07 as f32 / all_similarities.len() as f32
|
|
);
|
|
println!(
|
|
" > 0.6: {} ({:.1}%)",
|
|
count_above_06,
|
|
100.0 * count_above_06 as f32 / all_similarities.len() as f32
|
|
);
|
|
println!(
|
|
" > 0.5: {} ({:.1}%)",
|
|
count_above_05,
|
|
100.0 * count_above_05 as f32 / all_similarities.len() as f32
|
|
);
|
|
println!(
|
|
" < 0.3: {} ({:.1}%)",
|
|
count_below_03,
|
|
100.0 * count_below_03 as f32 / all_similarities.len() as f32
|
|
);
|
|
println!();
|
|
|
|
// Identify "central" embeddings (high average similarity to all others)
|
|
per_embedding_avg.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
|
|
|
println!("========================================");
|
|
println!("TOP {} MOST 'CENTRAL' SUMMARIES", args.top);
|
|
println!("(These match everything with high similarity)");
|
|
println!("========================================\n");
|
|
|
|
for (rank, (idx, avg_sim)) in per_embedding_avg.iter().take(args.top).enumerate() {
|
|
let (id, date, contact, summary, _) = &embeddings[*idx];
|
|
let preview: String = summary.chars().take(80).collect();
|
|
println!("{}. [id={}, avg_sim={:.4}]", rank + 1, id, avg_sim);
|
|
println!(" Date: {}, Contact: {}", date, contact);
|
|
println!(" Preview: {}...", preview.replace('\n', " "));
|
|
println!();
|
|
}
|
|
|
|
// Also show the least central (most unique)
|
|
println!("========================================");
|
|
println!("TOP {} MOST UNIQUE SUMMARIES", args.top);
|
|
println!("(These are most different from others)");
|
|
println!("========================================\n");
|
|
|
|
for (rank, (idx, avg_sim)) in per_embedding_avg.iter().rev().take(args.top).enumerate() {
|
|
let (id, date, contact, summary, _) = &embeddings[*idx];
|
|
let preview: String = summary.chars().take(80).collect();
|
|
println!("{}. [id={}, avg_sim={:.4}]", rank + 1, id, avg_sim);
|
|
println!(" Date: {}, Contact: {}", date, contact);
|
|
println!(" Preview: {}...", preview.replace('\n', " "));
|
|
println!();
|
|
}
|
|
|
|
// Diagnosis
|
|
println!("========================================");
|
|
println!("DIAGNOSIS");
|
|
println!("========================================\n");
|
|
|
|
if mean_sim > 0.7 {
|
|
println!("⚠️ HIGH AVERAGE SIMILARITY ({:.4})", mean_sim);
|
|
println!(" All embeddings are very similar to each other.");
|
|
println!(" This explains why the same summaries always match.");
|
|
println!();
|
|
println!(" Possible causes:");
|
|
println!(
|
|
" 1. Summaries have similar structure/phrasing (e.g., all start with 'Summary:')"
|
|
);
|
|
println!(" 2. Embedding model isn't capturing semantic differences well");
|
|
println!(" 3. Daily conversations have similar topics (e.g., 'good morning', plans)");
|
|
println!();
|
|
println!(" Recommendations:");
|
|
println!(" 1. Try a different embedding model (mxbai-embed-large, bge-large)");
|
|
println!(" 2. Improve summary diversity by varying the prompt");
|
|
println!(" 3. Extract and embed only keywords/entities, not full summaries");
|
|
} else if mean_sim > 0.5 {
|
|
println!("⚡ MODERATE AVERAGE SIMILARITY ({:.4})", mean_sim);
|
|
println!(" Some clustering in embeddings, but some differentiation exists.");
|
|
println!();
|
|
println!(" The 'central' summaries above are likely dominating search results.");
|
|
println!(" Consider:");
|
|
println!(" 1. Filtering out summaries with very high centrality");
|
|
println!(" 2. Adding time-based weighting to prefer recent/relevant dates");
|
|
println!(" 3. Increasing the similarity threshold from 0.3 to 0.5");
|
|
} else {
|
|
println!("✅ GOOD EMBEDDING DIVERSITY ({:.4})", mean_sim);
|
|
println!(" Embeddings are well-differentiated.");
|
|
println!(" If same results keep appearing, the issue may be elsewhere.");
|
|
}
|
|
|
|
Ok(())
|
|
}
|