diff --git a/.gitignore b/.gitignore index ae33b4c..1437451 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ database/target .idea/dataSources.local.xml # Editor-based HTTP Client requests .idea/httpRequests/ +/.claude/settings.local.json diff --git a/Cargo.lock b/Cargo.lock index b235d7c..3d8173f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -646,9 +646,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.35" +version = "1.2.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" +checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203" dependencies = [ "find-msvc-tools", "jobserver", @@ -694,7 +694,7 @@ dependencies = [ "js-sys", "num-traits", "wasm-bindgen", - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -783,6 +783,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1137,9 +1147,9 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.0" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" +checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff" [[package]] name = "flate2" @@ -1163,6 +1173,21 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -1469,6 +1494,22 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http 1.3.1", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "hyper-timeout" version = "0.5.2" @@ -1482,6 +1523,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.16" @@ -1501,9 +1558,11 @@ dependencies = [ "percent-encoding", "pin-project-lite", "socket2 0.6.0", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] @@ -1698,10 +1757,12 @@ dependencies = [ "rand 0.8.5", "rayon", "regex", + "reqwest", "serde", "serde_json", "tempfile", "tokio", + "urlencoding", "walkdir", ] @@ -2070,6 +2131,23 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13d2233c9842d08cfe13f9eac96e207ca6a2ea10b80259ebe8ad0268be27d2af" +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -2181,6 +2259,50 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "opentelemetry" version = "0.31.0" @@ -2744,23 +2866,31 @@ checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" dependencies = [ "base64", "bytes", + "encoding_rs", "futures-channel", "futures-core", "futures-util", + "h2 0.4.12", "http 1.3.1", "http-body", "http-body-util", "hyper", + "hyper-rustls", + "hyper-tls", "hyper-util", "js-sys", "log", + "mime", + "native-tls", "percent-encoding", "pin-project-lite", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-native-tls", "tower", "tower-http", "tower-service", @@ -2818,6 +2948,39 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "rustls" +version = "0.23.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -2839,12 +3002,44 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.26" @@ -3077,6 +3272,27 @@ dependencies = [ "syn", ] +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "system-deps" version = "6.2.2" @@ -3208,6 +3424,26 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -3363,9 +3599,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "bitflags", "bytes", @@ -3477,6 +3713,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -3688,7 +3930,7 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-link", + "windows-link 0.1.3", "windows-result", "windows-strings", ] @@ -3721,13 +3963,30 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e" +dependencies = [ + "windows-link 0.1.3", + "windows-result", + "windows-strings", +] + [[package]] name = "windows-result" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -3736,7 +3995,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -3766,6 +4025,15 @@ dependencies = [ "windows-targets 0.53.3", ] +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link 0.2.1", +] + [[package]] name = "windows-targets" version = "0.52.6" @@ -3788,7 +4056,7 @@ version = "0.53.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" dependencies = [ - "windows-link", + "windows-link 0.1.3", "windows_aarch64_gnullvm 0.53.0", "windows_aarch64_msvc 0.53.0", "windows_i686_gnu 0.53.0", diff --git a/Cargo.toml b/Cargo.toml index f31b722..35043cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,3 +49,5 @@ opentelemetry-appender-log = "0.31.0" tempfile = "3.20.0" regex = "1.11.1" exif = { package = "kamadak-exif", version = "0.6.1" } +reqwest = { version = "0.12", features = ["json"] } +urlencoding = "2.1" diff --git a/migrations/2025-12-31-000000_add_ai_insights/down.sql b/migrations/2025-12-31-000000_add_ai_insights/down.sql new file mode 100644 index 0000000..6064840 --- /dev/null +++ b/migrations/2025-12-31-000000_add_ai_insights/down.sql @@ -0,0 +1,3 @@ +-- Rollback AI insights table +DROP INDEX IF EXISTS idx_photo_insights_path; +DROP TABLE IF EXISTS photo_insights; diff --git a/migrations/2025-12-31-000000_add_ai_insights/up.sql b/migrations/2025-12-31-000000_add_ai_insights/up.sql new file mode 100644 index 0000000..81d4849 --- /dev/null +++ b/migrations/2025-12-31-000000_add_ai_insights/up.sql @@ -0,0 +1,11 @@ +-- AI-generated insights for individual photos +CREATE TABLE IF NOT EXISTS photo_insights ( + id INTEGER PRIMARY KEY NOT NULL, + file_path TEXT NOT NULL UNIQUE, -- Full path to the photo + title TEXT NOT NULL, -- "At the beach with Sarah" + summary TEXT NOT NULL, -- 2-3 sentence description + generated_at BIGINT NOT NULL, + model_version TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_photo_insights_path ON photo_insights(file_path); diff --git a/src/ai/handlers.rs b/src/ai/handlers.rs new file mode 100644 index 0000000..3b74a49 --- /dev/null +++ b/src/ai/handlers.rs @@ -0,0 +1,154 @@ +use actix_web::{HttpResponse, Responder, delete, get, post, web}; +use serde::{Deserialize, Serialize}; + +use crate::ai::InsightGenerator; +use crate::data::Claims; +use crate::database::InsightDao; + +#[derive(Debug, Deserialize)] +pub struct GeneratePhotoInsightRequest { + pub file_path: String, +} + +#[derive(Debug, Deserialize)] +pub struct GetPhotoInsightQuery { + pub path: String, +} + +#[derive(Debug, Serialize)] +pub struct PhotoInsightResponse { + pub id: i32, + pub file_path: String, + pub title: String, + pub summary: String, + pub generated_at: i64, + pub model_version: String, +} + +/// POST /insights/generate - Generate insight for a specific photo +#[post("/insights/generate")] +pub async fn generate_insight_handler( + _claims: Claims, + request: web::Json, + insight_generator: web::Data, +) -> impl Responder { + log::info!( + "Manual insight generation triggered for photo: {}", + request.file_path + ); + + // Generate insight + match insight_generator + .generate_insight_for_photo(&request.file_path) + .await + { + Ok(()) => HttpResponse::Ok().json(serde_json::json!({ + "success": true, + "message": "Insight generated successfully" + })), + Err(e) => { + log::error!("Failed to generate insight: {:?}", e); + HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to generate insight: {:?}", e) + })) + } + } +} + +/// GET /insights?path=/path/to/photo.jpg - Fetch insight for specific photo +#[get("/insights")] +pub async fn get_insight_handler( + _claims: Claims, + query: web::Query, + insight_dao: web::Data>>, +) -> impl Responder { + log::debug!("Fetching insight for {}", query.path); + + let otel_context = opentelemetry::Context::new(); + let mut dao = insight_dao.lock().expect("Unable to lock InsightDao"); + + match dao.get_insight(&otel_context, &query.path) { + Ok(Some(insight)) => { + let response = PhotoInsightResponse { + id: insight.id, + file_path: insight.file_path, + title: insight.title, + summary: insight.summary, + generated_at: insight.generated_at, + model_version: insight.model_version, + }; + HttpResponse::Ok().json(response) + } + Ok(None) => HttpResponse::NotFound().json(serde_json::json!({ + "error": "Insight not found" + })), + Err(e) => { + log::error!("Failed to fetch insight ({}): {:?}", &query.path, e); + HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to fetch insight: {:?}", e) + })) + } + } +} + +/// DELETE /insights?path=/path/to/photo.jpg - Remove insight (will regenerate on next request) +#[delete("/insights")] +pub async fn delete_insight_handler( + _claims: Claims, + query: web::Query, + insight_dao: web::Data>>, +) -> impl Responder { + log::info!("Deleting insight for {}", query.path); + + let otel_context = opentelemetry::Context::new(); + let mut dao = insight_dao.lock().expect("Unable to lock InsightDao"); + + match dao.delete_insight(&otel_context, &query.path) { + Ok(()) => HttpResponse::Ok().json(serde_json::json!({ + "success": true, + "message": "Insight deleted successfully" + })), + Err(e) => { + log::error!("Failed to delete insight: {:?}", e); + HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to delete insight: {:?}", e) + })) + } + } +} + +/// GET /insights/all - Get all insights +#[get("/insights/all")] +pub async fn get_all_insights_handler( + _claims: Claims, + insight_dao: web::Data>>, +) -> impl Responder { + log::debug!("Fetching all insights"); + + let otel_context = opentelemetry::Context::new(); + let mut dao = insight_dao.lock().expect("Unable to lock InsightDao"); + + match dao.get_all_insights(&otel_context) { + Ok(insights) => { + let responses: Vec = insights + .into_iter() + .map(|insight| PhotoInsightResponse { + id: insight.id, + file_path: insight.file_path, + title: insight.title, + summary: insight.summary, + generated_at: insight.generated_at, + model_version: insight.model_version, + }) + .collect(); + + HttpResponse::Ok().json(responses) + } + Err(e) => { + log::error!("Failed to fetch all insights: {:?}", e); + HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to fetch insights: {:?}", e) + })) + } + } +} diff --git a/src/ai/insight_generator.rs b/src/ai/insight_generator.rs new file mode 100644 index 0000000..9c1cac9 --- /dev/null +++ b/src/ai/insight_generator.rs @@ -0,0 +1,239 @@ +use anyhow::Result; +use chrono::Utc; +use serde::Deserialize; +use std::sync::{Arc, Mutex}; + +use crate::ai::ollama::OllamaClient; +use crate::ai::sms_client::SmsApiClient; +use crate::database::models::InsertPhotoInsight; +use crate::database::{ExifDao, InsightDao}; +use crate::memories::extract_date_from_filename; + +#[derive(Deserialize)] +struct NominatimResponse { + display_name: Option, + address: Option, +} + +#[derive(Deserialize)] +struct NominatimAddress { + city: Option, + town: Option, + village: Option, + county: Option, + state: Option, + country: Option, +} + +#[derive(Clone)] +pub struct InsightGenerator { + ollama: OllamaClient, + sms_client: SmsApiClient, + insight_dao: Arc>>, + exif_dao: Arc>>, +} + +impl InsightGenerator { + pub fn new( + ollama: OllamaClient, + sms_client: SmsApiClient, + insight_dao: Arc>>, + exif_dao: Arc>>, + ) -> Self { + Self { + ollama, + sms_client, + insight_dao, + exif_dao, + } + } + + /// Extract contact name from file path + /// e.g., "Sarah/img.jpeg" -> Some("Sarah") + /// e.g., "img.jpeg" -> None + fn extract_contact_from_path(file_path: &str) -> Option { + use std::path::Path; + + let path = Path::new(file_path); + let components: Vec<_> = path.components().collect(); + + // If path has at least 2 components (directory + file), extract first directory + if components.len() >= 2 { + if let Some(component) = components.first() { + if let Some(os_str) = component.as_os_str().to_str() { + return Some(os_str.to_string()); + } + } + } + + None + } + + /// Generate AI insight for a single photo + pub async fn generate_insight_for_photo(&self, file_path: &str) -> Result<()> { + log::info!("Generating insight for photo: {}", file_path); + + // 1. Get EXIF data for the photo + let otel_context = opentelemetry::Context::new(); + let exif = { + let mut exif_dao = self.exif_dao.lock().expect("Unable to lock ExifDao"); + exif_dao + .get_exif(&otel_context, file_path) + .map_err(|e| anyhow::anyhow!("Failed to get EXIF: {:?}", e))? + }; + + // Get full timestamp for proximity-based message filtering + let timestamp = if let Some(ts) = exif.as_ref().and_then(|e| e.date_taken) { + ts + } else { + log::warn!("No date_taken in EXIF for {}, trying filename", file_path); + + extract_date_from_filename(file_path) + .map(|dt| dt.timestamp()) + .unwrap_or_else(|| Utc::now().timestamp()) + }; + + let date_taken = chrono::DateTime::from_timestamp(timestamp, 0) + .map(|dt| dt.date_naive()) + .unwrap_or_else(|| Utc::now().date_naive()); + + // 3. Extract contact name from file path + let contact = Self::extract_contact_from_path(file_path); + log::info!("Extracted contact from path: {:?}", contact); + + // 4. Fetch SMS messages for the contact (±1 day) + // Pass the full timestamp for proximity sorting + let sms_messages = self + .sms_client + .fetch_messages_for_contact(contact.as_deref(), timestamp) + .await + .unwrap_or_else(|e| { + log::error!("Failed to fetch SMS messages: {}", e); + Vec::new() + }); + + log::info!( + "Fetched {} SMS messages closest to {}", + sms_messages.len(), + chrono::DateTime::from_timestamp(timestamp, 0) + .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string()) + .unwrap_or_else(|| "unknown time".to_string()) + ); + + // 5. Summarize SMS context + let sms_summary = if !sms_messages.is_empty() { + match self + .sms_client + .summarize_context(&sms_messages, &self.ollama) + .await + { + Ok(summary) => Some(summary), + Err(e) => { + log::warn!("Failed to summarize SMS context: {}", e); + None + } + } + } else { + None + }; + + // 6. Get location name from GPS coordinates + let location = match exif { + Some(exif) => { + if let (Some(lat), Some(lon)) = (exif.gps_latitude, exif.gps_longitude) { + self.reverse_geocode(lat, lon).await + } else { + None + } + } + None => None, + }; + + log::info!( + "Photo context: date={}, location={:?}, sms_messages={}", + date_taken, + location, + sms_messages.len() + ); + + // 7. Generate title and summary with Ollama + let title = self + .ollama + .generate_photo_title(date_taken, location.as_deref(), sms_summary.as_deref()) + .await?; + + let summary = self + .ollama + .generate_photo_summary(date_taken, location.as_deref(), sms_summary.as_deref()) + .await?; + + log::info!("Generated title: {}", title); + log::info!("Generated summary: {}", summary); + + // 8. Store in database + let insight = InsertPhotoInsight { + file_path: file_path.to_string(), + title, + summary, + generated_at: Utc::now().timestamp(), + model_version: self.ollama.model.clone(), + }; + + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.store_insight(&otel_context, insight) + .map_err(|e| anyhow::anyhow!("Failed to store insight: {:?}", e))?; + + log::info!("Successfully stored insight for {}", file_path); + Ok(()) + } + + /// Reverse geocode GPS coordinates to human-readable place names + async fn reverse_geocode(&self, lat: f64, lon: f64) -> Option { + let url = format!( + "https://nominatim.openstreetmap.org/reverse?format=json&lat={}&lon={}", + lat, lon + ); + + let client = reqwest::Client::new(); + let response = client + .get(&url) + .header("User-Agent", "ImageAPI/1.0") // Nominatim requires User-Agent + .send() + .await + .ok()?; + + if !response.status().is_success() { + log::warn!( + "Geocoding failed for {}, {}: {}", + lat, + lon, + response.status() + ); + return None; + } + + let data: NominatimResponse = response.json().await.ok()?; + + // Try to build a concise location name + if let Some(addr) = data.address { + let mut parts = Vec::new(); + + // Prefer city/town/village + if let Some(city) = addr.city.or(addr.town).or(addr.village) { + parts.push(city); + } + + // Add state if available + if let Some(state) = addr.state { + parts.push(state); + } + + if !parts.is_empty() { + return Some(parts.join(", ")); + } + } + + // Fallback to display_name if structured address not available + data.display_name + } +} diff --git a/src/ai/mod.rs b/src/ai/mod.rs new file mode 100644 index 0000000..be1fb05 --- /dev/null +++ b/src/ai/mod.rs @@ -0,0 +1,11 @@ +pub mod handlers; +pub mod insight_generator; +pub mod ollama; +pub mod sms_client; + +pub use handlers::{ + delete_insight_handler, generate_insight_handler, get_all_insights_handler, get_insight_handler, +}; +pub use insight_generator::InsightGenerator; +pub use ollama::OllamaClient; +pub use sms_client::SmsApiClient; diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs new file mode 100644 index 0000000..0c81028 --- /dev/null +++ b/src/ai/ollama.rs @@ -0,0 +1,173 @@ +use anyhow::Result; +use chrono::NaiveDate; +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +use crate::memories::MemoryItem; + +#[derive(Clone)] +pub struct OllamaClient { + client: Client, + pub base_url: String, + pub model: String, +} + +impl OllamaClient { + pub fn new(base_url: String, model: String) -> Self { + Self { + client: Client::new(), + base_url, + model, + } + } + + /// Extract final answer from thinking model output + /// Handles ... tags and takes everything after + fn extract_final_answer(&self, response: &str) -> String { + let response = response.trim(); + + // Look for tag and take everything after it + if let Some(pos) = response.find("") { + let answer = response[pos + 8..].trim(); + if !answer.is_empty() { + return answer.to_string(); + } + } + + // Fallback: return the whole response trimmed + response.to_string() + } + + pub async fn generate(&self, prompt: &str, system: Option<&str>) -> Result { + log::debug!("=== Ollama Request ==="); + log::debug!("Model: {}", self.model); + if let Some(sys) = system { + log::debug!("System: {}", sys); + } + log::debug!("Prompt:\n{}", prompt); + log::debug!("====================="); + + let request = OllamaRequest { + model: self.model.clone(), + prompt: prompt.to_string(), + stream: false, + system: system.map(|s| s.to_string()), + }; + + let response = self + .client + .post(&format!("{}/api/generate", self.base_url)) + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let error_body = response.text().await.unwrap_or_default(); + log::error!("Ollama request failed: {} - {}", status, error_body); + return Err(anyhow::anyhow!( + "Ollama request failed: {} - {}", + status, + error_body + )); + } + + let result: OllamaResponse = response.json().await?; + + log::debug!("=== Ollama Response ==="); + log::debug!("Raw response: {}", result.response.trim()); + log::debug!("======================="); + + // Extract final answer from thinking model output + let cleaned = self.extract_final_answer(&result.response); + + log::debug!("=== Cleaned Response ==="); + log::debug!("Final answer: {}", cleaned); + log::debug!("========================"); + + Ok(cleaned) + } + + /// Generate a title for a single photo based on its context + pub async fn generate_photo_title( + &self, + date: NaiveDate, + location: Option<&str>, + sms_summary: Option<&str>, + ) -> Result { + let location_str = location.unwrap_or("Unknown location"); + let sms_str = sms_summary.unwrap_or("No messages"); + + let prompt = format!( + r#"Create a short title (maximum 8 words) for this photo: + +Date: {} +Location: {} +Messages: {} + +Use specific details from the context above. If no specific details are available, use a simple descriptive title. + +Return ONLY the title, nothing else."#, + date.format("%B %d, %Y"), + location_str, + sms_str + ); + + let system = + "You are a memory assistant. Use only the information provided. Do not invent details."; + + let title = self.generate(&prompt, Some(system)).await?; + Ok(title.trim().trim_matches('"').to_string()) + } + + /// Generate a summary for a single photo based on its context + pub async fn generate_photo_summary( + &self, + date: NaiveDate, + location: Option<&str>, + sms_summary: Option<&str>, + ) -> Result { + let location_str = location.unwrap_or("somewhere"); + let sms_str = sms_summary.unwrap_or("No messages"); + + let prompt = format!( + r#"Write a brief 1-2 paragraph description of this moment based on the available information: + +Date: {} +Location: {} +Messages: {} + +Use only the specific details provided above. Mention people's names, places, or activities if they appear in the context. Write in first person as Cam in a casual but fluent tone. If limited information is available, keep it simple and factual. If the location is unknown omit it"#, + date.format("%B %d, %Y"), + location_str, + sms_str + ); + + let system = "You are a memory refreshing assistant. Use only the information provided. Do not invent details. Help me remember this day."; + + self.generate(&prompt, Some(system)).await + } + +} + +pub struct MemoryContext { + pub date: NaiveDate, + pub photos: Vec, + pub sms_summary: Option, + pub locations: Vec, + pub cameras: Vec, +} + +#[derive(Serialize)] +struct OllamaRequest { + model: String, + prompt: String, + stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, +} + +#[derive(Deserialize)] +struct OllamaResponse { + response: String, +} diff --git a/src/ai/sms_client.rs b/src/ai/sms_client.rs new file mode 100644 index 0000000..154dabc --- /dev/null +++ b/src/ai/sms_client.rs @@ -0,0 +1,220 @@ +use anyhow::Result; +use chrono::NaiveDate; +use reqwest::Client; +use serde::Deserialize; + +use super::ollama::OllamaClient; + +#[derive(Clone)] +pub struct SmsApiClient { + client: Client, + base_url: String, + token: Option, +} + +impl SmsApiClient { + pub fn new(base_url: String, token: Option) -> Self { + Self { + client: Client::new(), + base_url, + token, + } + } + + pub async fn fetch_messages_for_date(&self, date: NaiveDate) -> Result> { + // Calculate date range (midnight to midnight in local time) + let start = date + .and_hms_opt(0, 0, 0) + .ok_or_else(|| anyhow::anyhow!("Invalid start time"))?; + let end = date + .and_hms_opt(23, 59, 59) + .ok_or_else(|| anyhow::anyhow!("Invalid end time"))?; + + let start_ts = start.and_utc().timestamp(); + let end_ts = end.and_utc().timestamp(); + + self.fetch_messages(start_ts, end_ts, None, None).await + } + + /// Fetch messages for a specific contact within ±1 day of the given timestamp + /// Falls back to all contacts if no messages found for the specific contact + /// Messages are sorted by proximity to the center timestamp + pub async fn fetch_messages_for_contact( + &self, + contact: Option<&str>, + center_timestamp: i64, + ) -> Result> { + use chrono::Duration; + + // Calculate ±1 day range around the center timestamp + let center_dt = chrono::DateTime::from_timestamp(center_timestamp, 0) + .ok_or_else(|| anyhow::anyhow!("Invalid timestamp"))?; + + let start_dt = center_dt - Duration::days(1); + let end_dt = center_dt + Duration::days(1); + + let start_ts = start_dt.timestamp(); + let end_ts = end_dt.timestamp(); + + // If contact specified, try fetching for that contact first + if let Some(contact_name) = contact { + log::info!( + "Fetching SMS for contact: {} (±1 day from {})", + contact_name, + center_dt.format("%Y-%m-%d %H:%M:%S") + ); + let messages = self + .fetch_messages(start_ts, end_ts, Some(contact_name), Some(center_timestamp)) + .await?; + + if !messages.is_empty() { + log::info!( + "Found {} messages for contact {}", + messages.len(), + contact_name + ); + return Ok(messages); + } + + log::info!( + "No messages found for contact {}, falling back to all contacts", + contact_name + ); + } + + // Fallback to all contacts + log::info!( + "Fetching all SMS messages (±1 day from {})", + center_dt.format("%Y-%m-%d %H:%M:%S") + ); + self.fetch_messages(start_ts, end_ts, None, Some(center_timestamp)) + .await + } + + /// Internal method to fetch messages with optional contact filter and timestamp sorting + async fn fetch_messages( + &self, + start_ts: i64, + end_ts: i64, + contact: Option<&str>, + center_timestamp: Option, + ) -> Result> { + // Call Django endpoint + let mut url = format!( + "{}/api/messages/by-date-range/?start_date={}&end_date={}", + self.base_url, start_ts, end_ts + ); + + // Add contact filter if provided + if let Some(contact_name) = contact { + url.push_str(&format!("&contact={}", urlencoding::encode(contact_name))); + } + + // Add timestamp for proximity sorting if provided + if let Some(ts) = center_timestamp { + url.push_str(&format!("×tamp={}", ts)); + } + + log::debug!("Fetching SMS messages from: {}", url); + + let mut request = self.client.get(&url); + + // Add authorization header if token exists + if let Some(token) = &self.token { + request = request.header("Authorization", format!("Bearer {}", token)); + } + + let response = request.send().await?; + + log::debug!("SMS API response status: {}", response.status()); + + if !response.status().is_success() { + let status = response.status(); + let error_body = response.text().await.unwrap_or_default(); + log::error!("SMS API request failed: {} - {}", status, error_body); + return Err(anyhow::anyhow!( + "SMS API request failed: {} - {}", + status, + error_body + )); + } + + let data: SmsApiResponse = response.json().await?; + + // Convert to internal format + Ok(data + .messages + .into_iter() + .map(|m| SmsMessage { + contact: m.contact_name, + body: m.body, + timestamp: m.date, + is_sent: m.type_ == 2, // type 2 = sent + }) + .collect()) + } + + pub async fn summarize_context( + &self, + messages: &[SmsMessage], + ollama: &OllamaClient, + ) -> Result { + if messages.is_empty() { + return Ok(String::from("No messages on this day")); + } + + // Create prompt for Ollama with sender/receiver distinction + let messages_text: String = messages + .iter() + .take(60) // Limit to avoid token overflow + .map(|m| { + if m.is_sent { + format!("Me: {}", m.body) + } else { + format!("{}: {}", m.contact, m.body) + } + }) + .collect::>() + .join("\n"); + + let prompt = format!( + r#"Summarize these messages in up to 4-5 sentences. Focus on key topics, places, people mentioned, and the overall context of the conversations. + +Messages: +{} + +Summary:"#, + messages_text + ); + + ollama + .generate( + &prompt, + // Some("You are a summarizer for the purposes of jogging my memory and highlighting events and situations."), + Some("You are the keeper of memories, ingest the context and give me a casual summary of the moment."), + ) + .await + } +} + +#[derive(Debug, Clone)] +pub struct SmsMessage { + pub contact: String, + pub body: String, + pub timestamp: i64, + pub is_sent: bool, +} + +#[derive(Deserialize)] +struct SmsApiResponse { + messages: Vec, +} + +#[derive(Deserialize)] +struct SmsApiMessage { + contact_name: String, + body: String, + date: i64, + #[serde(rename = "type")] + type_: i32, +} diff --git a/src/database/insights_dao.rs b/src/database/insights_dao.rs new file mode 100644 index 0000000..1efa9f3 --- /dev/null +++ b/src/database/insights_dao.rs @@ -0,0 +1,133 @@ +use diesel::prelude::*; +use diesel::sqlite::SqliteConnection; +use std::ops::DerefMut; +use std::sync::{Arc, Mutex}; + +use crate::database::models::{InsertPhotoInsight, PhotoInsight}; +use crate::database::schema; +use crate::database::{DbError, DbErrorKind, connect}; +use crate::otel::trace_db_call; + +pub trait InsightDao: Sync + Send { + fn store_insight( + &mut self, + context: &opentelemetry::Context, + insight: InsertPhotoInsight, + ) -> Result; + + fn get_insight( + &mut self, + context: &opentelemetry::Context, + file_path: &str, + ) -> Result, DbError>; + + fn delete_insight( + &mut self, + context: &opentelemetry::Context, + file_path: &str, + ) -> Result<(), DbError>; + + fn get_all_insights( + &mut self, + context: &opentelemetry::Context, + ) -> Result, DbError>; +} + +pub struct SqliteInsightDao { + connection: Arc>, +} + +impl Default for SqliteInsightDao { + fn default() -> Self { + Self::new() + } +} + +impl SqliteInsightDao { + pub fn new() -> Self { + SqliteInsightDao { + connection: Arc::new(Mutex::new(connect())), + } + } +} + +impl InsightDao for SqliteInsightDao { + fn store_insight( + &mut self, + context: &opentelemetry::Context, + insight: InsertPhotoInsight, + ) -> Result { + trace_db_call(context, "insert", "store_insight", |_span| { + use schema::photo_insights::dsl::*; + + let mut connection = self.connection.lock().expect("Unable to get InsightDao"); + + // Insert or replace on conflict (UNIQUE constraint on file_path) + diesel::replace_into(photo_insights) + .values(&insight) + .execute(connection.deref_mut()) + .map_err(|_| anyhow::anyhow!("Insert error"))?; + + // Retrieve the inserted record + photo_insights + .filter(file_path.eq(&insight.file_path)) + .first::(connection.deref_mut()) + .map_err(|_| anyhow::anyhow!("Query error")) + }) + .map_err(|_| DbError::new(DbErrorKind::InsertError)) + } + + fn get_insight( + &mut self, + context: &opentelemetry::Context, + path: &str, + ) -> Result, DbError> { + trace_db_call(context, "query", "get_insight", |_span| { + use schema::photo_insights::dsl::*; + + let mut connection = self.connection.lock().expect("Unable to get InsightDao"); + + photo_insights + .filter(file_path.eq(path)) + .first::(connection.deref_mut()) + .optional() + .map_err(|_| anyhow::anyhow!("Query error")) + }) + .map_err(|_| DbError::new(DbErrorKind::QueryError)) + } + + fn delete_insight( + &mut self, + context: &opentelemetry::Context, + path: &str, + ) -> Result<(), DbError> { + trace_db_call(context, "delete", "delete_insight", |_span| { + use schema::photo_insights::dsl::*; + + let mut connection = self.connection.lock().expect("Unable to get InsightDao"); + + diesel::delete(photo_insights.filter(file_path.eq(path))) + .execute(connection.deref_mut()) + .map(|_| ()) + .map_err(|_| anyhow::anyhow!("Delete error")) + }) + .map_err(|_| DbError::new(DbErrorKind::QueryError)) + } + + fn get_all_insights( + &mut self, + context: &opentelemetry::Context, + ) -> Result, DbError> { + trace_db_call(context, "query", "get_all_insights", |_span| { + use schema::photo_insights::dsl::*; + + let mut connection = self.connection.lock().expect("Unable to get InsightDao"); + + photo_insights + .order(generated_at.desc()) + .load::(connection.deref_mut()) + .map_err(|_| anyhow::anyhow!("Query error")) + }) + .map_err(|_| DbError::new(DbErrorKind::QueryError)) + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index f71d885..759d5f4 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -9,9 +9,12 @@ use crate::database::models::{ }; use crate::otel::trace_db_call; +pub mod insights_dao; pub mod models; pub mod schema; +pub use insights_dao::{InsightDao, SqliteInsightDao}; + pub trait UserDao { fn create_user(&mut self, user: &str, password: &str) -> Option; fn get_user(&mut self, user: &str, password: &str) -> Option; diff --git a/src/database/models.rs b/src/database/models.rs index 1d36206..9cee59b 100644 --- a/src/database/models.rs +++ b/src/database/models.rs @@ -1,4 +1,4 @@ -use crate::database::schema::{favorites, image_exif, users}; +use crate::database::schema::{favorites, image_exif, photo_insights, users}; use serde::Serialize; #[derive(Insertable)] @@ -73,3 +73,23 @@ pub struct ImageExif { pub created_time: i64, pub last_modified: i64, } + +#[derive(Insertable)] +#[diesel(table_name = photo_insights)] +pub struct InsertPhotoInsight { + pub file_path: String, + pub title: String, + pub summary: String, + pub generated_at: i64, + pub model_version: String, +} + +#[derive(Serialize, Queryable, Clone, Debug)] +pub struct PhotoInsight { + pub id: i32, + pub file_path: String, + pub title: String, + pub summary: String, + pub generated_at: i64, + pub model_version: String, +} diff --git a/src/database/schema.rs b/src/database/schema.rs index c0ca44c..aa9a93e 100644 --- a/src/database/schema.rs +++ b/src/database/schema.rs @@ -46,6 +46,17 @@ table! { } } +table! { + photo_insights (id) { + id -> Integer, + file_path -> Text, + title -> Text, + summary -> Text, + generated_at -> BigInt, + model_version -> Text, + } +} + table! { users (id) { id -> Integer, @@ -56,4 +67,11 @@ table! { joinable!(tagged_photo -> tags (tag_id)); -allow_tables_to_appear_in_same_query!(favorites, image_exif, tagged_photo, tags, users,); +allow_tables_to_appear_in_same_query!( + favorites, + image_exif, + photo_insights, + tagged_photo, + tags, + users, +); diff --git a/src/files.rs b/src/files.rs index b75b09a..4d8c86c 100644 --- a/src/files.rs +++ b/src/files.rs @@ -16,6 +16,7 @@ use crate::file_types; use crate::geo::{gps_bounding_box, haversine_distance}; use crate::memories::extract_date_from_filename; use crate::{AppState, create_thumbnails}; +use actix_web::dev::ResourcePath; use actix_web::web::Data; use actix_web::{ HttpRequest, HttpResponse, @@ -383,7 +384,14 @@ pub async fn list_photos( ) }) .map(|path: &PathBuf| { - let relative = path.strip_prefix(&app_state.base_path).unwrap(); + let relative = path.strip_prefix(&app_state.base_path).expect( + format!( + "Unable to strip base path {} from file path {}", + &app_state.base_path.path(), + path.display() + ) + .as_str(), + ); relative.to_path_buf() }) .map(|f| f.to_str().unwrap().to_string()) @@ -1018,10 +1026,11 @@ mod tests { let request: Query = Query::from_query("path=").unwrap(); + // Create AppState with the same base_path as RealFileSystem + let test_state = AppState::test_state(); + // Create a dedicated test directory to avoid interference from other files in system temp - let mut test_base = env::temp_dir(); - test_base.push("image_api_test_list_photos"); - fs::create_dir_all(&test_base).unwrap(); + let test_base = PathBuf::from(test_state.base_path.clone()); let mut test_dir = test_base.clone(); test_dir.push("test-dir"); @@ -1031,17 +1040,6 @@ mod tests { photo_path.push("photo.jpg"); File::create(&photo_path).unwrap(); - // Create AppState with the same base_path as RealFileSystem - use actix::Actor; - let test_state = AppState::new( - std::sync::Arc::new(crate::video::actors::StreamActor {}.start()), - test_base.to_str().unwrap().to_string(), - test_base.join("thumbnails").to_str().unwrap().to_string(), - test_base.join("videos").to_str().unwrap().to_string(), - test_base.join("gifs").to_str().unwrap().to_string(), - Vec::new(), - ); - let response: HttpResponse = list_photos( claims, TestRequest::default().to_http_request(), @@ -1049,9 +1047,7 @@ mod tests { Data::new(test_state), Data::new(RealFileSystem::new(test_base.to_str().unwrap().to_string())), Data::new(Mutex::new(SqliteTagDao::default())), - Data::new(Mutex::new( - Box::new(MockExifDao) as Box - )), + Data::new(Mutex::new(Box::new(MockExifDao) as Box)), ) .await; let status = response.status(); diff --git a/src/lib.rs b/src/lib.rs index 03760e2..61f1387 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #[macro_use] extern crate diesel; +pub mod ai; pub mod auth; pub mod cleanup; pub mod data; diff --git a/src/main.rs b/src/main.rs index 2d720e0..f90bdfc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,6 +31,7 @@ use chrono::Utc; use diesel::sqlite::Sqlite; use rayon::prelude::*; +use crate::ai::InsightGenerator; use crate::auth::login; use crate::data::*; use crate::database::models::InsertImageExif; @@ -50,6 +51,7 @@ use log::{debug, error, info, trace, warn}; use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer}; use opentelemetry::{KeyValue, global}; +mod ai; mod auth; mod data; mod database; @@ -715,7 +717,7 @@ fn main() -> std::io::Result<()> { } create_thumbnails(); - generate_video_gifs().await; + // generate_video_gifs().await; let app_data = Data::new(AppState::default()); @@ -744,6 +746,7 @@ fn main() -> std::io::Result<()> { let favorites_dao = SqliteFavoriteDao::new(); let tag_dao = SqliteTagDao::default(); let exif_dao = SqliteExifDao::new(); + let insight_dao = SqliteInsightDao::new(); let cors = Cors::default() .allowed_origin_fn(|origin, _req_head| { // Allow all origins in development, or check against CORS_ALLOWED_ORIGINS env var @@ -795,6 +798,10 @@ fn main() -> std::io::Result<()> { .service(delete_favorite) .service(get_file_metadata) .service(memories::list_memories) + .service(ai::generate_insight_handler) + .service(ai::get_insight_handler) + .service(ai::delete_insight_handler) + .service(ai::get_all_insights_handler) .add_feature(add_tag_services::<_, SqliteTagDao>) .app_data(app_data.clone()) .app_data::>(Data::new(RealFileSystem::new( @@ -808,6 +815,10 @@ fn main() -> std::io::Result<()> { .app_data::>>>(Data::new(Mutex::new(Box::new( exif_dao, )))) + .app_data::>>>(Data::new(Mutex::new(Box::new( + insight_dao, + )))) + .app_data::>(Data::new(app_data.insight_generator.clone())) .wrap(prometheus.clone()) }) .bind(dotenv::var("BIND_URL").unwrap())? diff --git a/src/state.rs b/src/state.rs index 8be3e73..5f7753f 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,9 @@ +use crate::ai::{InsightGenerator, OllamaClient, SmsApiClient}; +use crate::database::{ExifDao, InsightDao, SqliteExifDao, SqliteInsightDao}; use crate::video::actors::{PlaylistGenerator, StreamActor, VideoPlaylistManager}; use actix::{Actor, Addr}; -use std::{env, sync::Arc}; +use std::env; +use std::sync::{Arc, Mutex}; pub struct AppState { pub stream_manager: Arc>, @@ -10,6 +13,10 @@ pub struct AppState { pub video_path: String, pub gif_path: String, pub excluded_dirs: Vec, + pub ollama: OllamaClient, + pub sms_client: SmsApiClient, + pub insight_generator: InsightGenerator, + pub insight_dao: Arc>>, } impl AppState { @@ -20,6 +27,10 @@ impl AppState { video_path: String, gif_path: String, excluded_dirs: Vec, + ollama: OllamaClient, + sms_client: SmsApiClient, + insight_generator: InsightGenerator, + insight_dao: Arc>>, ) -> Self { let playlist_generator = PlaylistGenerator::new(); let video_playlist_manager = @@ -33,6 +44,10 @@ impl AppState { video_path, gif_path, excluded_dirs, + ollama, + sms_client, + insight_generator, + insight_dao, } } @@ -49,6 +64,31 @@ impl AppState { impl Default for AppState { fn default() -> Self { + // Initialize AI clients + let ollama_url = + env::var("OLLAMA_URL").unwrap_or_else(|_| "http://localhost:11434".to_string()); + let ollama_model = env::var("OLLAMA_MODEL").unwrap_or_else(|_| "llama3.2".to_string()); + let ollama = OllamaClient::new(ollama_url, ollama_model); + + let sms_api_url = + env::var("SMS_API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string()); + let sms_api_token = env::var("SMS_API_TOKEN").ok(); + let sms_client = SmsApiClient::new(sms_api_url, sms_api_token); + + // Initialize DAOs + let insight_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteInsightDao::new()))); + let exif_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteExifDao::new()))); + + // Initialize InsightGenerator + let insight_generator = InsightGenerator::new( + ollama.clone(), + sms_client.clone(), + insight_dao.clone(), + exif_dao.clone(), + ); + Self::new( Arc::new(StreamActor {}.start()), env::var("BASE_PATH").expect("BASE_PATH was not set in the env"), @@ -56,6 +96,10 @@ impl Default for AppState { env::var("VIDEO_PATH").expect("VIDEO_PATH was not set in the env"), env::var("GIFS_DIRECTORY").expect("GIFS_DIRECTORY was not set in the env"), Self::parse_excluded_dirs(), + ollama, + sms_client, + insight_generator, + insight_dao, ) } } @@ -74,14 +118,37 @@ impl AppState { let video_path = create_test_subdir(&base_path, "videos"); let gif_path = create_test_subdir(&base_path, "gifs"); + // Initialize test AI clients + let ollama = + OllamaClient::new("http://localhost:11434".to_string(), "llama3.2".to_string()); + let sms_client = SmsApiClient::new("http://localhost:8000".to_string(), None); + + // Initialize test DAOs + let insight_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteInsightDao::new()))); + let exif_dao: Arc>> = + Arc::new(Mutex::new(Box::new(SqliteExifDao::new()))); + + // Initialize test InsightGenerator + let insight_generator = InsightGenerator::new( + ollama.clone(), + sms_client.clone(), + insight_dao.clone(), + exif_dao.clone(), + ); + // Create the AppState with the temporary paths AppState::new( - std::sync::Arc::new(crate::video::actors::StreamActor {}.start()), + Arc::new(StreamActor {}.start()), base_path.to_string_lossy().to_string(), thumbnail_path.to_string_lossy().to_string(), video_path.to_string_lossy().to_string(), gif_path.to_string_lossy().to_string(), Vec::new(), // No excluded directories for test state + ollama, + sms_client, + insight_generator, + insight_dao, ) } }