Files
ImageApi/src/bin/test_daily_summary.rs
Cameron d54419e779 style: cargo fmt drift
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-23 17:19:59 -04:00

279 lines
9.4 KiB
Rust

use anyhow::Result;
use chrono::NaiveDate;
use clap::Parser;
use image_api::ai::{
EMBEDDING_MODEL, OllamaClient, SmsApiClient, build_daily_summary_prompt,
strip_summary_boilerplate, user_display_name,
};
use image_api::database::{DailySummaryDao, InsertDailySummary, SqliteDailySummaryDao};
use std::env;
use std::sync::{Arc, Mutex};
#[derive(Parser, Debug)]
#[command(author, version, about = "Test daily summary generation with different models and prompts", long_about = None)]
struct Args {
/// Contact name to generate summaries for
#[arg(short, long)]
contact: String,
/// Start date (YYYY-MM-DD)
#[arg(short, long)]
start: String,
/// End date (YYYY-MM-DD)
#[arg(short, long)]
end: String,
/// Optional: Override the model to use (e.g., "qwen2.5:32b", "llama3.1:30b")
#[arg(short, long)]
model: Option<String>,
/// Context window size passed as Ollama `num_ctx`. Omit for server default.
#[arg(long)]
num_ctx: Option<i32>,
/// Sampling temperature. Omit for server default.
#[arg(long)]
temperature: Option<f32>,
/// Top-p (nucleus) sampling. Omit for server default.
#[arg(long)]
top_p: Option<f32>,
/// Top-k sampling. Omit for server default.
#[arg(long)]
top_k: Option<i32>,
/// Min-p sampling. Omit for server default.
#[arg(long)]
min_p: Option<f32>,
/// Test mode: Generate but don't save to database (shows output only)
#[arg(short = 't', long, default_value_t = false)]
test_mode: bool,
/// Show message count and preview
#[arg(short, long, default_value_t = false)]
verbose: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
// Load .env file
dotenv::dotenv().ok();
// Initialize logging
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
let args = Args::parse();
// Parse dates
let start_date = NaiveDate::parse_from_str(&args.start, "%Y-%m-%d")
.expect("Invalid start date format. Use YYYY-MM-DD");
let end_date = NaiveDate::parse_from_str(&args.end, "%Y-%m-%d")
.expect("Invalid end date format. Use YYYY-MM-DD");
println!("========================================");
println!("Daily Summary Generation Test Tool");
println!("========================================");
println!("Contact: {}", args.contact);
println!("Date range: {} to {}", start_date, end_date);
println!("Days: {}", (end_date - start_date).num_days() + 1);
if let Some(ref model) = args.model {
println!("Model: {}", model);
} else {
println!(
"Model: {} (from env)",
env::var("OLLAMA_PRIMARY_MODEL")
.or_else(|_| env::var("OLLAMA_MODEL"))
.unwrap_or_else(|_| "nemotron-3-nano:30b".to_string())
);
}
if args.test_mode {
println!("⚠ TEST MODE: Results will NOT be saved to database");
}
println!("========================================");
println!();
// Initialize AI clients
let ollama_primary_url = env::var("OLLAMA_PRIMARY_URL")
.or_else(|_| env::var("OLLAMA_URL"))
.unwrap_or_else(|_| "http://localhost:11434".to_string());
let ollama_fallback_url = env::var("OLLAMA_FALLBACK_URL").ok();
// Use provided model or fallback to env
let model_to_use = args.model.clone().unwrap_or_else(|| {
env::var("OLLAMA_PRIMARY_MODEL")
.or_else(|_| env::var("OLLAMA_MODEL"))
.unwrap_or_else(|_| "nemotron-3-nano:30b".to_string())
});
let mut ollama = OllamaClient::new(
ollama_primary_url,
ollama_fallback_url.clone(),
model_to_use.clone(),
Some(model_to_use), // Use same model for fallback
);
if let Some(ctx) = args.num_ctx {
ollama.set_num_ctx(Some(ctx));
}
if args.temperature.is_some()
|| args.top_p.is_some()
|| args.top_k.is_some()
|| args.min_p.is_some()
{
ollama.set_sampling_params(args.temperature, args.top_p, args.top_k, args.min_p);
}
// Surface what's actually configured so comparison runs are auditable.
println!(
"num_ctx={:?} temperature={:?} top_p={:?} top_k={:?} min_p={:?}",
args.num_ctx, args.temperature, args.top_p, args.top_k, args.min_p
);
let sms_api_url =
env::var("SMS_API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string());
let sms_api_token = env::var("SMS_API_TOKEN").ok();
let sms_client = SmsApiClient::new(sms_api_url, sms_api_token);
// Initialize DAO
let summary_dao: Arc<Mutex<Box<dyn DailySummaryDao>>> =
Arc::new(Mutex::new(Box::new(SqliteDailySummaryDao::new())));
// Fetch messages for contact
println!("Fetching messages for {}...", args.contact);
let all_messages = sms_client
.fetch_all_messages_for_contact(&args.contact)
.await?;
println!(
"Found {} total messages for {}",
all_messages.len(),
args.contact
);
println!();
// Filter to date range and group by date
let mut messages_by_date = std::collections::HashMap::new();
for msg in all_messages {
if let Some(dt) = chrono::DateTime::from_timestamp(msg.timestamp, 0) {
let date = dt.date_naive();
if date >= start_date && date <= end_date {
messages_by_date
.entry(date)
.or_insert_with(Vec::new)
.push(msg);
}
}
}
if messages_by_date.is_empty() {
println!("⚠ No messages found in date range");
return Ok(());
}
println!("Found {} days with messages", messages_by_date.len());
println!();
// Sort dates
let mut dates: Vec<NaiveDate> = messages_by_date.keys().cloned().collect();
dates.sort();
// Process each day
for (idx, date) in dates.iter().enumerate() {
let messages = messages_by_date.get(date).unwrap();
let date_str = date.format("%Y-%m-%d").to_string();
let weekday = date.format("%A");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!(
"Day {}/{}: {} ({}) - {} messages",
idx + 1,
dates.len(),
date_str,
weekday,
messages.len()
);
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
if args.verbose {
let user_name = user_display_name();
println!("\nMessage preview:");
for (i, msg) in messages.iter().take(3).enumerate() {
let sender: &str = if msg.is_sent {
&user_name
} else {
&msg.contact
};
let preview = msg.body.chars().take(60).collect::<String>();
println!(" {}. {}: {}...", i + 1, sender, preview);
}
if messages.len() > 3 {
println!(" ... and {} more", messages.len() - 3);
}
println!();
}
let (prompt, system_prompt) = build_daily_summary_prompt(&args.contact, date, messages);
println!("Generating summary...");
let summary = ollama.generate(&prompt, Some(system_prompt)).await?;
println!("\n📝 GENERATED SUMMARY:");
println!("─────────────────────────────────────────");
println!("{}", summary.trim());
println!("─────────────────────────────────────────");
if !args.test_mode {
println!("\nStripping boilerplate for embedding...");
let stripped = strip_summary_boilerplate(&summary);
println!(
"Stripped: {}...",
stripped.chars().take(80).collect::<String>()
);
println!("\nGenerating embedding...");
let embedding = ollama.generate_embedding(&stripped).await?;
println!("✓ Embedding generated ({} dimensions)", embedding.len());
println!("Saving to database...");
let insert = InsertDailySummary {
date: date_str.clone(),
contact: args.contact.clone(),
summary: summary.trim().to_string(),
message_count: messages.len() as i32,
embedding,
created_at: chrono::Utc::now().timestamp(),
model_version: EMBEDDING_MODEL.to_string(),
};
let mut dao = summary_dao.lock().expect("Unable to lock DailySummaryDao");
let context = opentelemetry::Context::new();
match dao.store_summary(&context, insert) {
Ok(_) => println!("✓ Saved to database"),
Err(e) => println!("✗ Database error: {:?}", e),
}
} else {
println!("\n⚠ TEST MODE: Not saved to database");
}
println!();
// Rate limiting between days
if idx < dates.len() - 1 {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
}
println!("========================================");
println!("✓ Complete!");
println!("Processed {} days", dates.len());
println!("========================================");
Ok(())
}