279 lines
9.4 KiB
Rust
279 lines
9.4 KiB
Rust
use anyhow::Result;
|
|
use chrono::NaiveDate;
|
|
use clap::Parser;
|
|
use image_api::ai::{
|
|
EMBEDDING_MODEL, OllamaClient, SmsApiClient, build_daily_summary_prompt,
|
|
strip_summary_boilerplate, user_display_name,
|
|
};
|
|
use image_api::database::{DailySummaryDao, InsertDailySummary, SqliteDailySummaryDao};
|
|
use std::env;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
#[derive(Parser, Debug)]
|
|
#[command(author, version, about = "Test daily summary generation with different models and prompts", long_about = None)]
|
|
struct Args {
|
|
/// Contact name to generate summaries for
|
|
#[arg(short, long)]
|
|
contact: String,
|
|
|
|
/// Start date (YYYY-MM-DD)
|
|
#[arg(short, long)]
|
|
start: String,
|
|
|
|
/// End date (YYYY-MM-DD)
|
|
#[arg(short, long)]
|
|
end: String,
|
|
|
|
/// Optional: Override the model to use (e.g., "qwen2.5:32b", "llama3.1:30b")
|
|
#[arg(short, long)]
|
|
model: Option<String>,
|
|
|
|
/// Context window size passed as Ollama `num_ctx`. Omit for server default.
|
|
#[arg(long)]
|
|
num_ctx: Option<i32>,
|
|
|
|
/// Sampling temperature. Omit for server default.
|
|
#[arg(long)]
|
|
temperature: Option<f32>,
|
|
|
|
/// Top-p (nucleus) sampling. Omit for server default.
|
|
#[arg(long)]
|
|
top_p: Option<f32>,
|
|
|
|
/// Top-k sampling. Omit for server default.
|
|
#[arg(long)]
|
|
top_k: Option<i32>,
|
|
|
|
/// Min-p sampling. Omit for server default.
|
|
#[arg(long)]
|
|
min_p: Option<f32>,
|
|
|
|
/// Test mode: Generate but don't save to database (shows output only)
|
|
#[arg(short = 't', long, default_value_t = false)]
|
|
test_mode: bool,
|
|
|
|
/// Show message count and preview
|
|
#[arg(short, long, default_value_t = false)]
|
|
verbose: bool,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
// Load .env file
|
|
dotenv::dotenv().ok();
|
|
|
|
// Initialize logging
|
|
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
|
|
|
|
let args = Args::parse();
|
|
|
|
// Parse dates
|
|
let start_date = NaiveDate::parse_from_str(&args.start, "%Y-%m-%d")
|
|
.expect("Invalid start date format. Use YYYY-MM-DD");
|
|
let end_date = NaiveDate::parse_from_str(&args.end, "%Y-%m-%d")
|
|
.expect("Invalid end date format. Use YYYY-MM-DD");
|
|
|
|
println!("========================================");
|
|
println!("Daily Summary Generation Test Tool");
|
|
println!("========================================");
|
|
println!("Contact: {}", args.contact);
|
|
println!("Date range: {} to {}", start_date, end_date);
|
|
println!("Days: {}", (end_date - start_date).num_days() + 1);
|
|
if let Some(ref model) = args.model {
|
|
println!("Model: {}", model);
|
|
} else {
|
|
println!(
|
|
"Model: {} (from env)",
|
|
env::var("OLLAMA_PRIMARY_MODEL")
|
|
.or_else(|_| env::var("OLLAMA_MODEL"))
|
|
.unwrap_or_else(|_| "nemotron-3-nano:30b".to_string())
|
|
);
|
|
}
|
|
if args.test_mode {
|
|
println!("⚠ TEST MODE: Results will NOT be saved to database");
|
|
}
|
|
println!("========================================");
|
|
println!();
|
|
|
|
// Initialize AI clients
|
|
let ollama_primary_url = env::var("OLLAMA_PRIMARY_URL")
|
|
.or_else(|_| env::var("OLLAMA_URL"))
|
|
.unwrap_or_else(|_| "http://localhost:11434".to_string());
|
|
|
|
let ollama_fallback_url = env::var("OLLAMA_FALLBACK_URL").ok();
|
|
|
|
// Use provided model or fallback to env
|
|
let model_to_use = args.model.clone().unwrap_or_else(|| {
|
|
env::var("OLLAMA_PRIMARY_MODEL")
|
|
.or_else(|_| env::var("OLLAMA_MODEL"))
|
|
.unwrap_or_else(|_| "nemotron-3-nano:30b".to_string())
|
|
});
|
|
|
|
let mut ollama = OllamaClient::new(
|
|
ollama_primary_url,
|
|
ollama_fallback_url.clone(),
|
|
model_to_use.clone(),
|
|
Some(model_to_use), // Use same model for fallback
|
|
);
|
|
if let Some(ctx) = args.num_ctx {
|
|
ollama.set_num_ctx(Some(ctx));
|
|
}
|
|
if args.temperature.is_some()
|
|
|| args.top_p.is_some()
|
|
|| args.top_k.is_some()
|
|
|| args.min_p.is_some()
|
|
{
|
|
ollama.set_sampling_params(args.temperature, args.top_p, args.top_k, args.min_p);
|
|
}
|
|
|
|
// Surface what's actually configured so comparison runs are auditable.
|
|
println!(
|
|
"num_ctx={:?} temperature={:?} top_p={:?} top_k={:?} min_p={:?}",
|
|
args.num_ctx, args.temperature, args.top_p, args.top_k, args.min_p
|
|
);
|
|
|
|
let sms_api_url =
|
|
env::var("SMS_API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string());
|
|
let sms_api_token = env::var("SMS_API_TOKEN").ok();
|
|
let sms_client = SmsApiClient::new(sms_api_url, sms_api_token);
|
|
|
|
// Initialize DAO
|
|
let summary_dao: Arc<Mutex<Box<dyn DailySummaryDao>>> =
|
|
Arc::new(Mutex::new(Box::new(SqliteDailySummaryDao::new())));
|
|
|
|
// Fetch messages for contact
|
|
println!("Fetching messages for {}...", args.contact);
|
|
let all_messages = sms_client
|
|
.fetch_all_messages_for_contact(&args.contact)
|
|
.await?;
|
|
|
|
println!(
|
|
"Found {} total messages for {}",
|
|
all_messages.len(),
|
|
args.contact
|
|
);
|
|
println!();
|
|
|
|
// Filter to date range and group by date
|
|
let mut messages_by_date = std::collections::HashMap::new();
|
|
|
|
for msg in all_messages {
|
|
if let Some(dt) = chrono::DateTime::from_timestamp(msg.timestamp, 0) {
|
|
let date = dt.date_naive();
|
|
if date >= start_date && date <= end_date {
|
|
messages_by_date
|
|
.entry(date)
|
|
.or_insert_with(Vec::new)
|
|
.push(msg);
|
|
}
|
|
}
|
|
}
|
|
|
|
if messages_by_date.is_empty() {
|
|
println!("⚠ No messages found in date range");
|
|
return Ok(());
|
|
}
|
|
|
|
println!("Found {} days with messages", messages_by_date.len());
|
|
println!();
|
|
|
|
// Sort dates
|
|
let mut dates: Vec<NaiveDate> = messages_by_date.keys().cloned().collect();
|
|
dates.sort();
|
|
|
|
// Process each day
|
|
for (idx, date) in dates.iter().enumerate() {
|
|
let messages = messages_by_date.get(date).unwrap();
|
|
let date_str = date.format("%Y-%m-%d").to_string();
|
|
let weekday = date.format("%A");
|
|
|
|
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
println!(
|
|
"Day {}/{}: {} ({}) - {} messages",
|
|
idx + 1,
|
|
dates.len(),
|
|
date_str,
|
|
weekday,
|
|
messages.len()
|
|
);
|
|
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
|
|
if args.verbose {
|
|
let user_name = user_display_name();
|
|
println!("\nMessage preview:");
|
|
for (i, msg) in messages.iter().take(3).enumerate() {
|
|
let sender: &str = if msg.is_sent {
|
|
&user_name
|
|
} else {
|
|
&msg.contact
|
|
};
|
|
let preview = msg.body.chars().take(60).collect::<String>();
|
|
println!(" {}. {}: {}...", i + 1, sender, preview);
|
|
}
|
|
if messages.len() > 3 {
|
|
println!(" ... and {} more", messages.len() - 3);
|
|
}
|
|
println!();
|
|
}
|
|
|
|
let (prompt, system_prompt) = build_daily_summary_prompt(&args.contact, date, messages);
|
|
|
|
println!("Generating summary...");
|
|
|
|
let summary = ollama.generate(&prompt, Some(system_prompt)).await?;
|
|
|
|
println!("\n📝 GENERATED SUMMARY:");
|
|
println!("─────────────────────────────────────────");
|
|
println!("{}", summary.trim());
|
|
println!("─────────────────────────────────────────");
|
|
|
|
if !args.test_mode {
|
|
println!("\nStripping boilerplate for embedding...");
|
|
let stripped = strip_summary_boilerplate(&summary);
|
|
println!(
|
|
"Stripped: {}...",
|
|
stripped.chars().take(80).collect::<String>()
|
|
);
|
|
|
|
println!("\nGenerating embedding...");
|
|
let embedding = ollama.generate_embedding(&stripped).await?;
|
|
println!("✓ Embedding generated ({} dimensions)", embedding.len());
|
|
|
|
println!("Saving to database...");
|
|
let insert = InsertDailySummary {
|
|
date: date_str.clone(),
|
|
contact: args.contact.clone(),
|
|
summary: summary.trim().to_string(),
|
|
message_count: messages.len() as i32,
|
|
embedding,
|
|
created_at: chrono::Utc::now().timestamp(),
|
|
model_version: EMBEDDING_MODEL.to_string(),
|
|
};
|
|
|
|
let mut dao = summary_dao.lock().expect("Unable to lock DailySummaryDao");
|
|
let context = opentelemetry::Context::new();
|
|
|
|
match dao.store_summary(&context, insert) {
|
|
Ok(_) => println!("✓ Saved to database"),
|
|
Err(e) => println!("✗ Database error: {:?}", e),
|
|
}
|
|
} else {
|
|
println!("\n⚠ TEST MODE: Not saved to database");
|
|
}
|
|
|
|
println!();
|
|
|
|
// Rate limiting between days
|
|
if idx < dates.len() - 1 {
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
|
}
|
|
}
|
|
|
|
println!("========================================");
|
|
println!("✓ Complete!");
|
|
println!("Processed {} days", dates.len());
|
|
println!("========================================");
|
|
|
|
Ok(())
|
|
}
|