Daily Summary Embedding Testing

This commit is contained in:
Cameron
2026-01-08 13:41:32 -05:00
parent 61e10f7678
commit 084994e0b5
8 changed files with 1000 additions and 106 deletions

View File

@@ -0,0 +1,282 @@
use anyhow::Result;
use clap::Parser;
use diesel::prelude::*;
use diesel::sql_query;
use diesel::sqlite::SqliteConnection;
use std::env;
#[derive(Parser, Debug)]
#[command(author, version, about = "Diagnose embedding distribution and identify problematic summaries", long_about = None)]
struct Args {
/// Show detailed per-summary statistics
#[arg(short, long, default_value_t = false)]
verbose: bool,
/// Number of top "central" summaries to show (ones that match everything)
#[arg(short, long, default_value_t = 10)]
top: usize,
/// Test a specific query to see what matches
#[arg(short, long)]
query: Option<String>,
}
#[derive(QueryableByName, Debug)]
struct EmbeddingRow {
#[diesel(sql_type = diesel::sql_types::Integer)]
id: i32,
#[diesel(sql_type = diesel::sql_types::Text)]
date: String,
#[diesel(sql_type = diesel::sql_types::Text)]
contact: String,
#[diesel(sql_type = diesel::sql_types::Text)]
summary: String,
#[diesel(sql_type = diesel::sql_types::Binary)]
embedding: Vec<u8>,
}
fn deserialize_embedding(bytes: &[u8]) -> Result<Vec<f32>> {
if bytes.len() % 4 != 0 {
return Err(anyhow::anyhow!("Invalid embedding byte length"));
}
let count = bytes.len() / 4;
let mut vec = Vec::with_capacity(count);
for chunk in bytes.chunks_exact(4) {
let float = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
vec.push(float);
}
Ok(vec)
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
dot_product / (magnitude_a * magnitude_b)
}
fn main() -> Result<()> {
dotenv::dotenv().ok();
let args = Args::parse();
let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| "auth.db".to_string());
println!("Connecting to database: {}", database_url);
let mut conn = SqliteConnection::establish(&database_url)?;
// Load all embeddings
println!("\nLoading embeddings from daily_conversation_summaries...");
let rows: Vec<EmbeddingRow> = sql_query(
"SELECT id, date, contact, summary, embedding FROM daily_conversation_summaries ORDER BY date"
)
.load(&mut conn)?;
println!("Found {} summaries with embeddings\n", rows.len());
if rows.is_empty() {
println!("No summaries found!");
return Ok(());
}
// Parse all embeddings
let mut embeddings: Vec<(i32, String, String, String, Vec<f32>)> = Vec::new();
for row in &rows {
match deserialize_embedding(&row.embedding) {
Ok(emb) => {
embeddings.push((
row.id,
row.date.clone(),
row.contact.clone(),
row.summary.clone(),
emb,
));
}
Err(e) => {
println!("Warning: Failed to parse embedding for id {}: {}", row.id, e);
}
}
}
println!("Successfully parsed {} embeddings\n", embeddings.len());
// Compute embedding statistics
println!("========================================");
println!("EMBEDDING STATISTICS");
println!("========================================\n");
// Check embedding variance (are values clustered or spread out?)
let first_emb = &embeddings[0].4;
let dim = first_emb.len();
println!("Embedding dimensions: {}", dim);
// Calculate mean and std dev per dimension
let mut dim_means: Vec<f32> = vec![0.0; dim];
let mut dim_vars: Vec<f32> = vec![0.0; dim];
for (_, _, _, _, emb) in &embeddings {
for (i, &val) in emb.iter().enumerate() {
dim_means[i] += val;
}
}
for m in &mut dim_means {
*m /= embeddings.len() as f32;
}
for (_, _, _, _, emb) in &embeddings {
for (i, &val) in emb.iter().enumerate() {
let diff = val - dim_means[i];
dim_vars[i] += diff * diff;
}
}
for v in &mut dim_vars {
*v = (*v / embeddings.len() as f32).sqrt();
}
let avg_std_dev: f32 = dim_vars.iter().sum::<f32>() / dim as f32;
let min_std_dev: f32 = dim_vars.iter().cloned().fold(f32::INFINITY, f32::min);
let max_std_dev: f32 = dim_vars.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
println!("Per-dimension standard deviation:");
println!(" Average: {:.6}", avg_std_dev);
println!(" Min: {:.6}", min_std_dev);
println!(" Max: {:.6}", max_std_dev);
println!();
// Compute pairwise similarities
println!("Computing pairwise similarities (this may take a moment)...\n");
let mut all_similarities: Vec<f32> = Vec::new();
let mut per_embedding_avg: Vec<(usize, f32)> = Vec::new();
for i in 0..embeddings.len() {
let mut sum = 0.0;
let mut count = 0;
for j in 0..embeddings.len() {
if i != j {
let sim = cosine_similarity(&embeddings[i].4, &embeddings[j].4);
all_similarities.push(sim);
sum += sim;
count += 1;
}
}
per_embedding_avg.push((i, sum / count as f32));
}
// Sort similarities for percentile analysis
all_similarities.sort_by(|a, b| a.partial_cmp(b).unwrap());
let min_sim = all_similarities.first().copied().unwrap_or(0.0);
let max_sim = all_similarities.last().copied().unwrap_or(0.0);
let median_sim = all_similarities[all_similarities.len() / 2];
let p25 = all_similarities[all_similarities.len() / 4];
let p75 = all_similarities[3 * all_similarities.len() / 4];
let mean_sim: f32 = all_similarities.iter().sum::<f32>() / all_similarities.len() as f32;
println!("========================================");
println!("PAIRWISE SIMILARITY DISTRIBUTION");
println!("========================================\n");
println!("Total pairs analyzed: {}", all_similarities.len());
println!();
println!("Min similarity: {:.4}", min_sim);
println!("25th percentile: {:.4}", p25);
println!("Median similarity: {:.4}", median_sim);
println!("Mean similarity: {:.4}", mean_sim);
println!("75th percentile: {:.4}", p75);
println!("Max similarity: {:.4}", max_sim);
println!();
// Analyze distribution
let count_above_08 = all_similarities.iter().filter(|&&s| s > 0.8).count();
let count_above_07 = all_similarities.iter().filter(|&&s| s > 0.7).count();
let count_above_06 = all_similarities.iter().filter(|&&s| s > 0.6).count();
let count_above_05 = all_similarities.iter().filter(|&&s| s > 0.5).count();
let count_below_03 = all_similarities.iter().filter(|&&s| s < 0.3).count();
println!("Similarity distribution:");
println!(" > 0.8: {} ({:.1}%)", count_above_08, 100.0 * count_above_08 as f32 / all_similarities.len() as f32);
println!(" > 0.7: {} ({:.1}%)", count_above_07, 100.0 * count_above_07 as f32 / all_similarities.len() as f32);
println!(" > 0.6: {} ({:.1}%)", count_above_06, 100.0 * count_above_06 as f32 / all_similarities.len() as f32);
println!(" > 0.5: {} ({:.1}%)", count_above_05, 100.0 * count_above_05 as f32 / all_similarities.len() as f32);
println!(" < 0.3: {} ({:.1}%)", count_below_03, 100.0 * count_below_03 as f32 / all_similarities.len() as f32);
println!();
// Identify "central" embeddings (high average similarity to all others)
per_embedding_avg.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
println!("========================================");
println!("TOP {} MOST 'CENTRAL' SUMMARIES", args.top);
println!("(These match everything with high similarity)");
println!("========================================\n");
for (rank, (idx, avg_sim)) in per_embedding_avg.iter().take(args.top).enumerate() {
let (id, date, contact, summary, _) = &embeddings[*idx];
let preview: String = summary.chars().take(80).collect();
println!("{}. [id={}, avg_sim={:.4}]", rank + 1, id, avg_sim);
println!(" Date: {}, Contact: {}", date, contact);
println!(" Preview: {}...", preview.replace('\n', " "));
println!();
}
// Also show the least central (most unique)
println!("========================================");
println!("TOP {} MOST UNIQUE SUMMARIES", args.top);
println!("(These are most different from others)");
println!("========================================\n");
for (rank, (idx, avg_sim)) in per_embedding_avg.iter().rev().take(args.top).enumerate() {
let (id, date, contact, summary, _) = &embeddings[*idx];
let preview: String = summary.chars().take(80).collect();
println!("{}. [id={}, avg_sim={:.4}]", rank + 1, id, avg_sim);
println!(" Date: {}, Contact: {}", date, contact);
println!(" Preview: {}...", preview.replace('\n', " "));
println!();
}
// Diagnosis
println!("========================================");
println!("DIAGNOSIS");
println!("========================================\n");
if mean_sim > 0.7 {
println!("⚠️ HIGH AVERAGE SIMILARITY ({:.4})", mean_sim);
println!(" All embeddings are very similar to each other.");
println!(" This explains why the same summaries always match.");
println!();
println!(" Possible causes:");
println!(" 1. Summaries have similar structure/phrasing (e.g., all start with 'Summary:')");
println!(" 2. Embedding model isn't capturing semantic differences well");
println!(" 3. Daily conversations have similar topics (e.g., 'good morning', plans)");
println!();
println!(" Recommendations:");
println!(" 1. Try a different embedding model (mxbai-embed-large, bge-large)");
println!(" 2. Improve summary diversity by varying the prompt");
println!(" 3. Extract and embed only keywords/entities, not full summaries");
} else if mean_sim > 0.5 {
println!("⚡ MODERATE AVERAGE SIMILARITY ({:.4})", mean_sim);
println!(" Some clustering in embeddings, but some differentiation exists.");
println!();
println!(" The 'central' summaries above are likely dominating search results.");
println!(" Consider:");
println!(" 1. Filtering out summaries with very high centrality");
println!(" 2. Adding time-based weighting to prefer recent/relevant dates");
println!(" 3. Increasing the similarity threshold from 0.3 to 0.5");
} else {
println!("✅ GOOD EMBEDDING DIVERSITY ({:.4})", mean_sim);
println!(" Embeddings are well-differentiated.");
println!(" If same results keep appearing, the issue may be elsewhere.");
}
Ok(())
}

View File

@@ -0,0 +1,285 @@
use anyhow::Result;
use chrono::NaiveDate;
use clap::Parser;
use image_api::ai::{strip_summary_boilerplate, OllamaClient, SmsApiClient};
use image_api::database::{DailySummaryDao, InsertDailySummary, SqliteDailySummaryDao};
use std::env;
use std::sync::{Arc, Mutex};
#[derive(Parser, Debug)]
#[command(author, version, about = "Test daily summary generation with different models and prompts", long_about = None)]
struct Args {
/// Contact name to generate summaries for
#[arg(short, long)]
contact: String,
/// Start date (YYYY-MM-DD)
#[arg(short, long)]
start: String,
/// End date (YYYY-MM-DD)
#[arg(short, long)]
end: String,
/// Optional: Override the model to use (e.g., "qwen2.5:32b", "llama3.1:30b")
#[arg(short, long)]
model: Option<String>,
/// Test mode: Generate but don't save to database (shows output only)
#[arg(short = 't', long, default_value_t = false)]
test_mode: bool,
/// Show message count and preview
#[arg(short, long, default_value_t = false)]
verbose: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
// Load .env file
dotenv::dotenv().ok();
// Initialize logging
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
let args = Args::parse();
// Parse dates
let start_date = NaiveDate::parse_from_str(&args.start, "%Y-%m-%d")
.expect("Invalid start date format. Use YYYY-MM-DD");
let end_date = NaiveDate::parse_from_str(&args.end, "%Y-%m-%d")
.expect("Invalid end date format. Use YYYY-MM-DD");
println!("========================================");
println!("Daily Summary Generation Test Tool");
println!("========================================");
println!("Contact: {}", args.contact);
println!("Date range: {} to {}", start_date, end_date);
println!("Days: {}", (end_date - start_date).num_days() + 1);
if let Some(ref model) = args.model {
println!("Model: {}", model);
} else {
println!(
"Model: {} (from env)",
env::var("OLLAMA_PRIMARY_MODEL")
.or_else(|_| env::var("OLLAMA_MODEL"))
.unwrap_or_else(|_| "nemotron-3-nano:30b".to_string())
);
}
if args.test_mode {
println!("⚠ TEST MODE: Results will NOT be saved to database");
}
println!("========================================");
println!();
// Initialize AI clients
let ollama_primary_url = env::var("OLLAMA_PRIMARY_URL")
.or_else(|_| env::var("OLLAMA_URL"))
.unwrap_or_else(|_| "http://localhost:11434".to_string());
let ollama_fallback_url = env::var("OLLAMA_FALLBACK_URL").ok();
// Use provided model or fallback to env
let model_to_use = args.model.clone().unwrap_or_else(|| {
env::var("OLLAMA_PRIMARY_MODEL")
.or_else(|_| env::var("OLLAMA_MODEL"))
.unwrap_or_else(|_| "nemotron-3-nano:30b".to_string())
});
let ollama = OllamaClient::new(
ollama_primary_url,
ollama_fallback_url.clone(),
model_to_use.clone(),
Some(model_to_use), // Use same model for fallback
);
let sms_api_url =
env::var("SMS_API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string());
let sms_api_token = env::var("SMS_API_TOKEN").ok();
let sms_client = SmsApiClient::new(sms_api_url, sms_api_token);
// Initialize DAO
let summary_dao: Arc<Mutex<Box<dyn DailySummaryDao>>> =
Arc::new(Mutex::new(Box::new(SqliteDailySummaryDao::new())));
// Fetch messages for contact
println!("Fetching messages for {}...", args.contact);
let all_messages = sms_client
.fetch_all_messages_for_contact(&args.contact)
.await?;
println!(
"Found {} total messages for {}",
all_messages.len(),
args.contact
);
println!();
// Filter to date range and group by date
let mut messages_by_date = std::collections::HashMap::new();
for msg in all_messages {
if let Some(dt) = chrono::DateTime::from_timestamp(msg.timestamp, 0) {
let date = dt.date_naive();
if date >= start_date && date <= end_date {
messages_by_date
.entry(date)
.or_insert_with(Vec::new)
.push(msg);
}
}
}
if messages_by_date.is_empty() {
println!("⚠ No messages found in date range");
return Ok(());
}
println!("Found {} days with messages", messages_by_date.len());
println!();
// Sort dates
let mut dates: Vec<NaiveDate> = messages_by_date.keys().cloned().collect();
dates.sort();
// Process each day
for (idx, date) in dates.iter().enumerate() {
let messages = messages_by_date.get(date).unwrap();
let date_str = date.format("%Y-%m-%d").to_string();
let weekday = date.format("%A");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!(
"Day {}/{}: {} ({}) - {} messages",
idx + 1,
dates.len(),
date_str,
weekday,
messages.len()
);
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
if args.verbose {
println!("\nMessage preview:");
for (i, msg) in messages.iter().take(3).enumerate() {
let sender = if msg.is_sent { "Me" } else { &msg.contact };
let preview = msg.body.chars().take(60).collect::<String>();
println!(" {}. {}: {}...", i + 1, sender, preview);
}
if messages.len() > 3 {
println!(" ... and {} more", messages.len() - 3);
}
println!();
}
// Format messages for LLM
let messages_text: String = messages
.iter()
.take(200)
.map(|m| {
if m.is_sent {
format!("Me: {}", m.body)
} else {
format!("{}: {}", m.contact, m.body)
}
})
.collect::<Vec<_>>()
.join("\n");
let prompt = format!(
r#"Summarize this day's conversation between me and {}.
CRITICAL FORMAT RULES:
- Do NOT start with "Based on the conversation..." or "Here is a summary..." or similar preambles
- Do NOT repeat the date at the beginning
- Start DIRECTLY with the content - begin with a person's name or action
- Write in past tense, as if recording what happened
NARRATIVE (3-5 sentences):
- What specific topics, activities, or events were discussed?
- What places, people, or organizations were mentioned?
- What plans were made or decisions discussed?
- Clearly distinguish between what "I" did versus what {} did
KEYWORDS (comma-separated):
5-10 specific keywords that capture this conversation's unique content:
- Proper nouns (people, places, brands)
- Specific activities ("drum corps audition" not just "music")
- Distinctive terms that make this day unique
Date: {} ({})
Messages:
{}
YOUR RESPONSE (follow this format EXACTLY):
Summary: [Start directly with content, NO preamble]
Keywords: [specific, unique terms]"#,
args.contact,
args.contact,
date.format("%B %d, %Y"),
weekday,
messages_text
);
println!("Generating summary...");
let summary = ollama
.generate(
&prompt,
Some("You are a conversation summarizer. Create clear, factual summaries with precise subject attribution AND extract distinctive keywords. Focus on specific, unique terms that differentiate this conversation from others."),
)
.await?;
println!("\n📝 GENERATED SUMMARY:");
println!("─────────────────────────────────────────");
println!("{}", summary.trim());
println!("─────────────────────────────────────────");
if !args.test_mode {
println!("\nStripping boilerplate for embedding...");
let stripped = strip_summary_boilerplate(&summary);
println!("Stripped: {}...", stripped.chars().take(80).collect::<String>());
println!("\nGenerating embedding...");
let embedding = ollama.generate_embedding(&stripped).await?;
println!("✓ Embedding generated ({} dimensions)", embedding.len());
println!("Saving to database...");
let insert = InsertDailySummary {
date: date_str.clone(),
contact: args.contact.clone(),
summary: summary.trim().to_string(),
message_count: messages.len() as i32,
embedding,
created_at: chrono::Utc::now().timestamp(),
// model_version: "nomic-embed-text:v1.5".to_string(),
model_version: "mxbai-embed-large:335m".to_string(),
};
let mut dao = summary_dao.lock().expect("Unable to lock DailySummaryDao");
let context = opentelemetry::Context::new();
match dao.store_summary(&context, insert) {
Ok(_) => println!("✓ Saved to database"),
Err(e) => println!("✗ Database error: {:?}", e),
}
} else {
println!("\n⚠ TEST MODE: Not saved to database");
}
println!();
// Rate limiting between days
if idx < dates.len() - 1 {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
}
println!("========================================");
println!("✓ Complete!");
println!("Processed {} days", dates.len());
println!("========================================");
Ok(())
}