diff --git a/.gitignore b/.gitignore index 1437451..112d1e3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,15 @@ /target database/target *.db +*.db.bak .env /tmp # Default ignored files .idea/shelf/ .idea/workspace.xml +.idea/inspectionProfiles/ +.idea/markdown.xml # Datasource local storage ignored files .idea/dataSources* .idea/dataSources.local.xml diff --git a/CLAUDE.md b/CLAUDE.md index 5da2612..23bddf5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -169,6 +169,20 @@ POST /image/tags/batch (bulk tag updates) // Memories (week-based grouping) GET /memories?path=...&recursive=true + +// AI Insights +POST /insights/generate (non-agentic single-shot) +POST /insights/generate/agentic (tool-calling loop; body: { file_path, backend?, model?, ... }) +GET /insights?path=...&library=... +GET /insights/models (local Ollama models + capabilities) +GET /insights/openrouter/models (curated OpenRouter allowlist) +POST /insights/rate (thumbs up/down for training data) + +// Insight Chat Continuation +POST /insights/chat (single-turn reply, non-streaming) +POST /insights/chat/stream (SSE: text / tool_call / tool_result / truncated / done) +GET /insights/chat/history?path=... (rendered transcript with tool invocations) +POST /insights/chat/rewind (truncate transcript at a rendered index) ``` **Request Types:** @@ -256,8 +270,23 @@ OLLAMA_PRIMARY_URL=http://desktop:11434 # Primary Ollama server (e.g., de OLLAMA_FALLBACK_URL=http://server:11434 # Fallback Ollama server (optional, always-on) OLLAMA_PRIMARY_MODEL=nemotron-3-nano:30b # Model for primary server (default: nemotron-3-nano:30b) OLLAMA_FALLBACK_MODEL=llama3.2:3b # Model for fallback server (optional, uses primary if not set) +OLLAMA_REQUEST_TIMEOUT_SECONDS=120 # Per-request generation timeout (default 120). Increase for slow CPU-offloaded models. SMS_API_URL=http://localhost:8000 # SMS message API endpoint (default: localhost:8000) SMS_API_TOKEN=your-api-token # SMS API authentication token (optional) + +# OpenRouter (Hybrid Backend) - keeps embeddings + vision local, routes chat to OpenRouter +OPENROUTER_API_KEY=sk-or-... # Required to enable hybrid backend +OPENROUTER_DEFAULT_MODEL=anthropic/claude-sonnet-4 # Used when client doesn't pick a model +OPENROUTER_ALLOWED_MODELS=openai/gpt-4o-mini,anthropic/claude-haiku-4-5,google/gemini-2.5-flash + # Curated allowlist exposed to clients via + # GET /insights/openrouter/models. Empty = no picker. +OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 # Override base URL (optional) +OPENROUTER_EMBEDDING_MODEL=openai/text-embedding-3-small # Optional, embeddings stay local today +OPENROUTER_HTTP_REFERER=https://your-site.example # Optional attribution header +OPENROUTER_APP_TITLE=ImageApi # Optional attribution header + +# Insight Chat Continuation +AGENTIC_CHAT_MAX_ITERATIONS=6 # Cap on tool-calling iterations per chat turn (default 6) ``` **AI Insights Fallback Behavior:** @@ -275,6 +304,67 @@ The `OllamaClient` provides methods to query available models: This allows runtime verification of model availability before generating insights. +**Hybrid Backend (OpenRouter):** +- Per-request opt-in via `backend=hybrid` on `POST /insights/generate/agentic`. +- Local Ollama still describes the image (vision); the description is inlined + into the chat prompt and the agentic loop runs on OpenRouter. +- `request.model` (if provided) overrides `OPENROUTER_DEFAULT_MODEL` for that + call. The mobile picker reads from `OPENROUTER_ALLOWED_MODELS`. +- No live capability precheck — the operator-curated allowlist is trusted. + A bad model id surfaces as a chat-call error. +- `GET /insights/openrouter/models` returns `{ models, default_model, configured }` + for client picker UIs. + +**Insight Chat Continuation:** + +After an agentic insight is generated, the full `Vec` transcript is +stored in `photo_insights.training_messages` and can be continued via the +chat endpoints. The `PhotoInsightResponse.has_training_messages` flag tells +clients whether chat is available for a given insight. + +- `POST /insights/chat` runs one turn of the agentic loop against the replayed + history. Body: `{ file_path, library?, user_message, model?, backend?, num_ctx?, + temperature?, top_p?, top_k?, min_p?, max_iterations?, amend? }`. +- `POST /insights/chat/stream` is the SSE variant — same request body, response + is `text/event-stream` with events: `iteration_start`, `text` (delta), `tool_call`, + `tool_result`, `truncated`, `done`, plus a server-emitted `error_message` on + failure. Preferred by the mobile client for live tool-chip updates. +- `GET /insights/chat/history?path=...&library=...` returns the rendered + transcript. Each assistant message carries a `tools: [{name, arguments, result, + result_truncated?}]` array with the tool invocations that led up to it. Tool + results over 2000 chars are truncated with `result_truncated: true`. +- `POST /insights/chat/rewind` truncates the transcript at a given rendered + index (drops that message + any tool-call scaffolding that preceded it + all + later turns). Index 0 is protected. Used for "try again from here" flows. + +Backend routing rules (matches agentic-insight generation): +- Stored `backend` on the insight row is authoritative by default. +- `request.backend` may override per-turn. `local -> hybrid` is rejected in + v1 (would require on-the-fly visual-description rewrite); `hybrid -> local` + replays verbatim since the description is already inlined as text. +- `request.model` overrides the chat model (an Ollama id in local mode, an + OpenRouter id in hybrid mode). + +Persistence: +- Append mode (default): re-serialize the full history and `UPDATE` the same + row's `training_messages`. +- Amend mode (`amend: true`): regenerate the title, insert a new insight row + via `store_insight` (auto-flips prior rows' `is_current=false`). Response + surfaces the new row's id as `amended_insight_id`. + +Per-`(library_id, file_path)` async mutex (`AppState.insight_chat.chat_locks`) +serialises concurrent turns on the same insight so the JSON blob doesn't race. + +Context management is a soft bound: if the serialized history exceeds +`num_ctx - 2048` tokens (cheap 4-byte/token heuristic), the oldest +assistant-tool_call + tool_result pairs are dropped until under budget. The +initial user message (with any images) and system prompt are always preserved. +The `truncated` event / flag is surfaced to the client when a drop occurred. + +Configurable env: +- `AGENTIC_CHAT_MAX_ITERATIONS` — cap on tool-calling iterations per turn + (default 6). Per-request `max_iterations` is clamped to this cap. + ## Dependencies of Note - **actix-web**: HTTP framework diff --git a/Cargo.lock b/Cargo.lock index 4f04521..d4a65c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -486,6 +486,28 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -886,6 +908,12 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.6" @@ -1196,6 +1224,26 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "fax" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05de7d48f37cd6730705cbca900770cab77a89f413d23e100ad7fad7795a0ab" +dependencies = [ + "fax_derive", +] + +[[package]] +name = "fax_derive" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0aca10fb742cb43f9e7bb8467c91aa9bcb8e3ffbc6a6f7389bb93ffc920577d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "fdeflate" version = "0.3.7" @@ -1479,6 +1527,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -1821,11 +1880,14 @@ checksum = "1c6a3ce16143778e24df6f95365f12ed105425b22abefd289dd88a64bab59605" dependencies = [ "bytemuck", "byteorder-lite", + "image-webp", "moxcms", "num-traits", "png", "ravif", "rayon", + "rgb", + "tiff", "zune-core", "zune-jpeg", ] @@ -1843,9 +1905,12 @@ dependencies = [ "actix-web", "actix-web-prom", "anyhow", + "async-stream", + "async-trait", "base64", "bcrypt", "blake3", + "bytes", "chrono", "clap", "diesel", @@ -1877,11 +1942,22 @@ dependencies = [ "serde_json", "tempfile", "tokio", + "tokio-util", "urlencoding", "walkdir", "zerocopy", ] +[[package]] +name = "image-webp" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" +dependencies = [ + "byteorder-lite", + "quick-error", +] + [[package]] name = "imgref" version = "1.11.0" @@ -3124,12 +3200,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", ] @@ -3684,6 +3762,20 @@ dependencies = [ "syn", ] +[[package]] +name = "tiff" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af9605de7fee8d9551863fd692cce7637f548dbd9db9180fcc07ccc6d26c336f" +dependencies = [ + "fax", + "flate2", + "half", + "quick-error", + "weezl", + "zune-jpeg", +] + [[package]] name = "time" version = "0.3.42" @@ -4218,6 +4310,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" @@ -4228,6 +4333,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "weezl" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 1e606b0..847c9f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ chrono = "0.4" clap = { version = "4.5", features = ["derive"] } dotenv = "0.15" bcrypt = "0.17.1" -image = { version = "0.25.5", default-features = false, features = ["jpeg", "png", "rayon"] } +image = { version = "0.25.5", default-features = false, features = ["jpeg", "png", "rayon", "webp", "tiff", "avif"] } infer = "0.16" walkdir = "2.4.0" rayon = "1.5" @@ -49,10 +49,14 @@ opentelemetry-appender-log = "0.31.0" tempfile = "3.20.0" regex = "1.11.1" exif = { package = "kamadak-exif", version = "0.6.1" } -reqwest = { version = "0.12", features = ["json"] } +reqwest = { version = "0.12", features = ["json", "stream"] } +async-stream = "0.3" +tokio-util = { version = "0.7", features = ["io"] } +bytes = "1" urlencoding = "2.1" zerocopy = "0.8" ical = "0.11" scraper = "0.20" base64 = "0.22" blake3 = "1.5" +async-trait = "0.1" diff --git a/README.md b/README.md index c8f1c69..31978d9 100644 --- a/README.md +++ b/README.md @@ -14,14 +14,43 @@ Upon first run it will generate thumbnails for all images and videos at `BASE_PA - **RAG-based Context Retrieval** - Semantic search over daily conversation summaries - **Automatic Daily Summaries** - LLM-generated summaries of daily conversations with embeddings +## External Dependencies + +### ffmpeg (required) +`ffmpeg` must be on `PATH`. It is used for: +- **HLS video streaming** — transcoding/segmenting source videos into `.m3u8` + `.ts` playlists +- **Video thumbnails** — extracting a frame at the 3-second mark +- **Video preview clips** — short looping previews for the Video Wall +- **HEIC / HEIF thumbnails** — decoding Apple's HEIC format (your ffmpeg build must include + `libheif`; most modern builds do) + +Builds used in development: the `gyan.dev` full build on Windows, and distro `ffmpeg` +packages on Linux work fine. If HEIC thumbnails silently fail, check +`ffmpeg -formats | grep heif` to confirm HEIF support. + +### RAW photo thumbnails (no extra dependency) +RAW formats (ARW, NEF, CR2, CR3, DNG, RAF, ORF, RW2, PEF, SRW, TIFF) are thumbnailed +by reading the embedded JPEG preview from the TIFF IFD1 using `kamadak-exif`. No +external RAW decoder (libraw / dcraw) is required. Files without an embedded preview +fall back to ffmpeg (works for most NEF files), and anything that still can't be +decoded is marked with a `.unsupported` sentinel in the thumbnail directory +so we don't retry it every scan. Delete those sentinels to force retries after a +tooling upgrade. + ## Environment There are a handful of required environment variables to have the API run. They should be defined where the binary is located or above it in an `.env` file. -You must have `ffmpeg` installed for streaming video and generating video thumbnails. - `DATABASE_URL` is a path or url to a database (currently only SQLite is tested) - `BASE_PATH` is the root from which you want to serve images and videos -- `THUMBNAILS` is a path where generated thumbnails should be stored +- `THUMBNAILS` is a path where generated thumbnails should be stored. Thumbnails + mirror the source tree under `BASE_PATH` and keep the source's original + extension (e.g. `foo.arw` or `bar.mp4`), though the file contents are always + JPEG bytes — browsers content-sniff. Files that can't be thumbnailed by the + `image` crate, ffmpeg, or an embedded RAW preview get a zero-byte + `.unsupported` sentinel in this directory so subsequent scans + skip them. Delete the `*.unsupported` files to force retries (for example + after upgrading ffmpeg or adding libheif) - `VIDEO_PATH` is a path where HLS playlists and video parts should be stored - `GIFS_DIRECTORY` is a path where generated video GIF thumbnails should be stored - `BIND_URL` is the url and port to bind to (typically your own IP address) @@ -50,6 +79,29 @@ The following environment variables configure AI-powered photo insights and dail - `OLLAMA_URL` - Used if `OLLAMA_PRIMARY_URL` not set - `OLLAMA_MODEL` - Used if `OLLAMA_PRIMARY_MODEL` not set +#### OpenRouter Configuration (Hybrid Backend) +The hybrid agentic backend keeps embeddings + vision local (Ollama) while routing +chat + tool-calling to OpenRouter. Enabled per-request when the client sends +`backend=hybrid`. + +- `OPENROUTER_API_KEY` - OpenRouter API key. Required to enable the hybrid backend. +- `OPENROUTER_DEFAULT_MODEL` - Model id used when the client doesn't specify one + [default: `anthropic/claude-sonnet-4`] + - Example: `openai/gpt-4o-mini`, `google/gemini-2.5-flash` +- `OPENROUTER_ALLOWED_MODELS` - Comma-separated curated allowlist exposed to + clients via `GET /insights/openrouter/models`. The mobile picker shows only + these. Empty/unset = no picker, server default is used. + - Example: `openai/gpt-4o-mini,anthropic/claude-haiku-4-5,google/gemini-2.5-flash` +- `OPENROUTER_BASE_URL` - Override base URL [default: `https://openrouter.ai/api/v1`] +- `OPENROUTER_EMBEDDING_MODEL` - Embedding model for OpenRouter + [default: `openai/text-embedding-3-small`]. Only used if/when embeddings are + routed through OpenRouter (currently embeddings stay local). +- `OPENROUTER_HTTP_REFERER` - Optional `HTTP-Referer` for OpenRouter attribution +- `OPENROUTER_APP_TITLE` - Optional `X-Title` for OpenRouter attribution + +Capability checks are skipped for the curated allowlist — bad model ids surface +as a 4xx from the chat call. Pick tool-capable models. + #### SMS API Configuration - `SMS_API_URL` - URL to SMS message API [default: `http://localhost:8000`] - Used to fetch conversation data for context in insights @@ -60,6 +112,24 @@ The following environment variables configure AI-powered photo insights and dail - Controls how many times the model can invoke tools before being forced to produce a final answer - Increase for more thorough context gathering; decrease to limit response time +#### Insight Chat Continuation +After an agentic insight is generated, the conversation can be continued. Endpoints: +- `POST /insights/chat` — single-turn reply (non-streaming) +- `POST /insights/chat/stream` — SSE variant with live `text` deltas and + `tool_call` / `tool_result` events. Mobile client uses this. +- `GET /insights/chat/history?path=...&library=...` — rendered transcript; + each assistant message carries a `tools: [{name, arguments, result}]` array +- `POST /insights/chat/rewind` — truncate transcript at a rendered index + (drops that message + any preceding tool scaffolding + later turns). Used + for "try again from here" flows. The initial user message is protected. + +Amend mode (`amend: true` in the chat request body) regenerates the insight's +title and inserts a new row instead of appending to the existing transcript, +so you can rewrite the saved summary from within chat. + +- `AGENTIC_CHAT_MAX_ITERATIONS` - Cap on tool-calling iterations per chat turn [default: `6`] + - Per-request `max_iterations` (when sent by the client) is clamped to this cap + #### Fallback Behavior - Primary server is tried first with 5-second connection timeout - On failure, automatically falls back to secondary server (if configured) diff --git a/migrations/2026-04-20-000000_add_backend_to_insights/down.sql b/migrations/2026-04-20-000000_add_backend_to_insights/down.sql new file mode 100644 index 0000000..cb8864d --- /dev/null +++ b/migrations/2026-04-20-000000_add_backend_to_insights/down.sql @@ -0,0 +1,23 @@ +-- SQLite can't DROP COLUMN cleanly on older versions; rebuild the table. +CREATE TABLE photo_insights_backup AS + SELECT id, library_id, rel_path, title, summary, generated_at, model_version, + is_current, training_messages, approved + FROM photo_insights; +DROP TABLE photo_insights; +CREATE TABLE photo_insights ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + library_id INTEGER NOT NULL REFERENCES libraries(id), + rel_path TEXT NOT NULL, + title TEXT NOT NULL, + summary TEXT NOT NULL, + generated_at BIGINT NOT NULL, + model_version TEXT NOT NULL, + is_current BOOLEAN NOT NULL DEFAULT TRUE, + training_messages TEXT, + approved BOOLEAN +); +INSERT INTO photo_insights + SELECT id, library_id, rel_path, title, summary, generated_at, model_version, + is_current, training_messages, approved + FROM photo_insights_backup; +DROP TABLE photo_insights_backup; diff --git a/migrations/2026-04-20-000000_add_backend_to_insights/up.sql b/migrations/2026-04-20-000000_add_backend_to_insights/up.sql new file mode 100644 index 0000000..520c209 --- /dev/null +++ b/migrations/2026-04-20-000000_add_backend_to_insights/up.sql @@ -0,0 +1 @@ +ALTER TABLE photo_insights ADD COLUMN backend TEXT NOT NULL DEFAULT 'local'; diff --git a/migrations/2026-04-24-000000_add_fewshot_source_to_insights/down.sql b/migrations/2026-04-24-000000_add_fewshot_source_to_insights/down.sql new file mode 100644 index 0000000..2702414 --- /dev/null +++ b/migrations/2026-04-24-000000_add_fewshot_source_to_insights/down.sql @@ -0,0 +1,24 @@ +-- SQLite can't DROP COLUMN cleanly on older versions; rebuild the table. +CREATE TABLE photo_insights_backup AS + SELECT id, library_id, rel_path, title, summary, generated_at, model_version, + is_current, training_messages, approved, backend + FROM photo_insights; +DROP TABLE photo_insights; +CREATE TABLE photo_insights ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + library_id INTEGER NOT NULL REFERENCES libraries(id), + rel_path TEXT NOT NULL, + title TEXT NOT NULL, + summary TEXT NOT NULL, + generated_at BIGINT NOT NULL, + model_version TEXT NOT NULL, + is_current BOOLEAN NOT NULL DEFAULT TRUE, + training_messages TEXT, + approved BOOLEAN, + backend TEXT NOT NULL DEFAULT 'local' +); +INSERT INTO photo_insights + SELECT id, library_id, rel_path, title, summary, generated_at, model_version, + is_current, training_messages, approved, backend + FROM photo_insights_backup; +DROP TABLE photo_insights_backup; diff --git a/migrations/2026-04-24-000000_add_fewshot_source_to_insights/up.sql b/migrations/2026-04-24-000000_add_fewshot_source_to_insights/up.sql new file mode 100644 index 0000000..f39340c --- /dev/null +++ b/migrations/2026-04-24-000000_add_fewshot_source_to_insights/up.sql @@ -0,0 +1 @@ +ALTER TABLE photo_insights ADD COLUMN fewshot_source_ids TEXT; diff --git a/src/ai/daily_summary_job.rs b/src/ai/daily_summary_job.rs index 9d9c9e0..3fede71 100644 --- a/src/ai/daily_summary_job.rs +++ b/src/ai/daily_summary_job.rs @@ -6,12 +6,83 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use tokio::time::sleep; -use crate::ai::{OllamaClient, SmsApiClient, SmsMessage}; +use crate::ai::{EMBEDDING_MODEL, OllamaClient, SmsApiClient, SmsMessage, user_display_name}; use crate::database::{DailySummaryDao, InsertDailySummary}; use crate::otel::global_tracer; /// Strip boilerplate prefixes and common phrases from summaries before embedding. /// This improves embedding diversity by removing structural similarity. +/// Maximum number of messages passed to the summarizer for a single day. +/// Tuned to avoid token overflow on typical chat models; shared between +/// the production job and the test binary so they can't drift. +pub const DAILY_SUMMARY_MESSAGE_LIMIT: usize = 300; + +/// System prompt used when generating daily conversation summaries. +pub const DAILY_SUMMARY_SYSTEM_PROMPT: &str = "You are a conversation summarizer. Create clear, factual summaries with \ + precise subject attribution AND extract distinctive keywords. Focus on \ + specific, unique terms that differentiate this conversation from others."; + +/// Build the prompt for a single day's conversation summary. Shared by the +/// production job and the test binary so prompt tweaks land in both places. +/// Returns `(prompt, system_prompt)`. +pub fn build_daily_summary_prompt( + contact: &str, + date: &NaiveDate, + messages: &[SmsMessage], +) -> (String, &'static str) { + let user_name = user_display_name(); + let messages_text: String = messages + .iter() + .take(DAILY_SUMMARY_MESSAGE_LIMIT) + .map(|m| { + if m.is_sent { + format!("{}: {}", user_name, m.body) + } else { + format!("{}: {}", m.contact, m.body) + } + }) + .collect::>() + .join("\n"); + + let prompt = format!( + r#"Summarize this day's conversation between {user_name} and {contact}. + +CRITICAL FORMAT RULES: +- Do NOT start with "Based on the conversation..." or "Here is a summary..." or similar preambles +- Do NOT repeat the date at the beginning +- Start DIRECTLY with the content - begin with a person's name or action +- Write in past tense, as if recording what happened + +NARRATIVE (4-8 sentences): +- What specific topics, activities, or events were discussed? +- What places, people, or organizations were mentioned? +- What plans were made or decisions discussed? +- Clearly distinguish between what {user_name} did versus what {contact} did + +KEYWORDS (comma-separated): +5-10 specific keywords that capture this conversation's unique content: +- Proper nouns (people, places, brands) +- Specific activities ("drum corps audition" not just "music") +- Distinctive terms that make this day unique + +Date: {month_day_year} ({weekday}) +Messages: +{messages_text} + +YOUR RESPONSE (follow this format EXACTLY): +Summary: [Start directly with content, NO preamble] + +Keywords: [specific, unique terms]"#, + user_name = user_name, + contact = contact, + month_day_year = date.format("%B %d, %Y"), + weekday = date.format("%A"), + messages_text = messages_text, + ); + + (prompt, DAILY_SUMMARY_SYSTEM_PROMPT) +} + pub fn strip_summary_boilerplate(summary: &str) -> String { let mut text = summary.trim().to_string(); @@ -290,65 +361,10 @@ async fn generate_and_store_daily_summary( span.set_attribute(KeyValue::new("contact", contact.to_string())); span.set_attribute(KeyValue::new("message_count", messages.len() as i64)); - // Format messages for LLM - let messages_text: String = messages - .iter() - .take(200) // Limit to 200 messages per day to avoid token overflow - .map(|m| { - if m.is_sent { - format!("Me: {}", m.body) - } else { - format!("{}: {}", m.contact, m.body) - } - }) - .collect::>() - .join("\n"); - - let weekday = date.format("%A"); - - let prompt = format!( - r#"Summarize this day's conversation between me and {}. - -CRITICAL FORMAT RULES: -- Do NOT start with "Based on the conversation..." or "Here is a summary..." or similar preambles -- Do NOT repeat the date at the beginning -- Start DIRECTLY with the content - begin with a person's name or action -- Write in past tense, as if recording what happened - -NARRATIVE (3-5 sentences): -- What specific topics, activities, or events were discussed? -- What places, people, or organizations were mentioned? -- What plans were made or decisions discussed? -- Clearly distinguish between what "I" did versus what {} did - -KEYWORDS (comma-separated): -5-10 specific keywords that capture this conversation's unique content: -- Proper nouns (people, places, brands) -- Specific activities ("drum corps audition" not just "music") -- Distinctive terms that make this day unique - -Date: {} ({}) -Messages: -{} - -YOUR RESPONSE (follow this format EXACTLY): -Summary: [Start directly with content, NO preamble] - -Keywords: [specific, unique terms]"#, - contact, - contact, - date.format("%B %d, %Y"), - weekday, - messages_text - ); + let (prompt, system_prompt) = build_daily_summary_prompt(contact, date, messages); // Generate summary with LLM - let summary = ollama - .generate( - &prompt, - Some("You are a conversation summarizer. Create clear, factual summaries with precise subject attribution AND extract distinctive keywords. Focus on specific, unique terms that differentiate this conversation from others."), - ) - .await?; + let summary = ollama.generate(&prompt, Some(system_prompt)).await?; log::debug!( "Generated summary for {}: {}", @@ -381,8 +397,7 @@ Keywords: [specific, unique terms]"#, message_count: messages.len() as i32, embedding, created_at: Utc::now().timestamp(), - // model_version: "nomic-embed-text:v1.5".to_string(), - model_version: "mxbai-embed-large:335m".to_string(), + model_version: EMBEDDING_MODEL.to_string(), }; // Create context from current span for DB operation diff --git a/src/ai/handlers.rs b/src/ai/handlers.rs index abf2369..7ec6ab7 100644 --- a/src/ai/handlers.rs +++ b/src/ai/handlers.rs @@ -3,6 +3,8 @@ use opentelemetry::KeyValue; use opentelemetry::trace::{Span, Status, Tracer}; use serde::{Deserialize, Serialize}; +use crate::ai::insight_chat::{ChatStreamEvent, ChatTurnRequest}; +use crate::ai::ollama::ChatMessage; use crate::ai::{InsightGenerator, ModelCapabilities, OllamaClient}; use crate::data::Claims; use crate::database::{ExifDao, InsightDao}; @@ -11,6 +13,14 @@ use crate::otel::{extract_context_from_request, global_tracer}; use crate::state::AppState; use crate::utils::normalize_path; +/// Hardcoded few-shot exemplars for the agentic endpoint. Populate with the +/// ids of approved insights whose `training_messages` should be compressed +/// into trajectory form and injected into the system prompt. Empty = no +/// change in behavior. Request-level `fewshot_insight_ids` overrides this +/// when non-empty. +// const DEFAULT_FEWSHOT_INSIGHT_IDS: &[i32] = &[2918, 2908]; +const DEFAULT_FEWSHOT_INSIGHT_IDS: &[i32] = &[]; + #[derive(Debug, Deserialize)] pub struct GeneratePhotoInsightRequest { pub file_path: String, @@ -28,6 +38,16 @@ pub struct GeneratePhotoInsightRequest { pub top_k: Option, #[serde(default)] pub min_p: Option, + /// `"local"` (default, Ollama with images) | `"hybrid"` (local vision + + /// OpenRouter chat). Only respected by the agentic endpoint. + #[serde(default)] + pub backend: Option, + /// Insight ids whose stored `training_messages` should be compressed + /// into few-shot trajectories and injected into the system prompt. + /// Silently truncated to the first 2. When absent/empty, the handler + /// falls back to `DEFAULT_FEWSHOT_INSIGHT_IDS`. + #[serde(default)] + pub fewshot_insight_ids: Option>, } #[derive(Debug, Deserialize)] @@ -65,6 +85,10 @@ pub struct PhotoInsightResponse { pub eval_count: Option, #[serde(skip_serializing_if = "Option::is_none")] pub approved: Option, + pub backend: String, + /// True when the insight was generated agentically and a chat + /// continuation can be started against it. Drives the mobile chat button. + pub has_training_messages: bool, } #[derive(Debug, Serialize)] @@ -187,6 +211,8 @@ pub async fn get_insight_handler( prompt_eval_count: None, eval_count: None, approved: insight.approved, + has_training_messages: insight.training_messages.is_some(), + backend: insight.backend, }; HttpResponse::Ok().json(response) } @@ -254,6 +280,8 @@ pub async fn get_all_insights_handler( prompt_eval_count: None, eval_count: None, approved: insight.approved, + has_training_messages: insight.training_messages.is_some(), + backend: insight.backend, }) .collect(); @@ -309,6 +337,45 @@ pub async fn generate_agentic_insight_handler( max_iterations ); + if let Some(ref b) = request.backend { + span.set_attribute(KeyValue::new("backend", b.clone())); + } + + // Resolve few-shot ids: request-provided ids take precedence when + // non-empty; otherwise fall back to the hardcoded defaults. + let fewshot_ids: Vec = match request.fewshot_insight_ids.as_deref() { + Some(ids) if !ids.is_empty() => ids.iter().take(2).copied().collect(), + _ => DEFAULT_FEWSHOT_INSIGHT_IDS + .iter() + .take(2) + .copied() + .collect(), + }; + span.set_attribute(KeyValue::new("fewshot_count", fewshot_ids.len() as i64)); + + let fewshot_examples: Vec> = { + let otel_context = opentelemetry::Context::new(); + let mut dao = insight_dao.lock().expect("Unable to lock InsightDao"); + fewshot_ids + .iter() + .filter_map(|id| { + let insight = dao.get_insight_by_id(&otel_context, *id).ok().flatten()?; + let json = insight.training_messages?; + match serde_json::from_str::>(&json) { + Ok(msgs) => Some(msgs), + Err(e) => { + log::warn!( + "Few-shot insight {} has malformed training_messages: {}", + id, + e + ); + None + } + } + }) + .collect() + }; + let result = insight_generator .generate_agentic_insight_for_photo( &normalized_path, @@ -320,6 +387,9 @@ pub async fn generate_agentic_insight_handler( request.top_k, request.min_p, max_iterations, + request.backend.clone(), + fewshot_examples, + fewshot_ids, ) .await; @@ -341,6 +411,8 @@ pub async fn generate_agentic_insight_handler( prompt_eval_count, eval_count, approved: insight.approved, + has_training_messages: insight.training_messages.is_some(), + backend: insight.backend, }; HttpResponse::Ok().json(response) } @@ -432,6 +504,34 @@ pub async fn get_available_models_handler( HttpResponse::Ok().json(response) } +#[derive(Debug, Serialize)] +pub struct OpenRouterModelsResponse { + pub models: Vec, + pub default_model: Option, + pub configured: bool, +} + +/// GET /insights/openrouter/models - Curated OpenRouter model ids exposed +/// to clients for the hybrid backend. Returned verbatim from +/// `OPENROUTER_ALLOWED_MODELS`; no live call to OpenRouter. +#[get("/insights/openrouter/models")] +pub async fn get_openrouter_models_handler( + _claims: Claims, + app_state: web::Data, +) -> impl Responder { + let configured = app_state.openrouter.is_some(); + let default_model = app_state + .openrouter + .as_ref() + .map(|c| c.primary_model.clone()); + let response = OpenRouterModelsResponse { + models: app_state.openrouter_allowed_models.clone(), + default_model, + configured, + }; + HttpResponse::Ok().json(response) +} + /// POST /insights/rate - Rate an insight (thumbs up/down for training data) #[post("/insights/rate")] pub async fn rate_insight_handler( @@ -517,3 +617,370 @@ pub async fn export_training_data_handler( } } } + +#[derive(Debug, Deserialize)] +pub struct ChatTurnHttpRequest { + pub file_path: String, + #[serde(default)] + pub library: Option, + pub user_message: String, + #[serde(default)] + pub model: Option, + #[serde(default)] + pub backend: Option, + #[serde(default)] + pub num_ctx: Option, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub top_p: Option, + #[serde(default)] + pub top_k: Option, + #[serde(default)] + pub min_p: Option, + #[serde(default)] + pub max_iterations: Option, + #[serde(default)] + pub amend: bool, +} + +#[derive(Debug, Serialize)] +pub struct ChatTurnHttpResponse { + pub assistant_message: String, + pub tool_calls_made: usize, + pub iterations_used: usize, + pub truncated: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_eval_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub eval_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub amended_insight_id: Option, + pub backend: String, + pub model: String, +} + +/// POST /insights/chat — submit a follow-up turn against an existing insight. +#[post("/insights/chat")] +pub async fn chat_turn_handler( + http_request: HttpRequest, + _claims: Claims, + request: web::Json, + app_state: web::Data, +) -> impl Responder { + let parent_context = extract_context_from_request(&http_request); + let tracer = global_tracer(); + let mut span = tracer.start_with_context("http.insights.chat", &parent_context); + span.set_attribute(KeyValue::new("file_path", request.file_path.clone())); + + let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) { + Ok(Some(lib)) => lib, + Ok(None) => app_state.primary_library(), + Err(e) => { + return HttpResponse::BadRequest().json(serde_json::json!({ + "error": format!("invalid library: {}", e) + })); + } + }; + + let chat_req = ChatTurnRequest { + library_id: library.id, + file_path: request.file_path.clone(), + user_message: request.user_message.clone(), + model: request.model.clone(), + backend: request.backend.clone(), + num_ctx: request.num_ctx, + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + min_p: request.min_p, + max_iterations: request.max_iterations, + amend: request.amend, + }; + + match app_state.insight_chat.chat_turn(chat_req).await { + Ok(result) => { + span.set_status(Status::Ok); + HttpResponse::Ok().json(ChatTurnHttpResponse { + assistant_message: result.assistant_message, + tool_calls_made: result.tool_calls_made, + iterations_used: result.iterations_used, + truncated: result.truncated, + prompt_eval_count: result.prompt_eval_count, + eval_count: result.eval_count, + amended_insight_id: result.amended_insight_id, + backend: result.backend_used, + model: result.model_used, + }) + } + Err(e) => { + let msg = format!("{}", e); + log::error!("Chat turn failed: {}", msg); + span.set_status(Status::error(msg.clone())); + + // Map well-known errors to client-facing 4xx codes. + if msg.contains("no insight found") { + HttpResponse::NotFound().json(serde_json::json!({ "error": msg })) + } else if msg.contains("no chat history") { + HttpResponse::Conflict().json(serde_json::json!({ "error": msg })) + } else if msg.contains("user_message") + || msg.contains("unknown backend") + || msg.contains("switching from local to hybrid") + || msg.contains("hybrid backend unavailable") + { + HttpResponse::BadRequest().json(serde_json::json!({ "error": msg })) + } else { + HttpResponse::InternalServerError().json(serde_json::json!({ "error": msg })) + } + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ChatHistoryQuery { + pub path: String, + #[serde(default)] + pub library: Option, +} + +#[derive(Debug, Serialize)] +pub struct ChatHistoryHttpResponse { + pub messages: Vec, + pub turn_count: usize, + pub model_version: String, + pub backend: String, +} + +#[derive(Debug, Serialize)] +pub struct RenderedHistoryMessage { + pub role: String, + pub content: String, + pub is_initial: bool, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, +} + +#[derive(Debug, Serialize)] +pub struct HistoryToolInvocation { + pub name: String, + pub arguments: serde_json::Value, + pub result: String, + #[serde(skip_serializing_if = "std::ops::Not::not")] + pub result_truncated: bool, +} + +#[derive(Debug, Deserialize)] +pub struct ChatRewindHttpRequest { + pub file_path: String, + #[serde(default)] + pub library: Option, + /// 0-based index into the rendered transcript. The message at this + /// index, and everything after it, is discarded. Must be > 0 — the + /// initial user message is protected. + pub discard_from_rendered_index: usize, +} + +/// POST /insights/chat/rewind — truncate the stored conversation so the +/// rendered message at `discard_from_rendered_index` (and everything after) +/// is removed. Use when a user wants to retry a turn with a different +/// prompt without prior replies poisoning context. +#[post("/insights/chat/rewind")] +pub async fn chat_rewind_handler( + _claims: Claims, + request: web::Json, + app_state: web::Data, +) -> impl Responder { + let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) { + Ok(Some(lib)) => lib, + Ok(None) => app_state.primary_library(), + Err(e) => { + return HttpResponse::BadRequest().json(serde_json::json!({ + "error": format!("invalid library: {}", e) + })); + } + }; + + match app_state + .insight_chat + .rewind_history( + library.id, + &request.file_path, + request.discard_from_rendered_index, + ) + .await + { + Ok(()) => HttpResponse::Ok().json(serde_json::json!({ "success": true })), + Err(e) => { + let msg = format!("{}", e); + log::error!("Chat rewind failed: {}", msg); + if msg.contains("no insight found") { + HttpResponse::NotFound().json(serde_json::json!({ "error": msg })) + } else if msg.contains("no chat history") { + HttpResponse::Conflict().json(serde_json::json!({ "error": msg })) + } else if msg.contains("cannot discard the initial") || msg.contains("out of range") { + HttpResponse::BadRequest().json(serde_json::json!({ "error": msg })) + } else { + HttpResponse::InternalServerError().json(serde_json::json!({ "error": msg })) + } + } + } +} + +/// GET /insights/chat/history — return the rendered transcript for a photo. +#[get("/insights/chat/history")] +pub async fn chat_history_handler( + _claims: Claims, + query: web::Query, + app_state: web::Data, +) -> impl Responder { + // library param parsed for parity with other insight endpoints, even + // though load_history currently keys on file_path alone (matches the + // existing get_insight DAO contract). + let _library = libraries::resolve_library_param(&app_state, query.library.as_deref()) + .ok() + .flatten() + .unwrap_or_else(|| app_state.primary_library()); + + match app_state.insight_chat.load_history(&query.path) { + Ok(view) => HttpResponse::Ok().json(ChatHistoryHttpResponse { + messages: view + .messages + .into_iter() + .map(|m| RenderedHistoryMessage { + role: m.role, + content: m.content, + is_initial: m.is_initial, + tools: m + .tools + .into_iter() + .map(|t| HistoryToolInvocation { + name: t.name, + arguments: t.arguments, + result: t.result, + result_truncated: t.result_truncated, + }) + .collect(), + }) + .collect(), + turn_count: view.turn_count, + model_version: view.model_version, + backend: view.backend, + }), + Err(e) => { + let msg = format!("{}", e); + if msg.contains("no insight found") { + HttpResponse::NotFound().json(serde_json::json!({ "error": msg })) + } else if msg.contains("no chat history") { + HttpResponse::Conflict().json(serde_json::json!({ "error": msg })) + } else { + HttpResponse::InternalServerError().json(serde_json::json!({ "error": msg })) + } + } + } +} + +/// POST /insights/chat/stream — streaming variant of /insights/chat. +/// Returns `text/event-stream` with one event per chat stream event. +#[post("/insights/chat/stream")] +pub async fn chat_stream_handler( + _claims: Claims, + request: web::Json, + app_state: web::Data, +) -> HttpResponse { + let library = match libraries::resolve_library_param(&app_state, request.library.as_deref()) { + Ok(Some(lib)) => lib, + Ok(None) => app_state.primary_library(), + Err(e) => { + return HttpResponse::BadRequest().json(serde_json::json!({ + "error": format!("invalid library: {}", e) + })); + } + }; + + let chat_req = ChatTurnRequest { + library_id: library.id, + file_path: request.file_path.clone(), + user_message: request.user_message.clone(), + model: request.model.clone(), + backend: request.backend.clone(), + num_ctx: request.num_ctx, + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + min_p: request.min_p, + max_iterations: request.max_iterations, + amend: request.amend, + }; + + let service = app_state.insight_chat.clone(); + let events = service.chat_turn_stream(chat_req); + + // Map ChatStreamEvent → SSE frame bytes. + let sse_stream = futures::stream::StreamExt::map(events, |ev| { + let frame = render_sse_frame(&ev); + Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(frame)) + }); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("X-Accel-Buffering", "no")) // nginx: disable response buffering + .streaming(sse_stream) +} + +fn render_sse_frame(ev: &ChatStreamEvent) -> String { + let (event_name, payload) = match ev { + ChatStreamEvent::IterationStart { n, max } => { + ("iteration_start", serde_json::json!({ "n": n, "max": max })) + } + ChatStreamEvent::Truncated => ("truncated", serde_json::json!({})), + ChatStreamEvent::TextDelta(delta) => ("text", serde_json::json!({ "delta": delta })), + ChatStreamEvent::ToolCall { + index, + name, + arguments, + } => ( + "tool_call", + serde_json::json!({ "index": index, "name": name, "arguments": arguments }), + ), + ChatStreamEvent::ToolResult { + index, + name, + result, + result_truncated, + } => ( + "tool_result", + serde_json::json!({ + "index": index, + "name": name, + "result": result, + "result_truncated": result_truncated, + }), + ), + ChatStreamEvent::Done { + tool_calls_made, + iterations_used, + truncated, + prompt_eval_count, + eval_count, + amended_insight_id, + backend_used, + model_used, + } => ( + "done", + serde_json::json!({ + "tool_calls_made": tool_calls_made, + "iterations_used": iterations_used, + "truncated": truncated, + "prompt_eval_count": prompt_eval_count, + "eval_count": eval_count, + "amended_insight_id": amended_insight_id, + "backend": backend_used, + "model": model_used, + }), + ), + ChatStreamEvent::Error(msg) => ("error", serde_json::json!({ "message": msg })), + }; + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_string()); + format!("event: {}\ndata: {}\n\n", event_name, data) +} diff --git a/src/ai/insight_chat.rs b/src/ai/insight_chat.rs new file mode 100644 index 0000000..1c20b45 --- /dev/null +++ b/src/ai/insight_chat.rs @@ -0,0 +1,1381 @@ +use anyhow::{Result, anyhow, bail}; +use chrono::Utc; +use opentelemetry::KeyValue; +use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use tokio::sync::Mutex as TokioMutex; + +use crate::ai::insight_generator::InsightGenerator; +use crate::ai::llm_client::{ChatMessage, LlmClient, LlmStreamEvent}; +use crate::ai::ollama::OllamaClient; +use crate::ai::openrouter::OpenRouterClient; +use crate::database::InsightDao; +use crate::database::models::InsertPhotoInsight; +use crate::otel::global_tracer; +use crate::utils::normalize_path; +use futures::stream::{BoxStream, StreamExt}; + +const DEFAULT_MAX_ITERATIONS: usize = 6; +const DEFAULT_NUM_CTX: i32 = 8192; +/// Headroom reserved for the model's response, deducted from the context +/// budget when deciding whether to truncate the replayed history. +const RESPONSE_HEADROOM_TOKENS: usize = 2048; +/// Cheap byte-to-token approximation used by the truncation pass. The real +/// tokenization is model-specific; this avoids carrying tiktoken just for a +/// soft bound. +const BYTES_PER_TOKEN: usize = 4; + +pub type ChatLockMap = Arc>>>>; + +#[derive(Debug)] +pub struct ChatTurnRequest { + pub library_id: i32, + pub file_path: String, + pub user_message: String, + /// Override the model id. Local mode: an Ollama model name. Hybrid: + /// an OpenRouter id. None defers to the stored insight's `model_version`. + pub model: Option, + /// Override the backend used for this turn. None defers to the stored + /// insight's `backend`. Switching `local -> hybrid` is rejected in v1. + pub backend: Option, + pub num_ctx: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub min_p: Option, + pub max_iterations: Option, + /// When true, write a new insight row (regenerating title) instead of + /// updating training_messages on the existing row. + pub amend: bool, +} + +#[derive(Debug)] +pub struct ChatTurnResult { + pub assistant_message: String, + pub tool_calls_made: usize, + pub iterations_used: usize, + pub truncated: bool, + pub prompt_eval_count: Option, + pub eval_count: Option, + /// Set when `amend=true` and the new insight row was inserted. + pub amended_insight_id: Option, + /// Backend used for this turn — useful when the client overrode the + /// stored value. + pub backend_used: String, + /// Model identifier the chat backend ran with. + pub model_used: String, +} + +#[derive(Clone)] +pub struct InsightChatService { + generator: Arc, + ollama: OllamaClient, + openrouter: Option>, + insight_dao: Arc>>, + chat_locks: ChatLockMap, +} + +impl InsightChatService { + pub fn new( + generator: Arc, + ollama: OllamaClient, + openrouter: Option>, + insight_dao: Arc>>, + chat_locks: ChatLockMap, + ) -> Self { + Self { + generator, + ollama, + openrouter, + insight_dao, + chat_locks, + } + } + + /// Load the rendered transcript for chat-UI display. Filters internal + /// scaffolding (system message, tool turns, tool-dispatch-only assistant + /// messages) and drops base64 images from user turns to keep payloads + /// small. The first remaining user message is flagged `is_initial`. + pub fn load_history(&self, file_path: &str) -> Result { + let normalized = normalize_path(file_path); + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let insight = dao + .get_insight(&cx, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + .ok_or_else(|| anyhow!("no insight found for path"))?; + + let raw = insight + .training_messages + .as_ref() + .ok_or_else(|| anyhow!("insight has no chat history (pre-agentic insight)"))?; + let messages: Vec = serde_json::from_str(raw) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + let mut rendered = Vec::new(); + let mut user_turns_seen = 0usize; + let mut assistant_turns_seen = 0usize; + + // Accumulate tool invocations seen since the last user turn. An + // invocation is: one assistant tool_call message (which may hold + // multiple calls) + the N following tool-role messages (one per call, + // in order). They attach to the next assistant-with-content, which + // is the "final" reply for the current turn. + // + // Wire shape from the model: + // assistant { tool_calls: [A, B], content: "" } + // tool { content: "result of A" } + // tool { content: "result of B" } + // assistant { content: "here's the answer" } ← rendered as final + let mut pending_tools: Vec = Vec::new(); + // Queue of (name, arguments) awaiting a tool_result to pair with. + let mut pending_calls: std::collections::VecDeque<(String, serde_json::Value)> = + std::collections::VecDeque::new(); + + for msg in &messages { + match msg.role.as_str() { + "system" => continue, + "tool" => { + if let Some((name, arguments)) = pending_calls.pop_front() { + let (result, result_truncated) = truncate_tool_result(&msg.content); + pending_tools.push(ToolInvocation { + name, + arguments, + result, + result_truncated, + }); + } + // If there's no pending call, the tool message is an + // orphan (shouldn't happen in practice) — skip silently. + } + "assistant" => { + let has_tool_calls = msg + .tool_calls + .as_ref() + .map(|c| !c.is_empty()) + .unwrap_or(false); + if has_tool_calls && msg.content.trim().is_empty() { + // Tool-dispatch turn: enqueue calls, wait for tool + // results on subsequent messages. + if let Some(ref tcs) = msg.tool_calls { + for tc in tcs { + pending_calls.push_back(( + tc.function.name.clone(), + tc.function.arguments.clone(), + )); + } + } + continue; + } + // Final assistant reply for this turn — drain accumulated + // tools into it. + assistant_turns_seen += 1; + let tools = std::mem::take(&mut pending_tools); + pending_calls.clear(); // any leftover unpaired calls are dropped + rendered.push(RenderedMessage { + role: "assistant".to_string(), + content: msg.content.clone(), + is_initial: false, + tools, + }); + } + "user" => { + let is_initial = user_turns_seen == 0; + user_turns_seen += 1; + // New user turn resets any in-flight tool state. + pending_tools.clear(); + pending_calls.clear(); + rendered.push(RenderedMessage { + role: "user".to_string(), + content: msg.content.clone(), + is_initial, + tools: Vec::new(), + }); + } + _ => continue, + } + } + + Ok(HistoryView { + messages: rendered, + turn_count: assistant_turns_seen, + model_version: insight.model_version, + backend: insight.backend, + }) + } + + pub async fn chat_turn(&self, req: ChatTurnRequest) -> Result { + let tracer = global_tracer(); + let parent_cx = opentelemetry::Context::new(); + let mut span = tracer.start_with_context("ai.insight.chat_turn", &parent_cx); + span.set_attribute(KeyValue::new("file_path", req.file_path.clone())); + span.set_attribute(KeyValue::new("library_id", req.library_id as i64)); + span.set_attribute(KeyValue::new("amend", req.amend)); + + if req.user_message.trim().is_empty() { + bail!("user_message must not be empty"); + } + if req.user_message.len() > 8192 { + bail!("user_message exceeds 8192 chars"); + } + + let normalized = normalize_path(&req.file_path); + + // 1. Acquire the per-(library, file) async mutex. Two concurrent + // chat turns on the same insight would race on the JSON blob — + // the lock serialises them. + let lock_key = (req.library_id, normalized.clone()); + let entry_lock = { + let mut locks = self.chat_locks.lock().await; + locks + .entry(lock_key.clone()) + .or_insert_with(|| Arc::new(TokioMutex::new(()))) + .clone() + }; + let _guard = entry_lock.lock().await; + + // 2. Load the current insight + history. + let insight = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.get_insight(&cx, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + .ok_or_else(|| anyhow!("no insight found for path"))? + }; + let raw_history = insight + .training_messages + .as_ref() + .ok_or_else(|| { + anyhow!("insight has no chat history; regenerate this insight in agentic mode") + })? + .clone(); + let mut messages: Vec = serde_json::from_str(&raw_history) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + // 3. Resolve effective backend. Reject the unsupported switch. + let stored_backend = insight.backend.clone(); + let effective_backend = req + .backend + .as_deref() + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| stored_backend.clone()); + if !matches!(effective_backend.as_str(), "local" | "hybrid") { + bail!( + "unknown backend '{}'; expected 'local' or 'hybrid'", + effective_backend + ); + } + if stored_backend == "local" && effective_backend == "hybrid" { + bail!( + "switching from local to hybrid mid-chat isn't supported yet; \ + regenerate the insight in hybrid mode if you want OpenRouter chat" + ); + } + let is_hybrid = effective_backend == "hybrid"; + span.set_attribute(KeyValue::new("backend", effective_backend.clone())); + + // 4. Build the chat backend client. Ollama in local mode, a freshly + // cloned OpenRouter client in hybrid mode (clone so per-request + // sampling/model overrides don't leak into shared state). + let max_iterations = req + .max_iterations + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .clamp(1, env_max_iterations()); + span.set_attribute(KeyValue::new("max_iterations", max_iterations as i64)); + + let stored_model = insight.model_version.clone(); + let custom_model = req + .model + .clone() + .or_else(|| Some(stored_model.clone())) + .filter(|m| !m.is_empty()); + + let mut ollama_client = self.ollama.clone(); + let mut openrouter_client: Option = None; + + if is_hybrid { + let arc = self.openrouter.as_ref().ok_or_else(|| { + anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured") + })?; + let mut c: OpenRouterClient = (**arc).clone(); + if let Some(ref m) = custom_model { + c.primary_model = m.clone(); + } + if req.temperature.is_some() + || req.top_p.is_some() + || req.top_k.is_some() + || req.min_p.is_some() + { + c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p); + } + if let Some(ctx) = req.num_ctx { + c.set_num_ctx(Some(ctx)); + } + openrouter_client = Some(c); + } else { + // Local-mode model swap. Build a new client when the chat model + // differs from the configured one (mirrors the agentic pattern). + if let Some(ref m) = custom_model + && m != &self.ollama.primary_model + { + ollama_client = OllamaClient::new( + self.ollama.primary_url.clone(), + self.ollama.fallback_url.clone(), + m.clone(), + Some(m.clone()), + ); + } + if req.temperature.is_some() + || req.top_p.is_some() + || req.top_k.is_some() + || req.min_p.is_some() + { + ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p); + } + if let Some(ctx) = req.num_ctx { + ollama_client.set_num_ctx(Some(ctx)); + } + } + + let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client { + c + } else { + &ollama_client + }; + let model_used = chat_backend.primary_model().to_string(); + span.set_attribute(KeyValue::new("model", model_used.clone())); + + // 5. Decide vision + tool set. In hybrid we always omit + // `describe_photo` (matches the original generation flow). In + // local we trust the stored history's first-user shape: if it + // carries `images`, the original model was vision-capable, and + // we keep `describe_photo` available. + let local_first_user_has_image = messages + .iter() + .find(|m| m.role == "user") + .and_then(|m| m.images.as_ref()) + .map(|imgs| !imgs.is_empty()) + .unwrap_or(false); + let offer_describe_tool = !is_hybrid && local_first_user_has_image; + let tools = InsightGenerator::build_tool_definitions(offer_describe_tool); + + // Image base64 only needed when describe_photo is on the menu. Load + // lazily to avoid disk IO when the loop never invokes it. + let image_base64: Option = if offer_describe_tool { + self.generator.load_image_as_base64(&normalized).ok() + } else { + None + }; + + // 6. Apply truncation budget. Drops oldest tool_call+tool pairs + // (preserves system + first user including any images). + let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize) + .saturating_sub(RESPONSE_HEADROOM_TOKENS); + let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN); + let truncated = apply_context_budget(&mut messages, budget_bytes); + if truncated { + span.set_attribute(KeyValue::new("history_truncated", true)); + } + + // 7. Append the new user turn. + messages.push(ChatMessage::user(req.user_message.clone())); + + // Temporarily annotate the system message with this turn's iteration + // budget so the model knows how many tool-calling rounds it has. We + // restore the original content before persistence so the note doesn't + // accumulate across turns. + let original_system_content = annotate_system_with_budget(&mut messages, max_iterations); + + let insight_cx = parent_cx.with_span(span); + + // 8. Agentic loop — same shape as insight_generator's, but capped + // tighter and dispatching tools through the shared executor. + let loop_span = tracer.start_with_context("ai.chat.loop", &insight_cx); + let loop_cx = insight_cx.with_span(loop_span); + let mut tool_calls_made = 0usize; + let mut iterations_used = 0usize; + let mut last_prompt_eval_count: Option = None; + let mut last_eval_count: Option = None; + let mut final_content = String::new(); + + for iteration in 0..max_iterations { + iterations_used = iteration + 1; + log::info!("Chat iteration {}/{}", iterations_used, max_iterations); + + let (response, prompt_tokens, eval_tokens) = chat_backend + .chat_with_tools(messages.clone(), tools.clone()) + .await?; + last_prompt_eval_count = prompt_tokens; + last_eval_count = eval_tokens; + + // Ollama rejects non-object tool-call arguments on replay. + let mut response = response; + if let Some(ref mut tcs) = response.tool_calls { + for tc in tcs.iter_mut() { + if !tc.function.arguments.is_object() { + tc.function.arguments = serde_json::Value::Object(Default::default()); + } + } + } + + messages.push(response.clone()); + + if let Some(ref tool_calls) = response.tool_calls + && !tool_calls.is_empty() + { + for tool_call in tool_calls { + tool_calls_made += 1; + log::info!( + "Chat tool call [{}]: {} {:?}", + iteration, + tool_call.function.name, + tool_call.function.arguments + ); + let result = self + .generator + .execute_tool( + &tool_call.function.name, + &tool_call.function.arguments, + &ollama_client, + &image_base64, + &normalized, + &loop_cx, + ) + .await; + messages.push(ChatMessage::tool_result(result)); + } + continue; + } + + final_content = response.content; + break; + } + + if final_content.is_empty() { + // The model never produced a final answer; ask once more without + // tools to force a textual reply. + log::info!( + "Chat loop exhausted after {} iterations, requesting final answer", + iterations_used + ); + messages.push(ChatMessage::user( + "Please write your final answer now without calling any more tools.", + )); + let (final_response, prompt_tokens, eval_tokens) = chat_backend + .chat_with_tools(messages.clone(), vec![]) + .await?; + last_prompt_eval_count = prompt_tokens; + last_eval_count = eval_tokens; + final_content = final_response.content.clone(); + messages.push(final_response); + } + + loop_cx.span().set_status(Status::Ok); + + // Drop the per-turn iteration-budget note from the system message + // before we persist so it doesn't snowball on each subsequent turn. + restore_system_content(&mut messages, original_system_content); + + // 9. Persist. Append mode rewrites the JSON blob in place; amend + // mode regenerates the title and inserts a new insight row, + // relying on store_insight to flip prior rows' is_current=false. + let json = serde_json::to_string(&messages) + .map_err(|e| anyhow!("failed to serialize chat history: {}", e))?; + + let mut amended_insight_id: Option = None; + if req.amend { + let title_prompt = format!( + "Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\ + Capture the key moment or theme. Return ONLY the title, nothing else.", + final_content + ); + let title_raw = chat_backend + .generate( + &title_prompt, + Some( + "You are my long term memory assistant. Use only the information provided. Do not invent details.", + ), + None, + ) + .await?; + let title = title_raw.trim().trim_matches('"').to_string(); + + // Amended rows intentionally do not inherit the parent's + // `fewshot_source_ids`. The parent's few-shot influence is still + // present in this row's content; if you want strict lineage + // tracking for training-set filtering, fetch the parent here and + // copy its value forward. + let new_row = InsertPhotoInsight { + library_id: req.library_id, + file_path: normalized.clone(), + title, + summary: final_content.clone(), + generated_at: Utc::now().timestamp(), + model_version: model_used.clone(), + is_current: true, + training_messages: Some(json), + backend: effective_backend.clone(), + fewshot_source_ids: None, + }; + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let stored = dao + .store_insight(&cx, new_row) + .map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?; + amended_insight_id = Some(stored.id); + } else { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.update_training_messages(&cx, req.library_id, &normalized, &json) + .map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?; + } + + Ok(ChatTurnResult { + assistant_message: final_content, + tool_calls_made, + iterations_used, + truncated, + prompt_eval_count: last_prompt_eval_count, + eval_count: last_eval_count, + amended_insight_id, + backend_used: effective_backend, + model_used, + }) + } + + /// Truncate the stored conversation so the rendered message at + /// `discard_from_rendered_index` (and everything after it — including + /// the tool-call scaffolding that produced a discarded assistant reply) + /// is removed. The initial user turn cannot be discarded; attempting to + /// do so returns an error. + /// + /// Holds the per-file chat mutex so it serialises with `chat_turn`. + pub async fn rewind_history( + &self, + library_id: i32, + file_path: &str, + discard_from_rendered_index: usize, + ) -> Result<()> { + if discard_from_rendered_index == 0 { + bail!("cannot discard the initial user message"); + } + let normalized = normalize_path(file_path); + + let lock_key = (library_id, normalized.clone()); + let entry_lock = { + let mut locks = self.chat_locks.lock().await; + locks + .entry(lock_key.clone()) + .or_insert_with(|| Arc::new(TokioMutex::new(()))) + .clone() + }; + let _guard = entry_lock.lock().await; + + let insight = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.get_insight(&cx, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + .ok_or_else(|| anyhow!("no insight found for path"))? + }; + let raw_history = insight + .training_messages + .as_ref() + .ok_or_else(|| anyhow!("insight has no chat history"))?; + let messages: Vec = serde_json::from_str(raw_history) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + let cut_at = find_raw_cut(&messages, discard_from_rendered_index) + .ok_or_else(|| anyhow!("discard_from_rendered_index out of range"))?; + + let truncated = &messages[..cut_at]; + let json = serde_json::to_string(truncated) + .map_err(|e| anyhow!("failed to serialize truncated history: {}", e))?; + + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.update_training_messages(&cx, library_id, &normalized, &json) + .map_err(|e| anyhow!("failed to persist truncated history: {:?}", e))?; + Ok(()) + } + + /// Streaming variant of `chat_turn`. Emits user-facing events as the + /// conversation progresses: iteration starts, tool dispatch + result, + /// text deltas from the final assistant reply, and a terminal `Done` + /// frame. Persistence happens inside the stream after the loop ends. + /// + /// The stream takes ownership of the service via `Arc` (passed by + /// the caller) so it can live past the handler's await boundary. + pub fn chat_turn_stream( + self: Arc, + req: ChatTurnRequest, + ) -> BoxStream<'static, ChatStreamEvent> { + let svc = self; + let s = async_stream::stream! { + match svc.chat_turn_stream_inner(req, Ok).await { + Ok(mut rx) => { + while let Some(ev) = rx.recv().await { + yield ev; + } + } + Err(e) => { + yield ChatStreamEvent::Error(format!("{}", e)); + } + } + }; + Box::pin(s) + } + + /// Internal: drives the streaming loop on a background task, returning + /// a receiver the caller drains. Keeping the work on a spawned task + /// decouples the HTTP request lifetime from the chat execution, which + /// matters because the chat may run longer than any single network hop + /// and we want clean cancellation semantics via the channel close. + async fn chat_turn_stream_inner( + self: Arc, + req: ChatTurnRequest, + _ev_mapper: F, + ) -> Result> + where + F: Fn(ChatStreamEvent) -> Result + Send + 'static, + { + let (tx, rx) = tokio::sync::mpsc::channel::(64); + let svc = self.clone(); + tokio::spawn(async move { + let result = svc.run_streaming_turn(req, tx.clone()).await; + if let Err(e) = result { + let _ = tx.send(ChatStreamEvent::Error(format!("{}", e))).await; + } + }); + Ok(rx) + } + + async fn run_streaming_turn( + self: Arc, + req: ChatTurnRequest, + tx: tokio::sync::mpsc::Sender, + ) -> Result<()> { + if req.user_message.trim().is_empty() { + bail!("user_message must not be empty"); + } + if req.user_message.len() > 8192 { + bail!("user_message exceeds 8192 chars"); + } + let normalized = normalize_path(&req.file_path); + + let lock_key = (req.library_id, normalized.clone()); + let entry_lock = { + let mut locks = self.chat_locks.lock().await; + locks + .entry(lock_key.clone()) + .or_insert_with(|| Arc::new(TokioMutex::new(()))) + .clone() + }; + let _guard = entry_lock.lock().await; + + let insight = { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.get_insight(&cx, &normalized) + .map_err(|e| anyhow!("failed to load insight: {:?}", e))? + .ok_or_else(|| anyhow!("no insight found for path"))? + }; + let raw_history = insight + .training_messages + .as_ref() + .ok_or_else(|| { + anyhow!("insight has no chat history; regenerate this insight in agentic mode") + })? + .clone(); + let mut messages: Vec = serde_json::from_str(&raw_history) + .map_err(|e| anyhow!("failed to deserialize chat history: {}", e))?; + + // Backend selection — same rules as non-streaming chat_turn. + let stored_backend = insight.backend.clone(); + let effective_backend = req + .backend + .as_deref() + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| stored_backend.clone()); + if !matches!(effective_backend.as_str(), "local" | "hybrid") { + bail!( + "unknown backend '{}'; expected 'local' or 'hybrid'", + effective_backend + ); + } + if stored_backend == "local" && effective_backend == "hybrid" { + bail!( + "switching from local to hybrid mid-chat isn't supported yet; \ + regenerate the insight in hybrid mode if you want OpenRouter chat" + ); + } + let is_hybrid = effective_backend == "hybrid"; + + let max_iterations = req + .max_iterations + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .clamp(1, env_max_iterations()); + + let stored_model = insight.model_version.clone(); + let custom_model = req + .model + .clone() + .or_else(|| Some(stored_model.clone())) + .filter(|m| !m.is_empty()); + + let mut ollama_client = self.ollama.clone(); + let mut openrouter_client: Option = None; + + if is_hybrid { + let arc = self.openrouter.as_ref().ok_or_else(|| { + anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured") + })?; + let mut c: OpenRouterClient = (**arc).clone(); + if let Some(ref m) = custom_model { + c.primary_model = m.clone(); + } + if req.temperature.is_some() + || req.top_p.is_some() + || req.top_k.is_some() + || req.min_p.is_some() + { + c.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p); + } + if let Some(ctx) = req.num_ctx { + c.set_num_ctx(Some(ctx)); + } + openrouter_client = Some(c); + } else { + if let Some(ref m) = custom_model + && m != &self.ollama.primary_model + { + ollama_client = OllamaClient::new( + self.ollama.primary_url.clone(), + self.ollama.fallback_url.clone(), + m.clone(), + Some(m.clone()), + ); + } + if req.temperature.is_some() + || req.top_p.is_some() + || req.top_k.is_some() + || req.min_p.is_some() + { + ollama_client.set_sampling_params(req.temperature, req.top_p, req.top_k, req.min_p); + } + if let Some(ctx) = req.num_ctx { + ollama_client.set_num_ctx(Some(ctx)); + } + } + + let chat_backend: &dyn LlmClient = if let Some(ref c) = openrouter_client { + c + } else { + &ollama_client + }; + let model_used = chat_backend.primary_model().to_string(); + + // Tool set. + let local_first_user_has_image = messages + .iter() + .find(|m| m.role == "user") + .and_then(|m| m.images.as_ref()) + .map(|imgs| !imgs.is_empty()) + .unwrap_or(false); + let offer_describe_tool = !is_hybrid && local_first_user_has_image; + let tools = InsightGenerator::build_tool_definitions(offer_describe_tool); + + let image_base64: Option = if offer_describe_tool { + self.generator.load_image_as_base64(&normalized).ok() + } else { + None + }; + + // Truncate before appending the new user turn. + let budget_tokens = (req.num_ctx.unwrap_or(DEFAULT_NUM_CTX) as usize) + .saturating_sub(RESPONSE_HEADROOM_TOKENS); + let budget_bytes = budget_tokens.saturating_mul(BYTES_PER_TOKEN); + let truncated = apply_context_budget(&mut messages, budget_bytes); + if truncated { + let _ = tx.send(ChatStreamEvent::Truncated).await; + } + + messages.push(ChatMessage::user(req.user_message.clone())); + + let original_system_content = annotate_system_with_budget(&mut messages, max_iterations); + + let mut tool_calls_made = 0usize; + let mut iterations_used = 0usize; + let mut last_prompt_eval_count: Option = None; + let mut last_eval_count: Option = None; + let mut final_content = String::new(); + + for iteration in 0..max_iterations { + iterations_used = iteration + 1; + let _ = tx + .send(ChatStreamEvent::IterationStart { + n: iterations_used, + max: max_iterations, + }) + .await; + + let mut stream = chat_backend + .chat_with_tools_stream(messages.clone(), tools.clone()) + .await?; + + let mut final_message: Option = None; + while let Some(ev) = stream.next().await { + let ev = ev?; + match ev { + LlmStreamEvent::TextDelta(delta) => { + let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await; + } + LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + } => { + last_prompt_eval_count = prompt_eval_count; + last_eval_count = eval_count; + final_message = Some(message); + break; + } + } + } + let mut response = + final_message.ok_or_else(|| anyhow!("stream ended without a Done event"))?; + + // Normalize non-object tool arguments (same as non-streaming path). + if let Some(ref mut tcs) = response.tool_calls { + for tc in tcs.iter_mut() { + if !tc.function.arguments.is_object() { + tc.function.arguments = serde_json::Value::Object(Default::default()); + } + } + } + + messages.push(response.clone()); + + if let Some(ref tool_calls) = response.tool_calls + && !tool_calls.is_empty() + { + for (i, tool_call) in tool_calls.iter().enumerate() { + tool_calls_made += 1; + let call_index = tool_calls_made - 1; + let _ = tx + .send(ChatStreamEvent::ToolCall { + index: call_index, + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + }) + .await; + let cx = opentelemetry::Context::new(); + let result = self + .generator + .execute_tool( + &tool_call.function.name, + &tool_call.function.arguments, + &ollama_client, + &image_base64, + &normalized, + &cx, + ) + .await; + let (result_preview, result_truncated) = truncate_tool_result(&result); + let _ = tx + .send(ChatStreamEvent::ToolResult { + index: call_index, + name: tool_call.function.name.clone(), + result: result_preview, + result_truncated, + }) + .await; + messages.push(ChatMessage::tool_result(result)); + let _ = i; // reserved for per-call ordering if needed + } + continue; + } + + final_content = response.content; + break; + } + + if final_content.is_empty() { + messages.push(ChatMessage::user( + "Please write your final answer now without calling any more tools.", + )); + let mut stream = chat_backend + .chat_with_tools_stream(messages.clone(), vec![]) + .await?; + let mut final_message: Option = None; + while let Some(ev) = stream.next().await { + let ev = ev?; + match ev { + LlmStreamEvent::TextDelta(delta) => { + let _ = tx.send(ChatStreamEvent::TextDelta(delta)).await; + } + LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + } => { + last_prompt_eval_count = prompt_eval_count; + last_eval_count = eval_count; + final_message = Some(message); + break; + } + } + } + let final_response = + final_message.ok_or_else(|| anyhow!("final stream ended without a Done event"))?; + final_content = final_response.content.clone(); + messages.push(final_response); + } + + // Drop the per-turn iteration-budget note from the system message + // before we persist so it doesn't snowball on each subsequent turn. + restore_system_content(&mut messages, original_system_content); + + // Persist. + let json = serde_json::to_string(&messages) + .map_err(|e| anyhow!("failed to serialize chat history: {}", e))?; + + let mut amended_insight_id: Option = None; + if req.amend { + let title_prompt = format!( + "Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\n\ + Capture the key moment or theme. Return ONLY the title, nothing else.", + final_content + ); + let title_raw = chat_backend + .generate( + &title_prompt, + Some( + "You are my long term memory assistant. Use only the information provided. Do not invent details.", + ), + None, + ) + .await?; + let title = title_raw.trim().trim_matches('"').to_string(); + + // Amended rows intentionally do not inherit the parent's + // `fewshot_source_ids`. The parent's few-shot influence is still + // present in this row's content; if you want strict lineage + // tracking for training-set filtering, fetch the parent here and + // copy its value forward. + let new_row = InsertPhotoInsight { + library_id: req.library_id, + file_path: normalized.clone(), + title, + summary: final_content.clone(), + generated_at: Utc::now().timestamp(), + model_version: model_used.clone(), + is_current: true, + training_messages: Some(json), + backend: effective_backend.clone(), + fewshot_source_ids: None, + }; + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + let stored = dao + .store_insight(&cx, new_row) + .map_err(|e| anyhow!("failed to store amended insight: {:?}", e))?; + amended_insight_id = Some(stored.id); + } else { + let cx = opentelemetry::Context::new(); + let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); + dao.update_training_messages(&cx, req.library_id, &normalized, &json) + .map_err(|e| anyhow!("failed to persist chat history: {:?}", e))?; + } + + let _ = tx + .send(ChatStreamEvent::Done { + tool_calls_made, + iterations_used, + truncated, + prompt_eval_count: last_prompt_eval_count, + eval_count: last_eval_count, + amended_insight_id, + backend_used: effective_backend, + model_used, + }) + .await; + + Ok(()) + } +} + +/// Events emitted by `chat_turn_stream`. One stream per turn; ends after +/// `Done` or `Error`. +#[derive(Debug, Clone)] +pub enum ChatStreamEvent { + /// Starting iteration `n` of up to `max` (1-based). + IterationStart { n: usize, max: usize }, + /// History was trimmed to fit the context budget before the turn ran. + /// Emitted at most once, before any tool or text events. + Truncated, + /// Incremental content from the final assistant reply. Concatenate to + /// reconstruct the reply body. Tool-dispatch turns don't produce these. + TextDelta(String), + /// The model requested this tool call. Emitted just before execution. + /// `index` is a monotonically-increasing counter across the turn so the + /// client can pair `ToolCall` with its matching `ToolResult`. + ToolCall { + index: usize, + name: String, + arguments: serde_json::Value, + }, + /// The tool finished; `result` is the (possibly truncated) output. + ToolResult { + index: usize, + name: String, + result: String, + result_truncated: bool, + }, + /// Terminal success event with counters + persistence result. + Done { + tool_calls_made: usize, + iterations_used: usize, + truncated: bool, + prompt_eval_count: Option, + eval_count: Option, + amended_insight_id: Option, + backend_used: String, + model_used: String, + }, + /// Terminal failure event. No further events follow. + Error(String), +} + +/// Is this raw message visible in the rendered transcript? Must match +/// `load_history`'s filter exactly — `find_raw_cut` depends on it to map +/// rendered indices back to raw positions. +fn is_rendered(m: &ChatMessage) -> bool { + match m.role.as_str() { + "user" => true, + "assistant" => { + let has_tool_calls = m + .tool_calls + .as_ref() + .map(|c| !c.is_empty()) + .unwrap_or(false); + !(has_tool_calls && m.content.trim().is_empty()) + } + _ => false, + } +} + +/// Given a rendered index to start discarding from, find the raw index at +/// which to truncate. The cut position is the raw length after all prior +/// rendered messages — which also strips any tool-call scaffolding that +/// immediately precedes the discarded rendered message. +/// +/// Discarding *at* the end (`discard == rendered_count`) is a no-op success: +/// returns `Some(messages.len())`. The mobile client hits this when +/// regenerating after a failed turn — its optimistic user bubble lives at +/// the index just past the server's persisted history. Strictly past the end +/// (`discard > rendered_count`) returns `None`. +pub(crate) fn find_raw_cut( + messages: &[ChatMessage], + discard_from_rendered_index: usize, +) -> Option { + let mut rendered_count = 0usize; + let mut last_kept_raw_end = 0usize; + for (i, m) in messages.iter().enumerate() { + if !is_rendered(m) { + continue; + } + if rendered_count == discard_from_rendered_index { + return Some(last_kept_raw_end); + } + rendered_count += 1; + last_kept_raw_end = i + 1; + } + if discard_from_rendered_index == rendered_count { + return Some(messages.len()); + } + None +} + +/// Read AGENTIC_CHAT_MAX_ITERATIONS once per call. Cheap; keeps the code +/// free of static globals and lets the operator change the cap by env without +/// a restart in test harnesses (the running server still caches via Default). +fn env_max_iterations() -> usize { + std::env::var("AGENTIC_CHAT_MAX_ITERATIONS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(DEFAULT_MAX_ITERATIONS) + .max(1) +} + +/// Append a per-turn iteration-budget reminder to the replayed system +/// message so the model knows how many tool-calling rounds this turn gets. +/// Returns the original `content` so the caller can restore it before +/// persistence — otherwise the note would accumulate across turns. +/// +/// No-op (returns `None`) when `messages` has no leading system message. +fn annotate_system_with_budget( + messages: &mut [ChatMessage], + max_iterations: usize, +) -> Option { + let first = messages.first_mut()?; + if first.role != "system" { + return None; + } + let original = first.content.clone(); + first.content = format!( + "{}\n\n(Budget for this chat turn: up to {} tool-calling iterations. Produce your final reply before the budget is exhausted.)", + first.content, max_iterations + ); + Some(original) +} + +/// Restore a system-message content previously captured by +/// [`annotate_system_with_budget`]. No-op when `original` is `None` or the +/// first message isn't a system message. +fn restore_system_content(messages: &mut [ChatMessage], original: Option) { + let Some(original) = original else { return }; + if let Some(first) = messages.first_mut() + && first.role == "system" + { + first.content = original; + } +} + +/// View returned to clients for chat-UI rendering. +#[derive(Debug)] +pub struct HistoryView { + pub messages: Vec, + pub turn_count: usize, + pub model_version: String, + pub backend: String, +} + +#[derive(Debug)] +pub struct RenderedMessage { + pub role: String, + pub content: String, + pub is_initial: bool, + /// Tools invoked during this turn (only populated for assistant replies). + /// Empty for user messages and for assistant replies that didn't involve + /// tool calls. + pub tools: Vec, +} + +#[derive(Debug, Clone)] +pub struct ToolInvocation { + pub name: String, + pub arguments: serde_json::Value, + pub result: String, + /// True when `result` was trimmed for payload size. Full value remains + /// available in the raw training_messages blob. + pub result_truncated: bool, +} + +/// Soft cap for tool-result bodies returned via the history API. Keeps +/// payloads small for the mobile client — verbose SMS / geocoding responses +/// don't need to ship in full for inspection. +const TOOL_RESULT_PREVIEW_MAX: usize = 2000; + +fn truncate_tool_result(s: &str) -> (String, bool) { + if s.len() <= TOOL_RESULT_PREVIEW_MAX { + (s.to_string(), false) + } else { + // Cut on a char boundary. + let mut cut = TOOL_RESULT_PREVIEW_MAX; + while !s.is_char_boundary(cut) && cut > 0 { + cut -= 1; + } + (s[..cut].to_string(), true) + } +} + +/// Trim history to fit within `budget_bytes` of serialized JSON. Preserves +/// the system message and the first user message (with its base64 images +/// intact, since dropping those would invalidate the model's prior visual +/// reasoning). Drops the oldest assistant-tool_call + corresponding +/// tool-result pair on each pass until the budget is met or only the +/// preserved prefix remains. +/// +/// Returns true when at least one message was dropped. +pub(crate) fn apply_context_budget(messages: &mut Vec, budget_bytes: usize) -> bool { + if budget_bytes == 0 { + return false; + } + if estimate_bytes(messages) <= budget_bytes { + return false; + } + + // Find the index past the protected prefix: system messages + the first + // user message. Everything after is droppable in pairs. + let first_user_idx = messages.iter().position(|m| m.role == "user"); + let preserve_through = match first_user_idx { + Some(i) => i, // keep [0..=i] + None => return false, + }; + + let mut dropped_any = false; + loop { + if estimate_bytes(messages) <= budget_bytes { + break; + } + // Find the oldest assistant-with-tool_calls strictly after the + // preserved prefix. Drop it together with the following tool turn(s) + // until we hit the next assistant or user turn. + let drop_start = (preserve_through + 1..messages.len()).find(|&i| { + let m = &messages[i]; + m.role == "assistant" + && m.tool_calls + .as_ref() + .map(|c| !c.is_empty()) + .unwrap_or(false) + }); + let Some(start) = drop_start else { break }; + // Determine end: drop the assistant turn plus any contiguous tool + // result turns that follow. + let mut end = start + 1; + while end < messages.len() && messages[end].role == "tool" { + end += 1; + } + // Stop if dropping these would leave the just-appended user turn at + // the end alone with no preceding context — we still want it kept. + if end > messages.len() { + break; + } + messages.drain(start..end); + dropped_any = true; + } + + dropped_any +} + +fn estimate_bytes(messages: &[ChatMessage]) -> usize { + serde_json::to_string(messages) + .map(|s| s.len()) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::llm_client::{ToolCall, ToolCallFunction}; + + fn assistant_with_tool_call(name: &str) -> ChatMessage { + ChatMessage { + role: "assistant".to_string(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: None, + function: ToolCallFunction { + name: name.to_string(), + arguments: serde_json::Value::Object(Default::default()), + }, + }]), + images: None, + } + } + + fn assistant_text(text: &str) -> ChatMessage { + ChatMessage { + role: "assistant".to_string(), + content: text.to_string(), + tool_calls: None, + images: None, + } + } + + #[test] + fn truncation_preserves_system_and_first_user() { + let mut msgs = vec![ + ChatMessage::system("sys"), + ChatMessage::user("first user with lots of context".repeat(50)), + assistant_with_tool_call("get_x"), + ChatMessage::tool_result("x result ".repeat(200)), + assistant_with_tool_call("get_y"), + ChatMessage::tool_result("y result ".repeat(200)), + assistant_text("final answer"), + ]; + let original_len = msgs.len(); + let dropped = apply_context_budget(&mut msgs, 500); + assert!(dropped, "should drop something at this small budget"); + assert!(msgs.len() < original_len); + // First two messages preserved. + assert_eq!(msgs[0].role, "system"); + assert_eq!(msgs[1].role, "user"); + } + + #[test] + fn truncation_no_op_when_under_budget() { + let mut msgs = vec![ChatMessage::system("s"), ChatMessage::user("u")]; + let dropped = apply_context_budget(&mut msgs, 1_000_000); + assert!(!dropped); + assert_eq!(msgs.len(), 2); + } + + #[test] + fn truncation_returns_false_with_no_droppable_pairs() { + // Only system + user, no tool-call turns to drop. + let mut msgs = vec![ChatMessage::system("s"), ChatMessage::user("u")]; + let dropped = apply_context_budget(&mut msgs, 1); + assert!(!dropped); + } + + #[test] + fn rewind_strips_assistant_and_tool_scaffolding() { + // Rendered: [user1, asst1, user2, asst2] → cut at rendered index 3 + // (the final asst2) should drop the tool-call scaffolding + asst2, + // leaving raw up through user2. + let msgs = vec![ + ChatMessage::system("sys"), + ChatMessage::user("q1"), + assistant_text("a1"), + ChatMessage::user("q2"), + assistant_with_tool_call("lookup"), + ChatMessage::tool_result("data"), + assistant_text("a2 final"), + ]; + let cut = find_raw_cut(&msgs, 3).expect("cut found"); + // raw[0..cut] should end at user("q2") — indices 0..=3. + assert_eq!(cut, 4); + assert_eq!(msgs[cut - 1].role, "user"); + assert_eq!(msgs[cut - 1].content, "q2"); + } + + #[test] + fn rewind_at_second_rendered_cuts_after_first_user() { + // Rendered index 1 = the first assistant reply → dropping it should + // leave just the initial user message. + let msgs = vec![ + ChatMessage::system("s"), + ChatMessage::user("q1"), + assistant_with_tool_call("tool"), + ChatMessage::tool_result("r"), + assistant_text("a1"), + ]; + let cut = find_raw_cut(&msgs, 1).expect("cut found"); + assert_eq!(cut, 2); // sys + user("q1") + } + + #[test] + fn rewind_beyond_range_returns_none() { + let msgs = vec![ChatMessage::user("q1"), assistant_text("a1")]; + assert!(find_raw_cut(&msgs, 5).is_none()); + } + + #[test] + fn rewind_at_end_is_noop_success() { + // Mobile client retries after a failed turn that never persisted — + // its optimistic user bubble's index equals the server's rendered + // count. Should resolve to "no cut" rather than an out-of-range error. + let msgs = vec![ + ChatMessage::system("s"), + ChatMessage::user("q1"), + assistant_text("a1"), + ]; + let cut = find_raw_cut(&msgs, 2).expect("boundary cut should succeed"); + assert_eq!(cut, msgs.len()); + } +} diff --git a/src/ai/insight_generator.rs b/src/ai/insight_generator.rs index 18e50c7..44a2a4e 100644 --- a/src/ai/insight_generator.rs +++ b/src/ai/insight_generator.rs @@ -9,8 +9,11 @@ use std::fs::File; use std::io::Cursor; use std::sync::{Arc, Mutex}; +use crate::ai::llm_client::LlmClient; use crate::ai::ollama::{ChatMessage, OllamaClient, Tool}; +use crate::ai::openrouter::OpenRouterClient; use crate::ai::sms_client::SmsApiClient; +use crate::ai::user_display_name; use crate::database::models::InsertPhotoInsight; use crate::database::{ CalendarEventDao, DailySummaryDao, ExifDao, InsightDao, KnowledgeDao, LocationHistoryDao, @@ -20,7 +23,7 @@ use crate::libraries::Library; use crate::memories::extract_date_from_filename; use crate::otel::global_tracer; use crate::tags::TagDao; -use crate::utils::normalize_path; +use crate::utils::{earliest_fs_time, normalize_path}; #[derive(Deserialize)] struct NominatimResponse { @@ -39,6 +42,9 @@ struct NominatimAddress { #[derive(Clone)] pub struct InsightGenerator { ollama: OllamaClient, + /// Optional OpenRouter client, used when `backend=hybrid` is requested. + /// `None` when `OPENROUTER_API_KEY` is not configured. + openrouter: Option>, sms_client: SmsApiClient, insight_dao: Arc>>, exif_dao: Arc>>, @@ -59,6 +65,7 @@ pub struct InsightGenerator { impl InsightGenerator { pub fn new( ollama: OllamaClient, + openrouter: Option>, sms_client: SmsApiClient, insight_dao: Arc>>, exif_dao: Arc>>, @@ -72,6 +79,7 @@ impl InsightGenerator { ) -> Self { Self { ollama, + openrouter, sms_client, insight_dao, exif_dao, @@ -89,7 +97,7 @@ impl InsightGenerator { /// first root under which the file exists. Insights may be generated /// for any library — the generator itself doesn't know which — so we /// probe each root rather than trust a single `base_path`. - fn resolve_full_path(&self, rel_path: &str) -> Option { + pub(crate) fn resolve_full_path(&self, rel_path: &str) -> Option { use std::path::Path; for lib in &self.libraries { let candidate = Path::new(&lib.root_path).join(rel_path); @@ -122,7 +130,7 @@ impl InsightGenerator { /// Load image file, resize it, and encode as base64 for vision models /// Resizes to max 1024px on longest edge to reduce context usage - fn load_image_as_base64(&self, file_path: &str) -> Result { + pub(crate) fn load_image_as_base64(&self, file_path: &str) -> Result { use image::imageops::FilterType; let full_path = self.resolve_full_path(file_path).ok_or_else(|| { @@ -752,8 +760,6 @@ impl InsightGenerator { let full_path = self.resolve_full_path(&file_path)?; File::open(&full_path) .and_then(|f| f.metadata()) - .and_then(|m| m.created().or(m.modified())) - .map(|t| DateTime::::from(t).timestamp()) .inspect_err(|e| { log::warn!( "Failed to get file timestamp for insight {}: {}", @@ -762,6 +768,8 @@ impl InsightGenerator { ) }) .ok() + .and_then(|m| earliest_fs_time(&m)) + .map(|t| DateTime::::from(t).timestamp()) }) .unwrap_or_else(|| Utc::now().timestamp()) }; @@ -1218,6 +1226,8 @@ impl InsightGenerator { model_version: ollama_client.primary_model.clone(), is_current: true, training_messages: None, + backend: "local".to_string(), + fewshot_source_ids: None, }; let mut dao = self.insight_dao.lock().expect("Unable to lock InsightDao"); @@ -1252,10 +1262,14 @@ impl InsightGenerator { // Format a sample of messages for topic extraction let sample_size = messages.len().min(20); + let user_name = user_display_name(); let sample_text: Vec = messages .iter() .take(sample_size) - .map(|m| format!("{}: {}", if m.is_sent { "Me" } else { &m.contact }, m.body)) + .map(|m| { + let sender: &str = if m.is_sent { &user_name } else { &m.contact }; + format!("{}: {}", sender, m.body) + }) .collect(); let prompt = format!( @@ -1353,10 +1367,11 @@ Return ONLY the summary, nothing else."#, } // Format messages + let user_name = user_display_name(); let formatted: Vec = messages .iter() .map(|m| { - let sender = if m.is_sent { "Me" } else { &m.contact }; + let sender: &str = if m.is_sent { &user_name } else { &m.contact }; let timestamp = chrono::DateTime::from_timestamp(m.timestamp, 0) .map(|dt| { dt.with_timezone(&Local) @@ -1403,7 +1418,7 @@ Return ONLY the summary, nothing else."#, // ── Tool executors for agentic loop ──────────────────────────────── /// Dispatch a tool call to the appropriate executor - async fn execute_tool( + pub(crate) async fn execute_tool( &self, tool_name: &str, arguments: &serde_json::Value, @@ -1413,7 +1428,8 @@ Return ONLY the summary, nothing else."#, cx: &opentelemetry::Context, ) -> String { let result = match tool_name { - "search_rag" => self.tool_search_rag(arguments, cx).await, + "search_rag" => self.tool_search_rag(arguments, ollama, cx).await, + "search_messages" => self.tool_search_messages(arguments).await, "get_sms_messages" => self.tool_get_sms_messages(arguments, cx).await, "get_calendar_events" => self.tool_get_calendar_events(arguments, cx).await, "get_location_history" => self.tool_get_location_history(arguments, cx).await, @@ -1439,6 +1455,7 @@ Return ONLY the summary, nothing else."#, async fn tool_search_rag( &self, args: &serde_json::Value, + ollama: &OllamaClient, _cx: &opentelemetry::Context, ) -> String { let query = match args.get("query").and_then(|v| v.as_str()) { @@ -1471,13 +1488,267 @@ Return ONLY the summary, nothing else."#, limit ); - match self - .find_relevant_messages_rag(date, None, contact.as_deref(), None, limit, Some(&query)) + // Pull a wider candidate pool than the final limit so the LLM + // reranker has room to promote less-obvious hits. Candidates_factor + // is capped so a big `limit` doesn't blow past what the reranker + // can sensibly judge in one prompt. + let rerank_enabled = std::env::var("SEARCH_RAG_RERANK") + .ok() + .map(|v| v.to_lowercase() != "off" && v != "0") + .unwrap_or(true); + let candidate_limit = if rerank_enabled { + (limit * 3).min(40) + } else { + limit + }; + + let results = match self + .find_relevant_messages_rag( + date, + None, + contact.as_deref(), + None, + candidate_limit, + Some(&query), + ) .await { - Ok(results) if !results.is_empty() => results.join("\n\n"), - Ok(_) => "No relevant messages found.".to_string(), - Err(e) => format!("Error searching RAG: {}", e), + Ok(results) if !results.is_empty() => results, + Ok(_) => return "No relevant messages found.".to_string(), + Err(e) => return format!("Error searching RAG: {}", e), + }; + + let final_results = if rerank_enabled && results.len() > limit { + match self.rerank_with_llm(&query, &results, limit, ollama).await { + Ok(reordered) => reordered, + Err(e) => { + log::warn!("rerank failed, using vector order: {}", e); + results.into_iter().take(limit).collect() + } + } + } else { + results.into_iter().take(limit).collect::>() + }; + + final_results.join("\n\n") + } + + /// LLM-based reranker: ask the local model to pick the top-`limit` + /// passages from `candidates` that are most relevant to `query`. + /// Returns the reordered subset. + /// + /// Cheap-ish because the reranker prompt and output live outside the + /// agent's visible context — only the final selection lands in the + /// tool_result. On parse failure we fall back to the input order. + async fn rerank_with_llm( + &self, + query: &str, + candidates: &[String], + limit: usize, + ollama: &OllamaClient, + ) -> Result> { + let query_preview: String = query.chars().take(60).collect(); + log::info!( + "rerank: {} candidates -> top {} (query=\"{}\")", + candidates.len(), + limit, + query_preview + ); + + // Build numbered list (1-based for readability). Cap each passage + // at ~1000 chars so very long summaries don't eat the prompt. + let numbered: String = candidates + .iter() + .enumerate() + .map(|(i, c)| { + let trimmed = if c.len() > 1000 { + format!("{}…", &c[..1000]) + } else { + c.clone() + }; + format!("[{}] {}", i + 1, trimmed) + }) + .collect::>() + .join("\n\n"); + + let prompt = format!( + "You are ranking search results. From the numbered passages below, \ + select the {} most relevant to the query. Respond with ONLY a \ + comma-separated list of passage numbers in order from most to \ + least relevant. No explanation, no other text.\n\n\ + Query: {}\n\n\ + Passages:\n{}\n\n\ + Top {} passage numbers:", + limit, query, numbered, limit + ); + + let started = std::time::Instant::now(); + let response = ollama + .generate_no_think( + &prompt, + Some( + "You are a terse relevance ranker. You output only numbers separated by commas.", + ), + ) + .await?; + log::info!( + "rerank: finished in {} ms (prompt={} chars)", + started.elapsed().as_millis(), + prompt.len() + ); + + // Extract indices from the response. Accept "3, 1, 7" and also + // tolerate "[3, 1, 7]" or "3,1,7,..." with trailing junk. + let picks: Vec = response + .split(|c: char| !c.is_ascii_digit()) + .filter_map(|s| s.parse::().ok()) + .filter(|&n| n >= 1 && n <= candidates.len()) + .collect(); + + if picks.is_empty() { + return Err(anyhow::anyhow!( + "reranker returned no usable indices (raw: {})", + response.chars().take(120).collect::() + )); + } + + let mut seen = std::collections::HashSet::new(); + let mut reordered: Vec = Vec::with_capacity(limit); + let mut final_indices: Vec = Vec::with_capacity(limit); + for n in picks { + if seen.insert(n) { + reordered.push(candidates[n - 1].clone()); + final_indices.push(n); + if reordered.len() >= limit { + break; + } + } + } + // Top-up from original order if the reranker returned fewer than + // `limit` distinct entries. + if reordered.len() < limit { + for (i, c) in candidates.iter().enumerate() { + if !seen.contains(&(i + 1)) { + reordered.push(c.clone()); + final_indices.push(i + 1); + if reordered.len() >= limit { + break; + } + } + } + } + + // Debug snapshot: show what the reranker changed. Position p holds + // the 1-based index of the candidate that now sits at position p. + // A value that equals its position means "no change at that slot". + let swapped = final_indices + .iter() + .enumerate() + .filter(|(pos, idx)| **idx != pos + 1) + .count(); + log::info!( + "rerank: final indices (1-based): {:?} — {} of top {} swapped from vector order", + final_indices, + swapped, + final_indices.len() + ); + let show = final_indices.len().min(5); + log::debug!("rerank: vector-order top {}:", show); + for (i, c) in candidates.iter().enumerate().take(show) { + let preview: String = c.chars().take(100).collect(); + log::debug!("rerank: [{}] {}", i + 1, preview); + } + log::debug!("rerank: reranked top {}:", show); + for (pos, idx) in final_indices.iter().enumerate().take(show) { + let preview: String = candidates[*idx - 1].chars().take(100).collect(); + log::debug!("rerank: [{}] (orig #{}) {}", pos + 1, idx, preview); + } + + Ok(reordered) + } + + /// Tool: search_messages — keyword / semantic / hybrid search over all + /// SMS message bodies via the Django FTS5 + embeddings pipeline. Unlike + /// `search_rag` (daily summaries, date-weighted) this hits raw message + /// text across time and is the right choice for exact phrases, proper + /// nouns, URLs, or anything where specific wording matters. + async fn tool_search_messages(&self, args: &serde_json::Value) -> String { + let query = match args.get("query").and_then(|v| v.as_str()) { + Some(q) if !q.trim().is_empty() => q.trim(), + _ => { + // Redirect when the model reached for this tool with a + // date/contact-shaped intent — get_sms_messages is the right + // call. Without this hint, small models often just retry + // search_messages again with the same args. + let has_date = args.get("date").is_some(); + let has_contact = args.get("contact").is_some(); + if has_date || has_contact { + return "Error: search_messages needs a 'query' (keywords/phrase). \ + To fetch messages around a date or from a contact, call \ + get_sms_messages with { date, contact? } instead." + .to_string(); + } + return "Error: missing required parameter 'query'".to_string(); + } + }; + if query.len() < 3 { + return "Error: query must be at least 3 characters".to_string(); + } + let mode = args + .get("mode") + .and_then(|v| v.as_str()) + .map(|s| s.to_lowercase()) + .unwrap_or_else(|| "hybrid".to_string()); + if !matches!(mode.as_str(), "fts5" | "semantic" | "hybrid") { + return format!( + "Error: unknown mode '{}'; expected one of: fts5, semantic, hybrid", + mode + ); + } + let limit = args + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(20) + .clamp(1, 50) as usize; + + log::info!( + "tool_search_messages: query='{}', mode={}, limit={}", + query, + mode, + limit + ); + + match self.sms_client.search_messages(query, &mode, limit).await { + Ok(hits) if hits.is_empty() => "No messages matched.".to_string(), + Ok(hits) => { + let mut out = String::new(); + out.push_str(&format!( + "Found {} messages (mode: {}):\n\n", + hits.len(), + mode + )); + let user_name = user_display_name(); + for h in hits { + let date = chrono::DateTime::from_timestamp(h.date, 0) + .map(|dt| dt.format("%Y-%m-%d").to_string()) + .unwrap_or_else(|| h.date.to_string()); + let direction: &str = if h.type_ == 2 { + &user_name + } else { + &h.contact_name + }; + let score = h + .similarity_score + .map(|s| format!(" [score {:.2}]", s)) + .unwrap_or_default(); + out.push_str(&format!( + "[{}]{} {} — {}\n\n", + date, score, direction, h.body + )); + } + out + } + Err(e) => format!("Error searching messages: {}", e), } } @@ -1525,11 +1796,12 @@ Return ONLY the summary, nothing else."#, .await { Ok(messages) if !messages.is_empty() => { + let user_name = user_display_name(); let formatted: Vec = messages .iter() .take(limit) .map(|m| { - let sender = if m.is_sent { "Me" } else { &m.contact }; + let sender: &str = if m.is_sent { &user_name } else { &m.contact }; let ts = DateTime::from_timestamp(m.timestamp, 0) .map(|dt| { dt.with_timezone(&Local) @@ -2128,7 +2400,7 @@ Return ONLY the summary, nothing else."#, // ── Agentic insight generation ────────────────────────────────────── /// Build the list of tool definitions for the agentic loop - fn build_tool_definitions(has_vision: bool) -> Vec { + pub(crate) fn build_tool_definitions(has_vision: bool) -> Vec { let mut tools = vec![ Tool::function( "search_rag", @@ -2156,9 +2428,32 @@ Return ONLY the summary, nothing else."#, } }), ), + Tool::function( + "search_messages", + "CONTENT search over SMS message bodies by keywords/phrases/topics across all time. Use when you're looking for specific wording (phrases, proper nouns, URLs, topics) and DON'T have a date in mind. NOT for time-based queries — if you know the date or want messages around a date, call get_sms_messages instead. Modes: 'fts5' (keyword, supports \"phrase\" / prefix* / AND / NEAR(w1 w2, 5)), 'semantic' (embedding similarity), 'hybrid' (recommended — merges both via reciprocal rank fusion).", + serde_json::json!({ + "type": "object", + "required": ["query"], + "properties": { + "query": { + "type": "string", + "description": "Search query. Min 3 chars. For fts5 mode, supports phrase (\"\"), prefix (*), AND/OR/NOT, and NEAR proximity." + }, + "mode": { + "type": "string", + "enum": ["fts5", "semantic", "hybrid"], + "description": "Search strategy. Default: hybrid." + }, + "limit": { + "type": "integer", + "description": "Maximum number of results (default: 20, max: 50)" + } + } + }), + ), Tool::function( "get_sms_messages", - "Fetch SMS/text messages near a specific date. Returns the actual message conversation. Omit contact to search across all conversations.", + "TIME-BASED fetch of SMS/text messages around a specific date (and optionally from a specific contact). Returns the actual message conversation for that window. Use this whenever you know the date or want the context around a photo's timestamp. Omit contact to search across all conversations. For keyword/topic search without a date, use search_messages instead.", serde_json::json!({ "type": "object", "required": ["date"], @@ -2337,7 +2632,7 @@ Return ONLY the summary, nothing else."#, }, "object_entity_id": { "type": "integer", - "description": "Use when the object is a known entity (e.g. Cameron's entity ID for 'is_friend_of Cameron'). Takes precedence over object_value." + "description": "Use when the object is a known entity (e.g. another person's entity ID for 'is_friend_of '). Takes precedence over object_value." }, "object_value": { "type": "string", @@ -2376,6 +2671,167 @@ Return ONLY the summary, nothing else."#, /// Generate an AI insight for a photo using an agentic tool-calling loop. /// The model decides which tools to call to gather context before writing the final insight. + /// + /// `backend` selects the chat provider: `"local"` (default) routes the + /// agentic loop through the configured Ollama server with the image + /// attached to the first user message; `"hybrid"` asks the local Ollama + /// vision model to describe the image once, inlines the description as + /// text, and runs the loop through OpenRouter (chat only — embeddings + /// and describe calls stay local in either mode). + #[allow(clippy::too_many_arguments)] + /// Render a set of prior-conversation transcripts into a compact + /// trajectory block for inclusion in the system prompt. Tool results + /// are summarised to one line each so the prompt stays small. + fn render_fewshot_examples(examples: &[Vec]) -> String { + if examples.is_empty() { + return String::new(); + } + + let mut out = String::from("## Examples of strong context-gathering\n\n"); + out.push_str( + "The following are compressed trajectories from prior high-quality insights. \ + They show the *pattern* of tool use, not answers to copy.\n\n", + ); + + for (i, msgs) in examples.iter().enumerate() { + out.push_str(&format!("### Example {}\n\n", i + 1)); + out.push_str(&Self::render_single_trajectory(msgs)); + out.push('\n'); + } + + out.push_str("---\n\n"); + out + } + + fn render_single_trajectory(msgs: &[ChatMessage]) -> String { + let mut out = String::new(); + + if let Some(first_user) = msgs.iter().find(|m| m.role == "user") { + let trimmed = first_user + .content + .lines() + .filter(|l| !l.trim().is_empty()) + .take(8) + .collect::>() + .join("\n"); + out.push_str(&format!("Input:\n{}\n\n", trimmed)); + } + + out.push_str("Trajectory:\n"); + let mut step = 1; + let mut final_content: Option = None; + + for (i, m) in msgs.iter().enumerate() { + if m.role != "assistant" { + continue; + } + if let Some(ref calls) = m.tool_calls { + for call in calls { + let args_brief = Self::brief_json_args(&call.function.arguments); + let result_summary = msgs + .get(i + 1) + .filter(|r| r.role == "tool") + .map(|r| Self::summarize_tool_result(&call.function.name, &r.content)) + .unwrap_or_else(|| "(no result)".to_string()); + out.push_str(&format!( + "{}. {}({}) -> {}\n", + step, call.function.name, args_brief, result_summary + )); + step += 1; + } + } else if !m.content.is_empty() { + final_content = Some(m.content.clone()); + } + } + + if let Some(content) = final_content { + let short: String = content.chars().take(240).collect(); + out.push_str(&format!("\nFinal insight: {}...\n", short)); + } + + out + } + + fn brief_json_args(v: &serde_json::Value) -> String { + let Some(obj) = v.as_object() else { + return v.to_string(); + }; + obj.iter() + .map(|(k, v)| { + let rendered = match v { + serde_json::Value::String(s) if s.len() > 40 => { + format!("\"{}...\"", &s[..40]) + } + _ => v.to_string(), + }; + format!("{}={}", k, rendered) + }) + .collect::>() + .join(", ") + } + + /// Collapse a raw tool-result string (the text the model saw) into a + /// short phrase suitable for a few-shot trajectory. Detects the + /// "Found N ...", "No ...", and "Error ..." idioms used by the tool + /// implementations in this file. Unknown shapes fall back to a char + /// count, which is deliberately visible so drift shows up in output. + fn summarize_tool_result(tool_name: &str, raw: &str) -> String { + if raw.starts_with("Error ") { + return "error".to_string(); + } + if raw.starts_with("No ") || raw.starts_with("Could not ") { + return "empty (pivoted)".to_string(); + } + + if let Some(rest) = raw.strip_prefix("Found ") + && let Some(n_str) = rest.split_whitespace().next() + && let Ok(n) = n_str.parse::() + { + let kind = match tool_name { + "search_messages" | "get_sms_messages" => "messages", + "get_calendar_events" => "events", + "get_location_history" => "location records", + _ => "results", + }; + return format!("{} {}", n, kind); + } + + match tool_name { + "search_rag" => { + let n = raw.split("\n\n").filter(|s| !s.trim().is_empty()).count(); + format!("{} rag hits", n) + } + "get_file_tags" => { + let n = raw.split(',').filter(|s| !s.trim().is_empty()).count(); + format!("{} tags", n) + } + "describe_photo" => { + let short: String = raw.chars().take(80).collect(); + format!("described: \"{}...\"", short) + } + "reverse_geocode" => { + let short: String = raw.chars().take(60).collect(); + format!("place: {}", short) + } + "recall_entities" | "recall_facts_for_photo" => { + let n = raw.lines().skip(1).filter(|l| !l.trim().is_empty()).count(); + let kind = if tool_name == "recall_entities" { + "entities" + } else { + "facts" + }; + format!("{} {}", n, kind) + } + "store_entity" | "store_fact" => raw + .split_whitespace() + .find_map(|tok| tok.strip_prefix("ID:")) + .map(|id| format!("stored id={}", id.trim_end_matches(','))) + .unwrap_or_else(|| "stored".to_string()), + "get_current_datetime" => "time noted".to_string(), + _ => format!("{} chars", raw.len()), + } + } + pub async fn generate_agentic_insight_for_photo( &self, file_path: &str, @@ -2387,6 +2843,9 @@ Return ONLY the summary, nothing else."#, top_k: Option, min_p: Option, max_iterations: usize, + backend: Option, + fewshot_examples: Vec>, + fewshot_source_ids: Vec, ) -> Result<(Option, Option)> { let tracer = global_tracer(); let current_cx = opentelemetry::Context::current(); @@ -2398,8 +2857,30 @@ Return ONLY the summary, nothing else."#, span.set_attribute(KeyValue::new("file_path", file_path.clone())); span.set_attribute(KeyValue::new("max_iterations", max_iterations as i64)); - // 1. Create OllamaClient - let mut ollama_client = if let Some(ref model) = custom_model { + // 1a. Resolve backend label (defaults to "local"). + let backend_label = backend + .as_deref() + .map(|s| s.trim().to_lowercase()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| "local".to_string()); + if !matches!(backend_label.as_str(), "local" | "hybrid") { + return Err(anyhow::anyhow!( + "unknown backend '{}'; expected 'local' or 'hybrid'", + backend_label + )); + } + span.set_attribute(KeyValue::new("backend", backend_label.clone())); + let is_hybrid = backend_label == "hybrid"; + + // 1b. Always build an Ollama client. In local mode it owns the chat + // loop; in hybrid mode it still handles describe_image + any + // tool-local calls (e.g. if a future tool needs embeddings). + // Sampling overrides only apply in local mode — in hybrid the + // user's params belong to the OpenRouter chat client. + let apply_sampling_to_ollama = !is_hybrid; + let mut ollama_client = if let Some(ref model) = custom_model + && !is_hybrid + { log::info!("Using custom model for agentic: {}", model); span.set_attribute(KeyValue::new("custom_model", model.clone())); OllamaClient::new( @@ -2409,108 +2890,159 @@ Return ONLY the summary, nothing else."#, Some(model.clone()), ) } else { - span.set_attribute(KeyValue::new("model", self.ollama.primary_model.clone())); + if !is_hybrid { + span.set_attribute(KeyValue::new("model", self.ollama.primary_model.clone())); + } self.ollama.clone() }; - if let Some(ctx) = num_ctx { - log::info!("Using custom context size: {}", ctx); - span.set_attribute(KeyValue::new("num_ctx", ctx as i64)); - ollama_client.set_num_ctx(Some(ctx)); + if apply_sampling_to_ollama { + if let Some(ctx) = num_ctx { + log::info!("Using custom context size: {}", ctx); + span.set_attribute(KeyValue::new("num_ctx", ctx as i64)); + ollama_client.set_num_ctx(Some(ctx)); + } + + if temperature.is_some() || top_p.is_some() || top_k.is_some() || min_p.is_some() { + log::info!( + "Using sampling params — temperature: {:?}, top_p: {:?}, top_k: {:?}, min_p: {:?}", + temperature, + top_p, + top_k, + min_p + ); + if let Some(t) = temperature { + span.set_attribute(KeyValue::new("temperature", t as f64)); + } + if let Some(p) = top_p { + span.set_attribute(KeyValue::new("top_p", p as f64)); + } + if let Some(k) = top_k { + span.set_attribute(KeyValue::new("top_k", k as i64)); + } + if let Some(m) = min_p { + span.set_attribute(KeyValue::new("min_p", m as f64)); + } + ollama_client.set_sampling_params(temperature, top_p, top_k, min_p); + } } - if temperature.is_some() || top_p.is_some() || top_k.is_some() || min_p.is_some() { - log::info!( - "Using sampling params — temperature: {:?}, top_p: {:?}, top_k: {:?}, min_p: {:?}", - temperature, - top_p, - top_k, - min_p - ); - if let Some(t) = temperature { - span.set_attribute(KeyValue::new("temperature", t as f64)); + // 1c. In hybrid mode, clone the configured OpenRouter client and + // apply per-request overrides. + let openrouter_client: Option = if is_hybrid { + let arc = self.openrouter.as_ref().ok_or_else(|| { + anyhow::anyhow!("hybrid backend unavailable: OPENROUTER_API_KEY not configured") + })?; + let mut c: OpenRouterClient = (**arc).clone(); + if let Some(ref m) = custom_model { + c.primary_model = m.clone(); + span.set_attribute(KeyValue::new("custom_model", m.clone())); } - if let Some(p) = top_p { - span.set_attribute(KeyValue::new("top_p", p as f64)); + span.set_attribute(KeyValue::new("openrouter_model", c.primary_model.clone())); + if temperature.is_some() || top_p.is_some() || top_k.is_some() || min_p.is_some() { + if let Some(t) = temperature { + span.set_attribute(KeyValue::new("temperature", t as f64)); + } + if let Some(p) = top_p { + span.set_attribute(KeyValue::new("top_p", p as f64)); + } + if let Some(k) = top_k { + span.set_attribute(KeyValue::new("top_k", k as i64)); + } + if let Some(m) = min_p { + span.set_attribute(KeyValue::new("min_p", m as f64)); + } + c.set_sampling_params(temperature, top_p, top_k, min_p); } - if let Some(k) = top_k { - span.set_attribute(KeyValue::new("top_k", k as i64)); + if let Some(ctx) = num_ctx { + span.set_attribute(KeyValue::new("num_ctx", ctx as i64)); + c.set_num_ctx(Some(ctx)); } - if let Some(m) = min_p { - span.set_attribute(KeyValue::new("min_p", m as f64)); - } - ollama_client.set_sampling_params(temperature, top_p, top_k, min_p); - } + Some(c) + } else { + None + }; let insight_cx = current_cx.with_span(span); - // 2a. Verify the model exists on at least one server before checking capabilities - if let Some(ref model_name) = custom_model { - let available_on_primary = - OllamaClient::is_model_available(&ollama_client.primary_url, model_name) - .await - .unwrap_or(false); + // 2. Verify chat model supports tool calling. + // - local: existing Ollama model availability + capability check. + // - hybrid: trust the operator's curated allowlist + // (OPENROUTER_ALLOWED_MODELS) — no live precheck. A bad model id + // surfaces as a chat-call error on the next step. + let has_vision = if is_hybrid { + // In hybrid mode the chat model never sees images directly — we + // describe-then-inject, so `has_vision` drives only whether we + // bother loading the image to describe it, which we always do. + true + } else { + if let Some(ref model_name) = custom_model { + let available_on_primary = + OllamaClient::is_model_available(&ollama_client.primary_url, model_name) + .await + .unwrap_or(false); - let available_on_fallback = if let Some(ref fallback_url) = ollama_client.fallback_url { - OllamaClient::is_model_available(fallback_url, model_name) - .await - .unwrap_or(false) - } else { - false + let available_on_fallback = + if let Some(ref fallback_url) = ollama_client.fallback_url { + OllamaClient::is_model_available(fallback_url, model_name) + .await + .unwrap_or(false) + } else { + false + }; + + if !available_on_primary && !available_on_fallback { + anyhow::bail!( + "model not available: '{}' not found on any configured server", + model_name + ); + } + } + + let model_name_for_caps = &ollama_client.primary_model; + let capabilities = match OllamaClient::check_model_capabilities( + &ollama_client.primary_url, + model_name_for_caps, + ) + .await + { + Ok(caps) => caps, + Err(_) => { + let fallback_url = ollama_client.fallback_url.as_deref().ok_or_else(|| { + anyhow::anyhow!( + "Failed to check model capabilities for '{}': model not found on primary server and no fallback configured", + model_name_for_caps + ) + })?; + OllamaClient::check_model_capabilities(fallback_url, model_name_for_caps) + .await + .map_err(|e| { + anyhow::anyhow!( + "Failed to check model capabilities for '{}': {}", + model_name_for_caps, + e + ) + })? + } }; - if !available_on_primary && !available_on_fallback { - anyhow::bail!( - "model not available: '{}' not found on any configured server", - model_name - ); + if !capabilities.has_tool_calling { + return Err(anyhow::anyhow!( + "tool calling not supported by model '{}'", + ollama_client.primary_model + )); } - } - // 2b. Check tool calling capability — try primary, fall back to fallback URL - let model_name_for_caps = &ollama_client.primary_model; - let capabilities = match OllamaClient::check_model_capabilities( - &ollama_client.primary_url, - model_name_for_caps, - ) - .await - { - Ok(caps) => caps, - Err(_) => { - // Model may only be on the fallback server - let fallback_url = ollama_client.fallback_url.as_deref().ok_or_else(|| { - anyhow::anyhow!( - "Failed to check model capabilities for '{}': model not found on primary server and no fallback configured", - model_name_for_caps - ) - })?; - OllamaClient::check_model_capabilities(fallback_url, model_name_for_caps) - .await - .map_err(|e| { - anyhow::anyhow!( - "Failed to check model capabilities for '{}': {}", - model_name_for_caps, - e - ) - })? - } + insight_cx + .span() + .set_attribute(KeyValue::new("model_has_vision", capabilities.has_vision)); + insight_cx + .span() + .set_attribute(KeyValue::new("model_has_tool_calling", true)); + + capabilities.has_vision }; - if !capabilities.has_tool_calling { - return Err(anyhow::anyhow!( - "tool calling not supported by model '{}'", - ollama_client.primary_model - )); - } - - let has_vision = capabilities.has_vision; - insight_cx - .span() - .set_attribute(KeyValue::new("model_has_vision", has_vision)); - insight_cx - .span() - .set_attribute(KeyValue::new("model_has_tool_calling", true)); - // 3. Fetch EXIF let exif = { let mut exif_dao = self.exif_dao.lock().expect("Unable to lock ExifDao"); @@ -2530,8 +3062,6 @@ Return ONLY the summary, nothing else."#, let full_path = self.resolve_full_path(&file_path)?; File::open(&full_path) .and_then(|f| f.metadata()) - .and_then(|m| m.created().or(m.modified())) - .map(|t| DateTime::::from(t).timestamp()) .inspect_err(|e| { log::warn!( "Failed to get file timestamp for agentic insight {}: {}", @@ -2540,6 +3070,8 @@ Return ONLY the summary, nothing else."#, ) }) .ok() + .and_then(|m| earliest_fs_time(&m)) + .map(|t| DateTime::::from(t).timestamp()) }) .unwrap_or_else(|| Utc::now().timestamp()) }; @@ -2564,27 +3096,26 @@ Return ONLY the summary, nothing else."#, .collect() }; - // 6. Clear existing entity-photo links for this file so the run starts fresh, - // and ensure the owner entity (Cameron) exists so the agent can reference it. - let cameron_entity_id: Option = { + // 6. Ensure the owner entity exists so the agent can reference it. + // Prior entity_photo_links for this file are intentionally preserved + // across regenerations — clearing them made `recall_facts_for_photo` + // always return empty and discarded hard-won knowledge. Re-linking + // the same entity is a no-op (INSERT OR IGNORE). + let owner_name = user_display_name(); + let owner_entity_id: Option = { let mut kdao = self .knowledge_dao .lock() .expect("Unable to lock KnowledgeDao"); - if let Err(e) = kdao.delete_photo_links_for_file(&insight_cx, &file_path) { - log::warn!( - "Failed to clear entity_photo_links for {}: {:?}", - file_path, - e - ); - } - // Upsert the owner entity so the agent always has a stable entity ID to reference. let owner = crate::database::models::InsertEntity { - name: "Cameron".to_string(), + name: owner_name.clone(), entity_type: "person".to_string(), - description: "The owner of this photo collection. All memories are written from Cameron's perspective.".to_string(), + description: format!( + "The owner of this photo collection. All memories are written from {}'s perspective.", + owner_name + ), embedding: None, confidence: 1.0, status: "active".to_string(), @@ -2593,17 +3124,20 @@ Return ONLY the summary, nothing else."#, }; match kdao.upsert_entity(&insight_cx, owner) { Ok(e) => { - log::info!("Cameron entity ID: {}", e.id); + log::info!("Owner entity '{}' ID: {}", owner_name, e.id); Some(e.id) } Err(e) => { - log::warn!("Failed to upsert Cameron entity: {:?}", e); + log::warn!("Failed to upsert owner entity '{}': {:?}", owner_name, e); None } } }; - // 7. Load image if vision capable + // 7. Load image if vision capable. + // In hybrid mode we ALSO describe it locally now so the + // description can be inlined as text — the OpenRouter chat model + // never receives the base64 image directly. let image_base64 = if has_vision { match self.load_image_as_base64(&file_path) { Ok(b64) => { @@ -2619,29 +3153,60 @@ Return ONLY the summary, nothing else."#, None }; + let hybrid_visual_description: Option = if is_hybrid { + match image_base64.as_deref() { + Some(b64) => match self.ollama.describe_image(b64).await { + Ok(desc) => { + log::info!( + "Hybrid: local vision describe succeeded ({} chars)", + desc.len() + ); + Some(desc) + } + Err(e) => { + log::warn!( + "Hybrid: local vision describe failed, continuing without: {}", + e + ); + None + } + }, + None => None, + } + } else { + None + }; + // 8. Build system message - let cameron_id_note = match cameron_entity_id { + let owner_id_note = match owner_entity_id { Some(id) => format!( - "\n\nYour identity in the knowledge store: Cameron (entity ID: {}). \ - When storing facts where you (Cameron) are the object — for example, someone is your friend, \ + "\n\nYour identity in the knowledge store: {name} (entity ID: {id}). \ + When storing facts where you ({name}) are the object — for example, someone is your friend, \ sibling, or colleague — use subject_entity_id for the other person and set object_value to \ - \"Cameron\" (or use store_fact with the other person as subject). When storing facts about \ - Cameron directly, use {} as the subject_entity_id.", - id, id + \"{name}\" (or use store_fact with the other person as subject). When storing facts about \ + {name} directly, use {id} as the subject_entity_id.", + name = owner_name, + id = id ), None => String::new(), }; + let fewshot_block = Self::render_fewshot_examples(&fewshot_examples); let base_system = format!( - "You are a personal photo memory assistant helping to reconstruct a memory from a photo.{cameron_id_note}\n\n\ + "You are a personal photo memory assistant helping to reconstruct a memory from a photo.{owner_id_note}\n\n\ + {fewshot_block}\ IMPORTANT INSTRUCTIONS:\n\ 1. You MUST call multiple tools to gather context BEFORE writing any final insight. Do not produce a final answer after only one or two tool calls.\n\ - 2. When calling get_sms_messages and search_rag, always make at least one call WITHOUT a contact filter to capture what else was happening in Cameron's life around this date — other conversations, events, and activities provide important wider context even when a specific contact is known.\n\ + 2. When calling get_sms_messages and search_rag, always make at least one call WITHOUT a contact filter to capture what else was happening in {owner_name}'s life around this date — other conversations, events, and activities provide important wider context even when a specific contact is known.\n\ 3. Use recall_facts_for_photo to load any previously stored knowledge about subjects in this photo.\n\ 4. Use recall_entities to look up known people, places, or things that appear in this photo.\n\ 5. When you identify people, places, events, or notable things in this photo: use store_entity to record them and store_fact to record key facts (relationships, roles, attributes). This builds a persistent memory for future insights.\n\ - 6. Only produce your final insight AFTER you have gathered context from at least 5-12 tool calls.\n\ - 7. If a tool returns no results, that is useful information — continue calling the remaining tools anyway.", - cameron_id_note = cameron_id_note + 6. Only produce your final insight AFTER you have gathered context from at least 5 tool calls.\n\ + 7. If a tool returns no results, that is useful information — continue calling the remaining tools anyway.\n\ + 8. You have a hard budget of {max_iterations} tool-calling iterations before the loop ends. Plan your context gathering so you can write a complete final insight within that budget.", + owner_id_note = owner_id_note, + fewshot_block = fewshot_block, + owner_name = owner_name, + max_iterations = max_iterations ); let system_content = if let Some(ref custom) = custom_system_prompt { format!("{}\n\n{}", custom, base_system) @@ -2672,8 +3237,13 @@ Return ONLY the summary, nothing else."#, .map(|c| format!("Contact/Person: {}", c)) .unwrap_or_else(|| "Contact/Person: unknown".to_string()); + let visual_block = hybrid_visual_description + .as_deref() + .map(|d| format!("Visual description (from local vision model):\n{}\n\n", d)) + .unwrap_or_default(); + let user_content = format!( - "Please analyze this photo and gather any relevant context from the surrounding weeks.\n\n\ + "{visual_block}Please analyze this photo and gather any relevant context from the surrounding weeks.\n\n\ Photo file path: {}\n\ Date taken: {}\n\ {}\n\ @@ -2686,21 +3256,32 @@ Return ONLY the summary, nothing else."#, contact_info, gps_info, tags_info, + visual_block = visual_block, ); - // 10. Define tools - let tools = Self::build_tool_definitions(has_vision); + // 10. Define tools. Hybrid mode omits `describe_photo` since the + // chat model receives the visual description inline. + let offer_describe_tool = has_vision && !is_hybrid; + let tools = Self::build_tool_definitions(offer_describe_tool); - // 11. Build initial messages + // 11. Build initial messages. In hybrid mode images are never + // attached to the wire message — the description is part of + // `user_content`. let system_msg = ChatMessage::system(system_content); let mut user_msg = ChatMessage::user(user_content); - if let Some(ref img) = image_base64 { + if !is_hybrid && let Some(ref img) = image_base64 { user_msg.images = Some(vec![img.clone()]); } let mut messages = vec![system_msg, user_msg]; - // 12. Agentic loop + // 12. Agentic loop — dispatch through the selected backend. + let chat_backend: &dyn LlmClient = if let Some(ref or_c) = openrouter_client { + or_c + } else { + &ollama_client + }; + let loop_span = tracer.start_with_context("ai.agentic.loop", &insight_cx); let loop_cx = insight_cx.with_span(loop_span); @@ -2713,7 +3294,7 @@ Return ONLY the summary, nothing else."#, iterations_used = iteration + 1; log::info!("Agentic iteration {}/{}", iteration + 1, max_iterations); - let (response, prompt_tokens, eval_tokens) = ollama_client + let (response, prompt_tokens, eval_tokens) = chat_backend .chat_with_tools(messages.clone(), tools.clone()) .await?; @@ -2744,7 +3325,7 @@ Return ONLY the summary, nothing else."#, { for tool_call in tool_calls { log::info!( - "Agentic tool call [{}]: {} {:?}", + "Agentic tool call [{}]: {} {}", iteration, tool_call.function.name, tool_call.function.arguments @@ -2775,10 +3356,11 @@ Return ONLY the summary, nothing else."#, "Agentic loop exhausted after {} iterations, requesting final answer", iterations_used ); - messages.push(ChatMessage::user( - "Based on the context gathered, please write the final photo insight: a title and a detailed personal summary. Write in first person as Cameron.", - )); - let (final_response, prompt_tokens, eval_tokens) = ollama_client + messages.push(ChatMessage::user(format!( + "Based on the context gathered, please write the final photo insight: a title and a detailed personal summary. Write in first person as {}.", + user_display_name() + ))); + let (final_response, prompt_tokens, eval_tokens) = chat_backend .chat_with_tools(messages.clone(), vec![]) .await?; last_prompt_eval_count = prompt_tokens; @@ -2792,10 +3374,18 @@ Return ONLY the summary, nothing else."#, .set_attribute(KeyValue::new("iterations_used", iterations_used as i64)); loop_cx.span().set_status(Status::Ok); - // 13. Generate title - let title = ollama_client - .generate_photo_title(&final_content, custom_system_prompt.as_deref()) + // 13. Generate title via the same backend so voice stays consistent. + let title_prompt = format!( + "Create a short title (maximum 8 words) for the following journal entry:\n\n{}\n\nCapture the key moment or theme. Return ONLY the title, nothing else.", + final_content + ); + let title_system = custom_system_prompt.as_deref().unwrap_or( + "You are my long term memory assistant. Use only the information provided. Do not invent details.", + ); + let title_raw = chat_backend + .generate(&title_prompt, Some(title_system), None) .await?; + let title = title_raw.trim().trim_matches('"').to_string(); log::info!("Agentic generated title: {}", title); log::info!( @@ -2814,15 +3404,23 @@ Return ONLY the summary, nothing else."#, }; // 15. Store insight (returns the persisted row including its new id) + let model_version = chat_backend.primary_model().to_string(); + let fewshot_source_ids_json = if fewshot_source_ids.is_empty() { + None + } else { + Some(serde_json::to_string(&fewshot_source_ids).unwrap_or_else(|_| "[]".to_string())) + }; let insight = InsertPhotoInsight { library_id: crate::libraries::PRIMARY_LIBRARY_ID, file_path: file_path.to_string(), title, summary: final_content, generated_at: Utc::now().timestamp(), - model_version: ollama_client.primary_model.clone(), + model_version, is_current: true, training_messages, + backend: backend_label.clone(), + fewshot_source_ids: fewshot_source_ids_json, }; let stored = { @@ -2942,6 +3540,7 @@ Return ONLY the summary, nothing else."#, #[cfg(test)] mod tests { use super::*; + use crate::ai::ollama::{ToolCall, ToolCallFunction}; #[test] fn combine_contexts_includes_tags_section_when_tags_present() { @@ -2983,4 +3582,219 @@ mod tests { let result = InsightGenerator::combine_contexts(None, None, None, None, None); assert_eq!(result, "No additional context available"); } + + // These tests assert the shape of the strings returned by the tool + // implementations above. If a tool's output format changes, update the + // tool AND the corresponding arm of `summarize_tool_result` — these + // tests exist to make that coupling loud. + + #[test] + fn summarize_errors_uniformly() { + assert_eq!( + InsightGenerator::summarize_tool_result("search_rag", "Error searching RAG: boom"), + "error" + ); + assert_eq!( + InsightGenerator::summarize_tool_result( + "get_sms_messages", + "Error fetching SMS messages: timeout" + ), + "error" + ); + } + + #[test] + fn summarize_empty_results_uniformly() { + assert_eq!( + InsightGenerator::summarize_tool_result("search_rag", "No relevant messages found."), + "empty (pivoted)" + ); + assert_eq!( + InsightGenerator::summarize_tool_result("get_sms_messages", "No messages found."), + "empty (pivoted)" + ); + assert_eq!( + InsightGenerator::summarize_tool_result( + "reverse_geocode", + "Could not resolve coordinates to a place name." + ), + "empty (pivoted)" + ); + assert_eq!( + InsightGenerator::summarize_tool_result( + "recall_facts_for_photo", + "No knowledge facts found for this photo." + ), + "empty (pivoted)" + ); + } + + #[test] + fn summarize_found_count_per_tool() { + assert_eq!( + InsightGenerator::summarize_tool_result( + "get_sms_messages", + "Found 7 messages:\n[2023-08-15 10:00] Sarah: hi" + ), + "7 messages" + ); + assert_eq!( + InsightGenerator::summarize_tool_result( + "search_messages", + "Found 3 messages (mode: hybrid):\n\n[2023-08-15] Sarah — hi" + ), + "3 messages" + ); + assert_eq!( + InsightGenerator::summarize_tool_result( + "get_calendar_events", + "Found 2 calendar events:\n[2023-08-15 10:00] Wedding" + ), + "2 events" + ); + assert_eq!( + InsightGenerator::summarize_tool_result( + "get_location_history", + "Found 5 location records:\n[2023-08-15 10:00] 39.0, -120.0" + ), + "5 location records" + ); + } + + #[test] + fn summarize_search_rag_counts_hits() { + let raw = "[2023-08-15] Sarah: venue confirmed\n\n[2023-08-14] Mom: travel plans\n\n[2023-08-13] Dad: weather"; + assert_eq!( + InsightGenerator::summarize_tool_result("search_rag", raw), + "3 rag hits" + ); + } + + #[test] + fn summarize_get_file_tags() { + assert_eq!( + InsightGenerator::summarize_tool_result("get_file_tags", "wedding, tahoe, 2023"), + "3 tags" + ); + } + + #[test] + fn summarize_describe_photo_truncates() { + let raw = "A wedding ceremony at Lake Tahoe with about 40 guests seated in rows facing a lakeside arch decorated with white flowers."; + let out = InsightGenerator::summarize_tool_result("describe_photo", raw); + assert!(out.starts_with("described: \"")); + assert!(out.contains("A wedding ceremony at Lake Tahoe")); + assert!(out.ends_with("...\"")); + } + + #[test] + fn summarize_reverse_geocode_returns_place() { + let out = + InsightGenerator::summarize_tool_result("reverse_geocode", "South Lake Tahoe, CA, USA"); + assert_eq!(out, "place: South Lake Tahoe, CA, USA"); + } + + #[test] + fn summarize_recall_entities_counts_lines() { + let raw = "Known entities:\n- Sarah (person)\n- Tahoe (place)\n- Wedding 2023 (event)"; + assert_eq!( + InsightGenerator::summarize_tool_result("recall_entities", raw), + "3 entities" + ); + } + + #[test] + fn summarize_recall_facts_counts_lines() { + let raw = "Knowledge for this photo:\n- Sarah: college friend\n- Tahoe: vacation spot"; + assert_eq!( + InsightGenerator::summarize_tool_result("recall_facts_for_photo", raw), + "2 facts" + ); + } + + #[test] + fn summarize_store_entity_extracts_id() { + assert_eq!( + InsightGenerator::summarize_tool_result( + "store_entity", + "Entity stored: ID:42 | person | Sarah | confidence:0.80" + ), + "stored id=42" + ); + } + + #[test] + fn summarize_store_fact_extracts_id() { + assert_eq!( + InsightGenerator::summarize_tool_result( + "store_fact", + "Stored new fact: ID:17 | confidence:0.60" + ), + "stored id=17" + ); + assert_eq!( + InsightGenerator::summarize_tool_result( + "store_fact", + "Corroborated existing fact: ID:17 | confidence:0.85" + ), + "stored id=17" + ); + } + + #[test] + fn summarize_current_datetime() { + assert_eq!( + InsightGenerator::summarize_tool_result( + "get_current_datetime", + "Current date/time: 2024-01-15 12:00:00 PST (Monday)" + ), + "time noted" + ); + } + + #[test] + fn summarize_unknown_tool_falls_back_to_char_count() { + let out = InsightGenerator::summarize_tool_result("never_heard_of_it", "some output"); + assert_eq!(out, "11 chars"); + } + + #[test] + fn render_fewshot_empty_returns_empty_string() { + assert!(InsightGenerator::render_fewshot_examples(&[]).is_empty()); + } + + #[test] + fn render_single_trajectory_walks_tool_calls_in_order() { + let arguments = serde_json::json!({ "query": "wedding", "date": "2023-08-15" }); + let msgs = vec![ + ChatMessage::system("ignored"), + ChatMessage::user("Photo file path: /photos/img.jpg\nDate taken: August 15, 2023"), + ChatMessage { + role: "assistant".to_string(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + function: ToolCallFunction { + name: "search_rag".to_string(), + arguments, + }, + id: None, + }]), + images: None, + }, + ChatMessage::tool_result("No relevant messages found."), + ChatMessage { + role: "assistant".to_string(), + content: "Final title\n\nFinal body.".to_string(), + tool_calls: None, + images: None, + }, + ]; + let out = InsightGenerator::render_single_trajectory(&msgs); + assert!(out.contains("Input:")); + assert!(out.contains("/photos/img.jpg")); + assert!(out.contains("1. search_rag(")); + assert!(out.contains("query=\"wedding\"")); + assert!(out.contains("-> empty (pivoted)")); + assert!(out.contains("Final insight: Final title")); + } } diff --git a/src/ai/llm_client.rs b/src/ai/llm_client.rs new file mode 100644 index 0000000..8d68978 --- /dev/null +++ b/src/ai/llm_client.rs @@ -0,0 +1,172 @@ +use anyhow::Result; +use async_trait::async_trait; +use futures::stream::BoxStream; +use serde::{Deserialize, Serialize}; + +/// Provider-agnostic surface for LLM backends (Ollama, OpenRouter, …). +/// +/// Impls translate these canonical shapes at the wire boundary: tool-call +/// arguments stay as `serde_json::Value` in memory and are stringified only +/// when a provider requires it (OpenAI-compatible APIs do), and `images` +/// stays as base64 strings here and is rewritten into content-parts where +/// needed. +// First consumer lands in a later PR (OpenRouter impl + hybrid mode routing). +#[allow(dead_code)] +#[async_trait] +pub trait LlmClient: Send + Sync { + /// Single-shot text generation. Optional system prompt and optional + /// base64 images (ignored by providers without vision support). + async fn generate( + &self, + prompt: &str, + system: Option<&str>, + images: Option>, + ) -> Result; + + /// Multi-turn chat with tool definitions. Returns the assistant message + /// (which may contain tool_calls) plus optional prompt/eval token counts. + async fn chat_with_tools( + &self, + messages: Vec, + tools: Vec, + ) -> Result<(ChatMessage, Option, Option)>; + + /// Streaming variant of `chat_with_tools`. The returned stream yields + /// `TextDelta` items as content is produced, then a single terminal + /// `Done` carrying the complete assembled message (with tool_calls, if + /// any) plus token usage counts. Implementations that can't stream may + /// fall back to calling `chat_with_tools` and emitting the full reply + /// as one `Done` event. + async fn chat_with_tools_stream( + &self, + messages: Vec, + tools: Vec, + ) -> Result>>; + + /// Batch embedding generation. Dimensionality is provider/model specific. + async fn generate_embeddings(&self, texts: &[&str]) -> Result>>; + + /// One-shot vision description of an image. Used to convert images into + /// plain text for the hybrid-mode conversation flow. + async fn describe_image(&self, image_base64: &str) -> Result; + + /// Enumerate available models with their capabilities. + async fn list_models(&self) -> Result>; + + /// Look up capabilities for a single model. + async fn model_capabilities(&self, model: &str) -> Result; + + /// Primary model identifier this client was constructed with. + fn primary_model(&self) -> &str; +} + +/// Events emitted by streaming `chat_with_tools_stream`. A stream is a +/// sequence of zero or more `TextDelta` events followed by exactly one +/// `Done`. Callers should treat `Done` as terminal — further items (if any +/// slip through due to upstream misbehavior) are safe to ignore. +#[derive(Debug, Clone)] +pub enum LlmStreamEvent { + /// Incremental content token(s) from the model. Concatenate in order to + /// reconstruct the assistant's final text. + TextDelta(String), + /// Terminal event with the full assembled message (content + any + /// tool_calls). `message.content` equals the concatenation of every + /// preceding `TextDelta.0`. + Done { + message: ChatMessage, + prompt_eval_count: Option, + eval_count: Option, + }, +} + +/// Tool definition sent to the model (OpenAI-compatible function schema). +#[derive(Serialize, Clone, Debug)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, // always "function" + pub function: ToolFunction, +} + +#[derive(Serialize, Clone, Debug)] +pub struct ToolFunction { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +impl Tool { + pub fn function(name: &str, description: &str, parameters: serde_json::Value) -> Self { + Self { + tool_type: "function".to_string(), + function: ToolFunction { + name: name.to_string(), + description: description.to_string(), + parameters, + }, + } + } +} + +/// A message in the chat conversation history. +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ChatMessage { + pub role: String, // "system" | "user" | "assistant" | "tool" + /// Empty string (not null) when tool_calls is present — Ollama quirk. + #[serde(default)] + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + /// Base64 images — only on user messages to vision-capable models. + #[serde(skip_serializing_if = "Option::is_none")] + pub images: Option>, +} + +impl ChatMessage { + pub fn system(content: impl Into) -> Self { + Self { + role: "system".to_string(), + content: content.into(), + tool_calls: None, + images: None, + } + } + pub fn user(content: impl Into) -> Self { + Self { + role: "user".to_string(), + content: content.into(), + tool_calls: None, + images: None, + } + } + pub fn tool_result(content: impl Into) -> Self { + Self { + role: "tool".to_string(), + content: content.into(), + tool_calls: None, + images: None, + } + } +} + +/// Tool call returned by the model in an assistant message. +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ToolCall { + pub function: ToolCallFunction, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ToolCallFunction { + pub name: String, + /// Canonical shape: native JSON. Providers that use JSON-encoded-string + /// arguments (OpenAI-compatible) translate at their wire boundary. + pub arguments: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ModelCapabilities { + pub name: String, + pub has_vision: bool, + pub has_tool_calling: bool, +} diff --git a/src/ai/mod.rs b/src/ai/mod.rs index 4e682fb..94e8541 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -1,17 +1,37 @@ pub mod daily_summary_job; pub mod handlers; +pub mod insight_chat; pub mod insight_generator; +pub mod llm_client; pub mod ollama; +pub mod openrouter; pub mod sms_client; // strip_summary_boilerplate is used by binaries (test_daily_summary), not the library #[allow(unused_imports)] -pub use daily_summary_job::{generate_daily_summaries, strip_summary_boilerplate}; +pub use daily_summary_job::{ + DAILY_SUMMARY_MESSAGE_LIMIT, DAILY_SUMMARY_SYSTEM_PROMPT, build_daily_summary_prompt, + generate_daily_summaries, strip_summary_boilerplate, +}; pub use handlers::{ + chat_history_handler, chat_rewind_handler, chat_stream_handler, chat_turn_handler, delete_insight_handler, export_training_data_handler, generate_agentic_insight_handler, generate_insight_handler, get_all_insights_handler, get_available_models_handler, - get_insight_handler, rate_insight_handler, + get_insight_handler, get_openrouter_models_handler, rate_insight_handler, }; pub use insight_generator::InsightGenerator; -pub use ollama::{ModelCapabilities, OllamaClient}; +#[allow(unused_imports)] +pub use llm_client::{ + ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction, ToolFunction, +}; +pub use ollama::{EMBEDDING_MODEL, OllamaClient}; pub use sms_client::{SmsApiClient, SmsMessage}; + +/// Display name used for the user in message transcripts and first-person +/// prompt text. Reads the `USER_NAME` env var; defaults to `"Me"`. Models +/// often confuse `"Me:"` in a transcript with their own role — setting +/// `USER_NAME=Cameron` (or similar) in the environment eliminates that +/// ambiguity across daily summaries, insight generation, and chat. +pub fn user_display_name() -> String { + std::env::var("USER_NAME").unwrap_or_else(|_| "Me".to_string()) +} diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs index 184bc61..c56e1e7 100644 --- a/src/ai/ollama.rs +++ b/src/ai/ollama.rs @@ -1,14 +1,43 @@ use anyhow::{Context, Result}; +use async_trait::async_trait; use chrono::NaiveDate; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; +use crate::ai::llm_client::{LlmClient, LlmStreamEvent}; +use futures::stream::{BoxStream, StreamExt}; + +// Re-export shared types so existing `crate::ai::ollama::{...}` imports +// continue to resolve. +pub use crate::ai::llm_client::{ChatMessage, ModelCapabilities, Tool}; +#[allow(unused_imports)] +pub use crate::ai::llm_client::{ToolCall, ToolCallFunction, ToolFunction}; + // Cache duration: 15 minutes const CACHE_DURATION_SECS: u64 = 15 * 60; +/// Default total request timeout for generation calls, in seconds. +/// Overridable via `OLLAMA_REQUEST_TIMEOUT_SECONDS` env var for slow +/// CPU-offloaded models where inference can take several minutes. +const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 120; + +fn configured_request_timeout_secs() -> u64 { + std::env::var("OLLAMA_REQUEST_TIMEOUT_SECONDS") + .ok() + .and_then(|v| v.parse::().ok()) + .filter(|&s| s > 0) + .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS) +} + +/// Embedding model used across the app. Callers that persist a +/// `model_version` alongside an embedding should read this constant so the +/// stored label always matches what `generate_embeddings` actually ran. +pub const EMBEDDING_MODEL: &str = "nomic-embed-text:v1.5"; + // Cached entry with timestamp #[derive(Clone)] struct CachedEntry { @@ -50,6 +79,12 @@ pub struct OllamaClient { top_p: Option, top_k: Option, min_p: Option, + /// Sticky preference shared across clones: when the fallback server + /// succeeded most recently, try it first on the next call. Avoids + /// re-probing the primary with a model it doesn't have loaded across + /// every iteration of the agent loop. `Arc` so cloning + /// `OllamaClient` shares the flag rather than resetting it. + prefer_fallback: Arc, } impl OllamaClient { @@ -62,7 +97,7 @@ impl OllamaClient { Self { client: Client::builder() .connect_timeout(Duration::from_secs(5)) // Quick connection timeout - .timeout(Duration::from_secs(120)) // Total request timeout for generation + .timeout(Duration::from_secs(configured_request_timeout_secs())) .build() .unwrap_or_else(|_| Client::new()), primary_url, @@ -74,9 +109,44 @@ impl OllamaClient { top_p: None, top_k: None, min_p: None, + prefer_fallback: Arc::new(AtomicBool::new(false)), } } + /// Return the server attempt order as `(label, url, model)` tuples. + /// Respects the sticky `prefer_fallback` flag so the most recently + /// successful server is tried first. + fn attempt_order(&self) -> Vec<(&'static str, String, String)> { + let primary = ( + "primary", + self.primary_url.clone(), + self.primary_model.clone(), + ); + let fallback = self.fallback_url.as_ref().map(|url| { + let model = self + .fallback_model + .clone() + .unwrap_or_else(|| self.primary_model.clone()); + ("fallback", url.clone(), model) + }); + + let prefer_fallback = fallback.is_some() && self.prefer_fallback.load(Ordering::Relaxed); + + let mut order = Vec::with_capacity(2); + if prefer_fallback { + if let Some(fb) = fallback.clone() { + order.push(fb); + } + order.push(primary); + } else { + order.push(primary); + if let Some(fb) = fallback { + order.push(fb); + } + } + order + } + pub fn set_num_ctx(&mut self, num_ctx: Option) { self.num_ctx = num_ctx; } @@ -311,6 +381,7 @@ impl OllamaClient { prompt: &str, system: Option<&str>, images: Option>, + think: Option, ) -> Result { let request = OllamaRequest { model: model.to_string(), @@ -319,6 +390,7 @@ impl OllamaClient { system: system.map(|s| s.to_string()), options: self.build_options(), images, + think, }; let response = self @@ -339,6 +411,12 @@ impl OllamaClient { } let result: OllamaResponse = response.json().await?; + log_chat_metrics( + result.prompt_eval_count, + result.prompt_eval_duration, + result.eval_count, + result.eval_duration, + ); Ok(result.response) } @@ -346,11 +424,31 @@ impl OllamaClient { self.generate_with_images(prompt, system, None).await } + /// Variant of `generate` that sets Ollama's top-level `think: false`. + /// Used by latency-sensitive callers like the rerank pass, where the + /// task has nothing to reason about and chain-of-thought tokens are + /// wasted wall time. Server-side no-op on non-reasoning models. + pub async fn generate_no_think(&self, prompt: &str, system: Option<&str>) -> Result { + self.generate_with_options(prompt, system, None, Some(false)) + .await + } + pub async fn generate_with_images( &self, prompt: &str, system: Option<&str>, images: Option>, + ) -> Result { + self.generate_with_options(prompt, system, images, None) + .await + } + + async fn generate_with_options( + &self, + prompt: &str, + system: Option<&str>, + images: Option>, + think: Option, ) -> Result { log::debug!("=== Ollama Request ==="); log::debug!("Primary model: {}", self.primary_model); @@ -376,6 +474,7 @@ impl OllamaClient { prompt, system, images.clone(), + think, ) .await; @@ -399,7 +498,14 @@ impl OllamaClient { fallback_model ); match self - .try_generate(fallback_url, fallback_model, prompt, system, images.clone()) + .try_generate( + fallback_url, + fallback_model, + prompt, + system, + images.clone(), + think, + ) .await { Ok(response) => { @@ -471,6 +577,7 @@ Capture the key moment or theme. Return ONLY the title, nothing else."#, ) -> Result { let location_str = location.unwrap_or("Unknown"); let sms_str = sms_summary.unwrap_or("No messages"); + let user_name = crate::ai::user_display_name(); let prompt = if image_base64.is_some() { if let Some(contact_name) = contact { @@ -482,13 +589,14 @@ Location: {} Person/Contact: {} Messages: {} -Analyze the image and use specific details from both the visual content and the context above. The photo is from a folder for {}, so they are likely in or related to this photo. Mention people's names (especially {}), places, or activities if they appear in either the image or the context. Write in first person as Cameron with the tone of a journal entry. If limited information is available, keep it simple and factual based on what you see and know. If the location is unknown omit it"#, +Analyze the image and use specific details from both the visual content and the context above. The photo is from a folder for {}, so they are likely in or related to this photo. Mention people's names (especially {}), places, or activities if they appear in either the image or the context. Write in first person as {} with the tone of a journal entry. If limited information is available, keep it simple and factual based on what you see and know. If the location is unknown omit it"#, date.format("%B %d, %Y"), location_str, contact_name, sms_str, contact_name, - contact_name + contact_name, + user_name ) } else { format!( @@ -498,10 +606,11 @@ Date: {} Location: {} Messages: {} -Analyze the image and use specific details from both the visual content and the context above. Mention people's names, places, or activities if they appear in either the image or the context. Write in first person as Cameron with the tone of a journal entry. If limited information is available, keep it simple and factual based on what you see and know. If the location is unknown omit it"#, +Analyze the image and use specific details from both the visual content and the context above. Mention people's names, places, or activities if they appear in either the image or the context. Write in first person as {} with the tone of a journal entry. If limited information is available, keep it simple and factual based on what you see and know. If the location is unknown omit it"#, date.format("%B %d, %Y"), location_str, - sms_str + sms_str, + user_name ) } } else if let Some(contact_name) = contact { @@ -513,13 +622,14 @@ Analyze the image and use specific details from both the visual content and the Person/Contact: {} Messages: {} - Use only the specific details provided above. The photo is from a folder for {}, so they are likely related to this moment. Mention people's names (especially {}), places, or activities if they appear in the context. Write in first person as Cameron with the tone of a journal entry. If limited information is available, keep it simple and factual. If the location is unknown omit it"#, + Use only the specific details provided above. The photo is from a folder for {}, so they are likely related to this moment. Mention people's names (especially {}), places, or activities if they appear in the context. Write in first person as {} with the tone of a journal entry. If limited information is available, keep it simple and factual. If the location is unknown omit it"#, date.format("%B %d, %Y"), location_str, contact_name, sms_str, contact_name, - contact_name + contact_name, + user_name ) } else { format!( @@ -529,10 +639,11 @@ Analyze the image and use specific details from both the visual content and the Location: {} Messages: {} - Use only the specific details provided above. Mention people's names, places, or activities if they appear in the context. Write in first person as Cameron with the tone of a journal entry. If limited information is available, keep it simple and factual. If the location is unknown omit it"#, + Use only the specific details provided above. Mention people's names, places, or activities if they appear in the context. Write in first person as {} with the tone of a journal entry. If limited information is available, keep it simple and factual. If the location is unknown omit it"#, date.format("%B %d, %Y"), location_str, - sms_str + sms_str, + user_name ) }; @@ -561,68 +672,229 @@ Analyze the image and use specific details from both the visual content and the /// Send a chat request with tool definitions to /api/chat. /// Returns the assistant's response message (may contain tool_calls or final content). - /// Uses primary/fallback URL routing same as other generation methods. + /// Tries servers in preference order — most recently successful first — + /// so a fallback-only model doesn't re-404 against the primary on every + /// iteration of the agent loop. pub async fn chat_with_tools( &self, messages: Vec, tools: Vec, ) -> Result<(ChatMessage, Option, Option)> { - // Try primary server first - log::info!( - "Attempting chat_with_tools with primary server: {} (model: {})", - self.primary_url, - self.primary_model - ); - let primary_result = self - .try_chat_with_tools(&self.primary_url, messages.clone(), tools.clone()) - .await; - - match primary_result { - Ok(result) => { - log::info!("Successfully got chat_with_tools response from primary server"); - Ok(result) - } - Err(e) => { - log::warn!("Primary server chat_with_tools failed: {}", e); - - // Try fallback server if available - if let Some(fallback_url) = &self.fallback_url { - let fallback_model = - self.fallback_model.as_ref().unwrap_or(&self.primary_model); + let order = self.attempt_order(); + let mut errors: Vec = Vec::new(); + for (label, url, model) in &order { + log::info!( + "Attempting chat_with_tools with {} server: {} (model: {})", + label, + url, + model + ); + match self + .try_chat_with_tools(url, messages.clone(), tools.clone()) + .await + { + Ok(result) => { log::info!( - "Attempting chat_with_tools with fallback server: {} (model: {})", - fallback_url, - fallback_model + "Successfully got chat_with_tools response from {} server", + label ); - match self - .try_chat_with_tools(fallback_url, messages, tools) - .await - { - Ok(result) => { - log::info!( - "Successfully got chat_with_tools response from fallback server" - ); - Ok(result) - } - Err(fallback_e) => { - log::error!( - "Fallback server chat_with_tools also failed: {}", - fallback_e - ); - Err(anyhow::anyhow!( - "Both primary and fallback servers failed. Primary: {}, Fallback: {}", - e, - fallback_e - )) - } - } - } else { - log::error!("No fallback server configured"); - Err(e) + self.prefer_fallback + .store(*label == "fallback", Ordering::Relaxed); + return Ok(result); + } + Err(e) => { + log::warn!("{} server chat_with_tools failed: {}", label, e); + errors.push(format!("{}: {}", label, e)); } } } + + if order.len() <= 1 { + log::error!("No fallback server configured; chat_with_tools exhausted"); + } else { + log::error!( + "All {} servers failed for chat_with_tools ({})", + order.len(), + errors.join(" / ") + ); + } + Err(anyhow::anyhow!( + "chat_with_tools failed on all servers: {}", + errors.join(" / ") + )) + } + + /// Streaming variant of `chat_with_tools`. Tries primary, then falls + /// back if the initial connection fails; once the stream has begun + /// emitting, mid-stream errors propagate to the caller. Emits + /// `TextDelta` events as content tokens arrive and a single terminal + /// `Done` event when the model marks the turn complete (tool_calls, if + /// any, live on the final message). + pub async fn chat_with_tools_stream( + &self, + messages: Vec, + tools: Vec, + ) -> Result>> { + // Same preference logic as `chat_with_tools`. Only the initial + // connection is retried across servers — once the stream begins, + // mid-stream errors propagate to the caller. + let order = self.attempt_order(); + let mut last_err: Option = None; + + for (label, url, _model) in &order { + match self + .try_chat_with_tools_stream(url, messages.clone(), tools.clone()) + .await + { + Ok(s) => { + self.prefer_fallback + .store(*label == "fallback", Ordering::Relaxed); + return Ok(s); + } + Err(e) => { + log::warn!("Streaming chat on {} server failed: {}", label, e); + last_err = Some(e); + } + } + } + + Err(last_err.unwrap_or_else(|| anyhow::anyhow!("No Ollama server configured"))) + } + + async fn try_chat_with_tools_stream( + &self, + base_url: &str, + messages: Vec, + tools: Vec, + ) -> Result>> { + let url = format!("{}/api/chat", base_url); + let model = if base_url == self.primary_url { + &self.primary_model + } else { + self.fallback_model + .as_deref() + .unwrap_or(&self.primary_model) + }; + let options = self.build_options(); + + let request_body = OllamaChatRequest { + model, + messages: &messages, + stream: true, + tools, + options, + }; + + let response = self + .client + .post(&url) + .json(&request_body) + .send() + .await + .with_context(|| format!("Failed to connect to Ollama at {}", url))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!( + "Ollama stream request failed with status {}: {}", + status, + body + ); + } + + // Ollama streams NDJSON: each line is a full `OllamaStreamChunk`. + // We buffer partial lines across chunks from the byte stream. + let byte_stream = response.bytes_stream(); + let stream = async_stream::stream! { + let mut buf: Vec = Vec::new(); + let mut accumulated = String::new(); + let mut tool_calls: Option> = None; + let mut role = "assistant".to_string(); + let mut prompt_eval_count: Option = None; + let mut eval_count: Option = None; + let mut prompt_eval_duration: Option = None; + let mut eval_duration: Option = None; + let mut done_seen = false; + + let mut byte_stream = byte_stream; + while let Some(chunk) = byte_stream.next().await { + let chunk = match chunk { + Ok(b) => b, + Err(e) => { + yield Err(anyhow::anyhow!("stream read failed: {}", e)); + return; + } + }; + buf.extend_from_slice(&chunk); + + // Drain complete lines; hold any trailing partial. + while let Some(nl) = buf.iter().position(|b| *b == b'\n') { + let line = buf.drain(..=nl).collect::>(); + let line_str = match std::str::from_utf8(&line) { + Ok(s) => s.trim(), + Err(_) => continue, + }; + if line_str.is_empty() { + continue; + } + match serde_json::from_str::(line_str) { + Ok(chunk) => { + // Accumulate content delta. + if !chunk.message.content.is_empty() { + accumulated.push_str(&chunk.message.content); + yield Ok(LlmStreamEvent::TextDelta(chunk.message.content)); + } + if !chunk.message.role.is_empty() { + role = chunk.message.role; + } + // Ollama only attaches tool_calls on the final chunk. + if let Some(tcs) = chunk.message.tool_calls + && !tcs.is_empty() + { + tool_calls = Some(tcs); + } + if chunk.done { + prompt_eval_count = chunk.prompt_eval_count; + eval_count = chunk.eval_count; + prompt_eval_duration = chunk.prompt_eval_duration; + eval_duration = chunk.eval_duration; + done_seen = true; + break; + } + } + Err(e) => { + log::warn!("malformed Ollama stream line: {} ({})", line_str, e); + } + } + } + if done_seen { + break; + } + } + + // Emit the terminal Done event with the assembled message. + log_chat_metrics( + prompt_eval_count, + prompt_eval_duration, + eval_count, + eval_duration, + ); + let message = ChatMessage { + role, + content: accumulated, + tool_calls, + images: None, + }; + yield Ok(LlmStreamEvent::Done { + message, + prompt_eval_count, + eval_count, + }); + }; + + Ok(Box::pin(stream)) } async fn try_chat_with_tools( @@ -665,8 +937,12 @@ Analyze the image and use specific details from both the visual content and the if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); - log::error!( - "chat_with_tools request body that caused {}: {}", + // warn, not error — the outer `chat_with_tools` may recover via + // the fallback server. When both fail, the outer layer emits the + // actual error log. + log::warn!( + "chat_with_tools request to {} got {}: {}", + base_url, status, request_json ); @@ -682,6 +958,17 @@ Analyze the image and use specific details from both the visual content and the .await .with_context(|| "Failed to parse Ollama chat response")?; + // Log performance counters returned by Ollama. Durations are + // reported in nanoseconds; we render ms + tokens/sec for skim-ability + // in the server log. Missing fields are left off the line rather + // than printed as `None`. + log_chat_metrics( + chat_response.prompt_eval_count, + chat_response.prompt_eval_duration, + chat_response.eval_count, + chat_response.eval_duration, + ); + Ok(( chat_response.message, chat_response.prompt_eval_count, @@ -703,7 +990,7 @@ Analyze the image and use specific details from both the visual content and the /// Returns a vector of 768-dimensional vectors /// This is much more efficient than calling generate_embedding multiple times pub async fn generate_embeddings(&self, texts: &[&str]) -> Result>> { - let embedding_model = "nomic-embed-text:v1.5"; + let embedding_model = EMBEDDING_MODEL; log::debug!("=== Ollama Batch Embedding Request ==="); log::debug!("Model: {}", embedding_model); @@ -818,6 +1105,54 @@ Analyze the image and use specific details from both the visual content and the } } +#[async_trait] +impl LlmClient for OllamaClient { + async fn generate( + &self, + prompt: &str, + system: Option<&str>, + images: Option>, + ) -> Result { + self.generate_with_images(prompt, system, images).await + } + + async fn chat_with_tools( + &self, + messages: Vec, + tools: Vec, + ) -> Result<(ChatMessage, Option, Option)> { + OllamaClient::chat_with_tools(self, messages, tools).await + } + + async fn chat_with_tools_stream( + &self, + messages: Vec, + tools: Vec, + ) -> Result>> { + OllamaClient::chat_with_tools_stream(self, messages, tools).await + } + + async fn generate_embeddings(&self, texts: &[&str]) -> Result>> { + OllamaClient::generate_embeddings(self, texts).await + } + + async fn describe_image(&self, image_base64: &str) -> Result { + self.generate_photo_description(image_base64).await + } + + async fn list_models(&self) -> Result> { + Self::list_models_with_capabilities(&self.primary_url).await + } + + async fn model_capabilities(&self, model: &str) -> Result { + Self::check_model_capabilities(&self.primary_url, model).await + } + + fn primary_model(&self) -> &str { + &self.primary_model + } +} + #[derive(Serialize)] struct OllamaRequest { model: String, @@ -829,6 +1164,12 @@ struct OllamaRequest { options: Option, #[serde(skip_serializing_if = "Option::is_none")] images: Option>, + /// Ollama's top-level reasoning-mode toggle (~0.4+). `Some(false)` + /// asks the server to skip thinking on models that expose a toggle + /// (Qwen3, Ollama-integrated DeepSeek-R1 distills, GPT-OSS, etc). + /// Ignored by non-reasoning models. None = use the model's default. + #[serde(skip_serializing_if = "Option::is_none")] + think: Option, } #[derive(Serialize)] @@ -845,90 +1186,6 @@ struct OllamaOptions { min_p: Option, } -/// Tool definition sent in /api/chat requests (OpenAI-compatible format) -#[derive(Serialize, Clone, Debug)] -pub struct Tool { - #[serde(rename = "type")] - pub tool_type: String, // always "function" - pub function: ToolFunction, -} - -#[derive(Serialize, Clone, Debug)] -pub struct ToolFunction { - pub name: String, - pub description: String, - pub parameters: serde_json::Value, -} - -impl Tool { - pub fn function(name: &str, description: &str, parameters: serde_json::Value) -> Self { - Self { - tool_type: "function".to_string(), - function: ToolFunction { - name: name.to_string(), - description: description.to_string(), - parameters, - }, - } - } -} - -/// A message in the chat conversation history -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct ChatMessage { - pub role: String, // "system" | "user" | "assistant" | "tool" - /// Empty string (not null) when tool_calls is present — Ollama quirk - #[serde(default)] - pub content: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - /// Base64 images — only on user messages to vision-capable models - #[serde(skip_serializing_if = "Option::is_none")] - pub images: Option>, -} - -impl ChatMessage { - pub fn system(content: impl Into) -> Self { - Self { - role: "system".to_string(), - content: content.into(), - tool_calls: None, - images: None, - } - } - pub fn user(content: impl Into) -> Self { - Self { - role: "user".to_string(), - content: content.into(), - tool_calls: None, - images: None, - } - } - pub fn tool_result(content: impl Into) -> Self { - Self { - role: "tool".to_string(), - content: content.into(), - tool_calls: None, - images: None, - } - } -} - -/// Tool call returned by the model in an assistant message -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct ToolCall { - pub function: ToolCallFunction, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct ToolCallFunction { - pub name: String, - /// Native JSON object (NOT a JSON-encoded string like OpenAI) - pub arguments: serde_json::Value, -} - #[derive(Serialize)] struct OllamaChatRequest<'a> { model: &'a str, @@ -950,13 +1207,102 @@ struct OllamaChatResponse { done_reason: String, #[serde(default)] prompt_eval_count: Option, + /// Nanoseconds spent evaluating the prompt (context ingestion). + #[serde(default)] + prompt_eval_duration: Option, #[serde(default)] eval_count: Option, + /// Nanoseconds spent generating the response tokens. + #[serde(default)] + eval_duration: Option, +} + +/// One chunk in the NDJSON stream from `/api/chat` with `stream: true`. +/// Early chunks carry content deltas in `message.content`; the final chunk +/// has `done: true`, optional `tool_calls`, and usage counters. +#[derive(Deserialize, Debug)] +struct OllamaStreamChunk { + #[serde(default)] + message: OllamaStreamMessage, + #[serde(default)] + done: bool, + #[serde(default)] + prompt_eval_count: Option, + #[serde(default)] + prompt_eval_duration: Option, + #[serde(default)] + eval_count: Option, + #[serde(default)] + eval_duration: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct OllamaStreamMessage { + #[serde(default)] + role: String, + #[serde(default)] + content: String, + #[serde(default)] + tool_calls: Option>, } #[derive(Deserialize)] struct OllamaResponse { response: String, + #[serde(default)] + prompt_eval_count: Option, + #[serde(default)] + prompt_eval_duration: Option, + #[serde(default)] + eval_count: Option, + #[serde(default)] + eval_duration: Option, +} + +fn log_chat_metrics( + prompt_eval_count: Option, + prompt_eval_duration_ns: Option, + eval_count: Option, + eval_duration_ns: Option, +) { + // Compute tokens/sec when both count and duration are present. + fn tokens_per_sec(count: Option, duration_ns: Option) -> Option { + match (count, duration_ns) { + (Some(c), Some(d)) if c > 0 && d > 0 => Some((c as f64) * 1_000_000_000.0 / (d as f64)), + _ => None, + } + } + let prompt_ms = prompt_eval_duration_ns.map(|ns| ns as f64 / 1_000_000.0); + let eval_ms = eval_duration_ns.map(|ns| ns as f64 / 1_000_000.0); + let prompt_tps = tokens_per_sec(prompt_eval_count, prompt_eval_duration_ns); + let eval_tps = tokens_per_sec(eval_count, eval_duration_ns); + + let mut parts: Vec = Vec::new(); + if let Some(c) = prompt_eval_count { + let mut s = format!("prompt={} tok", c); + if let Some(ms) = prompt_ms { + s.push_str(&format!(" ({:.0} ms", ms)); + if let Some(tps) = prompt_tps { + s.push_str(&format!(", {:.1} tok/s", tps)); + } + s.push(')'); + } + parts.push(s); + } + if let Some(c) = eval_count { + let mut s = format!("gen={} tok", c); + if let Some(ms) = eval_ms { + s.push_str(&format!(" ({:.0} ms", ms)); + if let Some(tps) = eval_tps { + s.push_str(&format!(", {:.1} tok/s", tps)); + } + s.push(')'); + } + parts.push(s); + } + if !parts.is_empty() { + log::info!("Ollama chat metrics — {}", parts.join(", ")); + } } #[derive(Deserialize)] @@ -975,13 +1321,6 @@ struct OllamaShowResponse { capabilities: Vec, } -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct ModelCapabilities { - pub name: String, - pub has_vision: bool, - pub has_tool_calling: bool, -} - #[derive(Serialize)] struct OllamaBatchEmbedRequest { model: String, diff --git a/src/ai/openrouter.rs b/src/ai/openrouter.rs new file mode 100644 index 0000000..62a3cca --- /dev/null +++ b/src/ai/openrouter.rs @@ -0,0 +1,998 @@ +// First consumer lands in a later PR (hybrid backend routing). Tests exercise +// the translation helpers directly. +#![allow(dead_code)] + +use anyhow::{Context, Result, anyhow, bail}; +use async_trait::async_trait; +use reqwest::Client; +use serde::Deserialize; +use serde_json::{Value, json}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use crate::ai::llm_client::{ + ChatMessage, LlmClient, LlmStreamEvent, ModelCapabilities, Tool, ToolCall, ToolCallFunction, +}; +use futures::stream::{BoxStream, StreamExt}; + +const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1"; +const DEFAULT_EMBEDDING_MODEL: &str = "openai/text-embedding-3-small"; +const CACHE_DURATION_SECS: u64 = 15 * 60; + +#[derive(Clone)] +struct CachedEntry { + data: T, + cached_at: Instant, +} + +impl CachedEntry { + fn new(data: T) -> Self { + Self { + data, + cached_at: Instant::now(), + } + } + + fn is_expired(&self) -> bool { + self.cached_at.elapsed().as_secs() > CACHE_DURATION_SECS + } +} + +lazy_static::lazy_static! { + static ref MODEL_CAPABILITIES_CACHE: Arc>>>> = + Arc::new(Mutex::new(HashMap::new())); +} + +/// OpenAI-compatible client for OpenRouter (https://openrouter.ai). +/// +/// Translates canonical `ChatMessage` / `Tool` shapes to OpenAI wire format: +/// - Tool-call `arguments` serialized as JSON-encoded strings (vs Ollama's +/// native JSON). +/// - Image content rewritten into content-parts array with `image_url` entries. +/// - `role=tool` messages attach a `tool_call_id` inferred from the preceding +/// assistant turn's tool call. +#[derive(Clone)] +pub struct OpenRouterClient { + client: Client, + pub api_key: String, + pub base_url: String, + pub primary_model: String, + pub embedding_model: String, + num_ctx: Option, + temperature: Option, + top_p: Option, + top_k: Option, + min_p: Option, + /// Optional `HTTP-Referer` header OpenRouter uses for attribution. + pub referer: Option, + /// Optional `X-Title` header OpenRouter uses for attribution. + pub app_title: Option, +} + +impl OpenRouterClient { + pub fn new(api_key: String, base_url: Option, primary_model: String) -> Self { + Self { + client: Client::builder() + .connect_timeout(Duration::from_secs(10)) + .timeout(Duration::from_secs(180)) + .build() + .unwrap_or_else(|_| Client::new()), + api_key, + base_url: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string()), + primary_model, + embedding_model: DEFAULT_EMBEDDING_MODEL.to_string(), + num_ctx: None, + temperature: None, + top_p: None, + top_k: None, + min_p: None, + referer: None, + app_title: None, + } + } + + pub fn set_embedding_model(&mut self, model: String) { + self.embedding_model = model; + } + + #[allow(dead_code)] + pub fn set_num_ctx(&mut self, num_ctx: Option) { + self.num_ctx = num_ctx; + } + + #[allow(dead_code)] + pub fn set_sampling_params( + &mut self, + temperature: Option, + top_p: Option, + top_k: Option, + min_p: Option, + ) { + self.temperature = temperature; + self.top_p = top_p; + self.top_k = top_k; + self.min_p = min_p; + } + + pub fn set_attribution(&mut self, referer: Option, app_title: Option) { + self.referer = referer; + self.app_title = app_title; + } + + fn authed(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + let mut b = builder.bearer_auth(&self.api_key); + if let Some(r) = &self.referer { + b = b.header("HTTP-Referer", r); + } + if let Some(t) = &self.app_title { + b = b.header("X-Title", t); + } + b + } + + /// Translate canonical messages to the OpenAI-compatible wire shape. + /// + /// Walks in order so it can attach `tool_call_id` to `role=tool` messages + /// based on the most recent assistant turn's tool call. + fn messages_to_openai(messages: &[ChatMessage]) -> Vec { + let mut out = Vec::with_capacity(messages.len()); + let mut last_tool_call_ids: Vec = Vec::new(); + let mut next_tool_result_idx: usize = 0; + + for msg in messages { + let mut obj = serde_json::Map::new(); + obj.insert("role".into(), Value::String(msg.role.clone())); + + // Content: string OR content-parts array (when images present). + match &msg.images { + Some(images) if !images.is_empty() => { + let mut parts: Vec = Vec::new(); + if !msg.content.is_empty() { + parts.push(json!({"type": "text", "text": msg.content})); + } + for img in images { + let url = image_to_data_url(img); + parts.push(json!({ + "type": "image_url", + "image_url": { "url": url } + })); + } + obj.insert("content".into(), Value::Array(parts)); + } + _ => { + obj.insert("content".into(), Value::String(msg.content.clone())); + } + } + + // Assistant message with tool_calls: stringify arguments, remember + // the ids so the subsequent tool messages can reference them. + if let Some(tcs) = &msg.tool_calls + && msg.role == "assistant" + { + let converted: Vec = tcs + .iter() + .enumerate() + .map(|(i, call)| { + let id = call.id.clone().unwrap_or_else(|| format!("call_{}", i)); + let args_str = serde_json::to_string(&call.function.arguments) + .unwrap_or_else(|_| "{}".to_string()); + json!({ + "id": id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": args_str, + } + }) + }) + .collect(); + last_tool_call_ids = converted + .iter() + .filter_map(|v| v.get("id").and_then(|x| x.as_str()).map(String::from)) + .collect(); + next_tool_result_idx = 0; + obj.insert("tool_calls".into(), Value::Array(converted)); + } + + // Tool result messages: attach tool_call_id from the last assistant turn. + if msg.role == "tool" { + let id = last_tool_call_ids + .get(next_tool_result_idx) + .cloned() + .unwrap_or_else(|| "call_0".to_string()); + obj.insert("tool_call_id".into(), Value::String(id)); + next_tool_result_idx += 1; + } + + out.push(Value::Object(obj)); + } + + out + } + + /// Parse an OpenAI-compatible assistant message back into canonical shape. + fn openai_message_to_chat(msg: &Value) -> Result { + let obj = msg + .as_object() + .ok_or_else(|| anyhow!("response message is not an object"))?; + let role = obj + .get("role") + .and_then(|v| v.as_str()) + .unwrap_or("assistant") + .to_string(); + let content = obj + .get("content") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let tool_calls = if let Some(tcs) = obj.get("tool_calls").and_then(|v| v.as_array()) { + let mut parsed = Vec::with_capacity(tcs.len()); + for tc in tcs { + let id = tc.get("id").and_then(|v| v.as_str()).map(String::from); + let function = tc + .get("function") + .ok_or_else(|| anyhow!("tool_call missing function field"))?; + let name = function + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let args_value = match function.get("arguments") { + // OpenAI-compat: stringified JSON. + Some(Value::String(s)) => { + serde_json::from_str::(s).unwrap_or_else(|_| json!({})) + } + // Some providers emit arguments as an object directly — accept both. + Some(v @ Value::Object(_)) => v.clone(), + _ => json!({}), + }; + parsed.push(ToolCall { + id, + function: ToolCallFunction { + name, + arguments: args_value, + }, + }); + } + Some(parsed) + } else { + None + }; + + Ok(ChatMessage { + role, + content, + tool_calls, + images: None, + }) + } + + fn build_options(&self) -> Vec<(&'static str, Value)> { + let mut v = Vec::new(); + if let Some(t) = self.temperature { + v.push(("temperature", json!(t))); + } + if let Some(p) = self.top_p { + v.push(("top_p", json!(p))); + } + if let Some(k) = self.top_k { + v.push(("top_k", json!(k))); + } + if let Some(m) = self.min_p { + v.push(("min_p", json!(m))); + } + if let Some(c) = self.num_ctx { + // OpenAI uses max_tokens for generation bound; num_ctx isn't + // directly transferable. Skip rather than silently mis-map. + let _ = c; + } + v + } +} + +#[async_trait] +impl LlmClient for OpenRouterClient { + async fn generate( + &self, + prompt: &str, + system: Option<&str>, + images: Option>, + ) -> Result { + let mut messages: Vec = Vec::new(); + if let Some(sys) = system { + messages.push(ChatMessage::system(sys)); + } + let mut user = ChatMessage::user(prompt); + user.images = images; + messages.push(user); + + let (reply, _, _) = self.chat_with_tools(messages, Vec::new()).await?; + Ok(reply.content) + } + + async fn chat_with_tools( + &self, + messages: Vec, + tools: Vec, + ) -> Result<(ChatMessage, Option, Option)> { + let url = format!("{}/chat/completions", self.base_url); + let mut body = serde_json::Map::new(); + body.insert("model".into(), Value::String(self.primary_model.clone())); + body.insert( + "messages".into(), + Value::Array(Self::messages_to_openai(&messages)), + ); + body.insert("stream".into(), Value::Bool(false)); + if !tools.is_empty() { + body.insert( + "tools".into(), + serde_json::to_value(&tools).context("serializing tools")?, + ); + } + for (k, v) in self.build_options() { + body.insert(k.into(), v); + } + + log::info!( + "OpenRouter chat_with_tools: model={} messages={} tools={}", + self.primary_model, + messages.len(), + tools.len() + ); + + let resp = self + .authed(self.client.post(&url)) + .json(&Value::Object(body)) + .send() + .await + .with_context(|| format!("POST {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("OpenRouter chat request failed: {} — {}", status, body); + } + + let parsed: Value = resp.json().await.context("parsing chat response")?; + let choice = parsed + .get("choices") + .and_then(|v| v.as_array()) + .and_then(|a| a.first()) + .ok_or_else(|| { + anyhow!( + "response missing choices[0]: {}", + extract_openrouter_error_detail(&parsed) + ) + })?; + let msg = choice.get("message").ok_or_else(|| { + anyhow!( + "choices[0] missing message: {}", + extract_openrouter_error_detail(&parsed) + ) + })?; + let chat_msg = Self::openai_message_to_chat(msg)?; + + let usage = parsed.get("usage"); + let prompt_tokens = usage + .and_then(|u| u.get("prompt_tokens")) + .and_then(|v| v.as_i64()) + .map(|n| n as i32); + let completion_tokens = usage + .and_then(|u| u.get("completion_tokens")) + .and_then(|v| v.as_i64()) + .map(|n| n as i32); + + Ok((chat_msg, prompt_tokens, completion_tokens)) + } + + async fn chat_with_tools_stream( + &self, + messages: Vec, + tools: Vec, + ) -> Result>> { + let url = format!("{}/chat/completions", self.base_url); + let mut body = serde_json::Map::new(); + body.insert("model".into(), Value::String(self.primary_model.clone())); + body.insert( + "messages".into(), + Value::Array(Self::messages_to_openai(&messages)), + ); + body.insert("stream".into(), Value::Bool(true)); + // Ask for usage data in the final chunk (OpenAI + OpenRouter + // both honor this options bag). + body.insert( + "stream_options".into(), + serde_json::json!({ "include_usage": true }), + ); + if !tools.is_empty() { + body.insert( + "tools".into(), + serde_json::to_value(&tools).context("serializing tools")?, + ); + } + for (k, v) in self.build_options() { + body.insert(k.into(), v); + } + + let resp = self + .authed(self.client.post(&url)) + .json(&Value::Object(body)) + .send() + .await + .with_context(|| format!("POST {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("OpenRouter stream request failed: {} — {}", status, body); + } + + // OpenAI-compat SSE stream. Each event is `data: \n\n`, with + // `data: [DONE]` signalling completion. Tool calls arrive as + // `delta.tool_calls[i]` chunks that must be concatenated by index. + let byte_stream = resp.bytes_stream(); + let stream = async_stream::stream! { + let mut byte_stream = byte_stream; + let mut buf: Vec = Vec::new(); + let mut accumulated_content = String::new(); + // tool call state: index -> (id, name, args_string) + let mut tool_state: std::collections::BTreeMap< + usize, + (Option, Option, String), + > = std::collections::BTreeMap::new(); + let mut role = "assistant".to_string(); + let mut prompt_tokens: Option = None; + let mut completion_tokens: Option = None; + let mut done_seen = false; + + while let Some(chunk) = byte_stream.next().await { + let chunk = match chunk { + Ok(b) => b, + Err(e) => { + yield Err(anyhow!("stream read failed: {}", e)); + return; + } + }; + buf.extend_from_slice(&chunk); + + // SSE frames are delimited by a blank line. Walk the buffer + // for "\n\n" markers; anything before them is a complete + // frame (possibly multi-line). + while let Some(sep) = find_double_newline(&buf) { + let frame = buf.drain(..sep + 2).collect::>(); + let frame_str = match std::str::from_utf8(&frame) { + Ok(s) => s, + Err(_) => continue, + }; + // A frame is one or more lines; the payload is on data: + // lines. Ignore comments and other fields. + for line in frame_str.lines() { + let line = line.trim_end_matches('\r'); + let payload = match line.strip_prefix("data: ") { + Some(p) => p, + None => continue, + }; + if payload == "[DONE]" { + done_seen = true; + break; + } + let v: Value = match serde_json::from_str(payload) { + Ok(v) => v, + Err(e) => { + log::warn!( + "malformed OpenRouter SSE frame: {} ({})", + payload, + e + ); + continue; + } + }; + + // Usage can arrive in a dedicated final frame with + // empty choices. + if let Some(usage) = v.get("usage") { + prompt_tokens = usage + .get("prompt_tokens") + .and_then(|n| n.as_i64()) + .map(|n| n as i32); + completion_tokens = usage + .get("completion_tokens") + .and_then(|n| n.as_i64()) + .map(|n| n as i32); + } + + let Some(choices) = v.get("choices").and_then(|c| c.as_array()) + else { + continue; + }; + let Some(choice) = choices.first() else { continue }; + let delta = match choice.get("delta") { + Some(d) => d, + None => continue, + }; + if let Some(r) = delta.get("role").and_then(|v| v.as_str()) { + role = r.to_string(); + } + if let Some(content) = + delta.get("content").and_then(|v| v.as_str()) + && !content.is_empty() + { + accumulated_content.push_str(content); + yield Ok(LlmStreamEvent::TextDelta(content.to_string())); + } + if let Some(tcs) = delta.get("tool_calls").and_then(|v| v.as_array()) { + for tc_delta in tcs { + let idx = tc_delta + .get("index") + .and_then(|n| n.as_u64()) + .unwrap_or(0) as usize; + let entry = tool_state + .entry(idx) + .or_insert((None, None, String::new())); + if let Some(id) = + tc_delta.get("id").and_then(|v| v.as_str()) + { + entry.0 = Some(id.to_string()); + } + if let Some(func) = tc_delta.get("function") { + if let Some(name) = + func.get("name").and_then(|v| v.as_str()) + { + entry.1 = Some(name.to_string()); + } + if let Some(args) = + func.get("arguments").and_then(|v| v.as_str()) + { + entry.2.push_str(args); + } + } + } + } + } + if done_seen { + break; + } + } + if done_seen { + break; + } + } + + // Finalize tool calls: parse accumulated argument strings. + let tool_calls: Option> = if tool_state.is_empty() { + None + } else { + let mut v = Vec::with_capacity(tool_state.len()); + for (_idx, (id, name, args)) in tool_state { + let arguments: Value = if args.trim().is_empty() { + Value::Object(Default::default()) + } else { + serde_json::from_str(&args).unwrap_or_else(|_| { + Value::Object(Default::default()) + }) + }; + v.push(ToolCall { + id, + function: ToolCallFunction { + name: name.unwrap_or_default(), + arguments, + }, + }); + } + Some(v) + }; + + let message = ChatMessage { + role, + content: accumulated_content, + tool_calls, + images: None, + }; + yield Ok(LlmStreamEvent::Done { + message, + prompt_eval_count: prompt_tokens, + eval_count: completion_tokens, + }); + }; + + Ok(Box::pin(stream)) + } + + async fn generate_embeddings(&self, texts: &[&str]) -> Result>> { + let url = format!("{}/embeddings", self.base_url); + let body = json!({ + "model": self.embedding_model, + "input": texts, + }); + + let resp = self + .authed(self.client.post(&url)) + .json(&body) + .send() + .await + .with_context(|| format!("POST {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("OpenRouter embedding request failed: {} — {}", status, body); + } + + #[derive(Deserialize)] + struct EmbedResponse { + data: Vec, + } + #[derive(Deserialize)] + struct EmbedItem { + embedding: Vec, + } + + let parsed: EmbedResponse = resp.json().await.context("parsing embed response")?; + Ok(parsed.data.into_iter().map(|i| i.embedding).collect()) + } + + async fn describe_image(&self, image_base64: &str) -> Result { + let prompt = "Briefly describe what you see in this image in 1-2 sentences. \ + Focus on the people, location, and activity."; + self.generate( + prompt, + Some("You are a scene description assistant. Be concise and factual."), + Some(vec![image_base64.to_string()]), + ) + .await + } + + async fn list_models(&self) -> Result> { + { + let cache = MODEL_CAPABILITIES_CACHE.lock().unwrap(); + if let Some(entry) = cache.get(&self.base_url) + && !entry.is_expired() + { + return Ok(entry.data.clone()); + } + } + + let url = format!("{}/models", self.base_url); + let resp = self + .authed(self.client.get(&url)) + .send() + .await + .with_context(|| format!("GET {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("OpenRouter list_models failed: {} — {}", status, body); + } + + let parsed: Value = resp.json().await.context("parsing models response")?; + let data = parsed + .get("data") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow!("models response missing data[]"))?; + + let caps: Vec = data.iter().map(parse_model_capabilities).collect(); + + { + let mut cache = MODEL_CAPABILITIES_CACHE.lock().unwrap(); + cache.insert(self.base_url.clone(), CachedEntry::new(caps.clone())); + } + + Ok(caps) + } + + async fn model_capabilities(&self, model: &str) -> Result { + let all = self.list_models().await?; + all.into_iter() + .find(|m| m.name == model) + .ok_or_else(|| anyhow!("model '{}' not found on OpenRouter", model)) + } + + fn primary_model(&self) -> &str { + &self.primary_model + } +} + +/// Extract a diagnostic fragment from an OpenRouter response body that +/// doesn't match the expected `{choices: [...]}` shape. OpenRouter will +/// sometimes return 200 OK with `{"error": {"message": "...", "code": ...}}` +/// when the upstream provider (Anthropic/OpenAI/Google/etc) errored out +/// — rate limits, content moderation, model overload, provider timeout. +/// Surface the structured error if present; otherwise fall back to a +/// truncated raw-JSON view so the log line is actionable. +fn extract_openrouter_error_detail(parsed: &Value) -> String { + if let Some(err) = parsed.get("error") { + let message = err + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("(no message)"); + let code = err + .get("code") + .map(|v| match v { + Value::String(s) => s.clone(), + other => other.to_string(), + }) + .unwrap_or_else(|| "?".to_string()); + let short_message: String = message.chars().take(240).collect(); + return format!("error code={} message=\"{}\"", code, short_message); + } + let raw = parsed.to_string(); + raw.chars().take(300).collect() +} + +/// Find the byte offset of the first `\n\n` (end of an SSE frame) in `buf`. +/// Returns the index of the first `\n` of the pair, so the full separator is +/// `buf[idx..=idx+1]`. Also handles `\r\n\r\n` since some servers emit it. +fn find_double_newline(buf: &[u8]) -> Option { + for i in 0..buf.len().saturating_sub(1) { + if buf[i] == b'\n' && buf[i + 1] == b'\n' { + return Some(i); + } + // \r\n\r\n: the second \n of this pattern is at i+2; flag at i so the + // drain call (which consumes ..sep+2) takes exactly the frame. + if i + 3 < buf.len() + && buf[i] == b'\r' + && buf[i + 1] == b'\n' + && buf[i + 2] == b'\r' + && buf[i + 3] == b'\n' + { + return Some(i + 1); + } + } + None +} + +/// Build a `data:` URL if the provided string is raw base64, otherwise pass it through. +fn image_to_data_url(img: &str) -> String { + if img.starts_with("data:") { + img.to_string() + } else { + format!("data:image/jpeg;base64,{}", img) + } +} + +fn parse_model_capabilities(m: &Value) -> ModelCapabilities { + let name = m + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let has_tool_calling = m + .get("supported_parameters") + .and_then(|v| v.as_array()) + .map(|arr| arr.iter().any(|x| x.as_str() == Some("tools"))) + .unwrap_or(false); + let has_vision = m + .get("architecture") + .and_then(|v| v.get("input_modalities")) + .and_then(|v| v.as_array()) + .map(|arr| arr.iter().any(|x| x.as_str() == Some("image"))) + .unwrap_or(false); + ModelCapabilities { + name, + has_vision, + has_tool_calling, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tool_call_arguments_stringified_on_send() { + let mut msg = ChatMessage { + role: "assistant".into(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: Some("call_abc".into()), + function: ToolCallFunction { + name: "search_sms".into(), + arguments: json!({"query": "hello", "limit": 5}), + }, + }]), + images: None, + }; + msg.tool_calls.as_mut().unwrap()[0].function.arguments = + json!({"query": "hello", "limit": 5}); + + let wire = OpenRouterClient::messages_to_openai(&[msg]); + let tcs = wire[0] + .get("tool_calls") + .and_then(|v| v.as_array()) + .expect("tool_calls present"); + let args = tcs[0] + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|a| a.as_str()) + .expect("arguments stringified"); + let parsed: Value = serde_json::from_str(args).unwrap(); + assert_eq!(parsed["query"], "hello"); + assert_eq!(parsed["limit"], 5); + } + + #[test] + fn tool_call_arguments_parsed_on_receive() { + let response_msg = json!({ + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": "call_xyz", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"city\":\"Boston\",\"units\":\"celsius\"}" + } + }] + }); + + let parsed = OpenRouterClient::openai_message_to_chat(&response_msg).unwrap(); + let tcs = parsed.tool_calls.unwrap(); + assert_eq!(tcs.len(), 1); + assert_eq!(tcs[0].function.name, "get_weather"); + assert_eq!(tcs[0].function.arguments["city"], "Boston"); + assert_eq!(tcs[0].function.arguments["units"], "celsius"); + assert_eq!(tcs[0].id.as_deref(), Some("call_xyz")); + } + + #[test] + fn tool_call_arguments_accept_native_json_on_receive() { + // Some providers return arguments as an object directly; accept both. + let response_msg = json!({ + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": { + "name": "foo", + "arguments": {"nested": {"k": 1}} + } + }] + }); + let parsed = OpenRouterClient::openai_message_to_chat(&response_msg).unwrap(); + let tc = &parsed.tool_calls.unwrap()[0]; + assert_eq!(tc.function.arguments["nested"]["k"], 1); + } + + #[test] + fn images_become_content_parts() { + let mut msg = ChatMessage::user("What is in this photo?"); + msg.images = Some(vec!["BASE64DATA".into()]); + + let wire = OpenRouterClient::messages_to_openai(&[msg]); + let content = wire[0].get("content").and_then(|v| v.as_array()).unwrap(); + assert_eq!(content.len(), 2); + assert_eq!(content[0]["type"], "text"); + assert_eq!(content[0]["text"], "What is in this photo?"); + assert_eq!(content[1]["type"], "image_url"); + assert_eq!( + content[1]["image_url"]["url"], + "data:image/jpeg;base64,BASE64DATA" + ); + } + + #[test] + fn data_url_images_pass_through_unchanged() { + let mut msg = ChatMessage::user(""); + msg.images = Some(vec!["data:image/png;base64,ABCDEF".into()]); + let wire = OpenRouterClient::messages_to_openai(&[msg]); + let content = wire[0].get("content").and_then(|v| v.as_array()).unwrap(); + // No text part when content is empty. + assert_eq!(content.len(), 1); + assert_eq!( + content[0]["image_url"]["url"], + "data:image/png;base64,ABCDEF" + ); + } + + #[test] + fn text_only_message_stays_string() { + let msg = ChatMessage::user("hello"); + let wire = OpenRouterClient::messages_to_openai(&[msg]); + assert_eq!(wire[0]["content"], "hello"); + assert!(wire[0]["content"].as_str().is_some()); + } + + #[test] + fn tool_result_inherits_tool_call_id_from_prior_assistant() { + let assistant = ChatMessage { + role: "assistant".into(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: Some("call_42".into()), + function: ToolCallFunction { + name: "lookup".into(), + arguments: json!({}), + }, + }]), + images: None, + }; + let tool_result = ChatMessage::tool_result("found it"); + + let wire = OpenRouterClient::messages_to_openai(&[assistant, tool_result]); + assert_eq!(wire[1]["role"], "tool"); + assert_eq!(wire[1]["tool_call_id"], "call_42"); + } + + #[test] + fn multiple_tool_results_map_to_sequential_call_ids() { + let assistant = ChatMessage { + role: "assistant".into(), + content: String::new(), + tool_calls: Some(vec![ + ToolCall { + id: Some("call_A".into()), + function: ToolCallFunction { + name: "a".into(), + arguments: json!({}), + }, + }, + ToolCall { + id: Some("call_B".into()), + function: ToolCallFunction { + name: "b".into(), + arguments: json!({}), + }, + }, + ]), + images: None, + }; + let r1 = ChatMessage::tool_result("a result"); + let r2 = ChatMessage::tool_result("b result"); + + let wire = OpenRouterClient::messages_to_openai(&[assistant, r1, r2]); + assert_eq!(wire[1]["tool_call_id"], "call_A"); + assert_eq!(wire[2]["tool_call_id"], "call_B"); + } + + #[test] + fn missing_tool_call_id_gets_synthetic_fallback() { + let assistant = ChatMessage { + role: "assistant".into(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: None, + function: ToolCallFunction { + name: "noid".into(), + arguments: json!({}), + }, + }]), + images: None, + }; + let wire = OpenRouterClient::messages_to_openai(&[assistant]); + let tcs = wire[0] + .get("tool_calls") + .and_then(|v| v.as_array()) + .unwrap(); + assert_eq!(tcs[0]["id"], "call_0"); + } + + #[test] + fn parse_model_capabilities_extracts_tools_and_vision() { + let m = json!({ + "id": "anthropic/claude-sonnet-4", + "supported_parameters": ["temperature", "top_p", "tools", "max_tokens"], + "architecture": { + "input_modalities": ["text", "image"] + } + }); + let caps = parse_model_capabilities(&m); + assert_eq!(caps.name, "anthropic/claude-sonnet-4"); + assert!(caps.has_tool_calling); + assert!(caps.has_vision); + } + + #[test] + fn parse_model_capabilities_handles_missing_fields() { + let m = json!({ + "id": "some/text-only-model" + }); + let caps = parse_model_capabilities(&m); + assert_eq!(caps.name, "some/text-only-model"); + assert!(!caps.has_tool_calling); + assert!(!caps.has_vision); + } +} diff --git a/src/ai/sms_client.rs b/src/ai/sms_client.rs index 1b6b605..ad6d28e 100644 --- a/src/ai/sms_client.rs +++ b/src/ai/sms_client.rs @@ -250,6 +250,45 @@ impl SmsApiClient { .collect()) } + /// Search message bodies via the Django side's FTS5 / semantic / hybrid + /// endpoint. `mode` selects the ranking strategy: + /// - "fts5" keyword-only, supports phrase / prefix / boolean / NEAR + /// - "semantic" embedding similarity + /// - "hybrid" both merged via reciprocal rank fusion (recommended) + pub async fn search_messages( + &self, + query: &str, + mode: &str, + limit: usize, + ) -> Result> { + let url = format!( + "{}/api/messages/search/?q={}&mode={}&limit={}", + self.base_url, + urlencoding::encode(query), + urlencoding::encode(mode), + limit + ); + + let mut request = self.client.get(&url); + if let Some(token) = &self.token { + request = request.header("Authorization", format!("Bearer {}", token)); + } + + let response = request.send().await?; + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(anyhow::anyhow!( + "SMS search request failed: {} - {}", + status, + body + )); + } + + let data: SmsSearchResponse = response.json().await?; + Ok(data.results) + } + pub async fn summarize_context( &self, messages: &[SmsMessage], @@ -260,12 +299,13 @@ impl SmsApiClient { } // Create prompt for Ollama with sender/receiver distinction + let user_name = crate::ai::user_display_name(); let messages_text: String = messages .iter() .take(60) // Limit to avoid token overflow .map(|m| { if m.is_sent { - format!("Me: {}", m.body) + format!("{}: {}", user_name, m.body) } else { format!("{}: {}", m.contact, m.body) } @@ -314,3 +354,28 @@ struct SmsApiMessage { #[serde(rename = "type")] type_: i32, } + +#[derive(Debug, Clone, Deserialize)] +pub struct SmsSearchHit { + #[allow(dead_code)] + pub message_id: i64, + pub contact_name: String, + #[allow(dead_code)] + pub contact_address: String, + pub body: String, + pub date: i64, + /// Message direction code: 1 = received, 2 = sent. + #[serde(rename = "type")] + pub type_: i32, + /// Present for semantic / hybrid modes; absent for fts5. + #[serde(default)] + pub similarity_score: Option, +} + +#[derive(Deserialize)] +struct SmsSearchResponse { + results: Vec, + #[allow(dead_code)] + #[serde(default)] + search_method: String, +} diff --git a/src/bin/populate_knowledge.rs b/src/bin/populate_knowledge.rs index bc37960..b3239ca 100644 --- a/src/bin/populate_knowledge.rs +++ b/src/bin/populate_knowledge.rs @@ -134,6 +134,7 @@ async fn main() -> anyhow::Result<()> { let generator = InsightGenerator::new( ollama, + None, sms_client, insight_dao.clone(), exif_dao, @@ -249,6 +250,9 @@ async fn main() -> anyhow::Result<()> { args.top_k, args.min_p, args.max_iterations, + None, + Vec::new(), + Vec::new(), ) .await { diff --git a/src/bin/test_daily_summary.rs b/src/bin/test_daily_summary.rs index fbbb621..d04aa32 100644 --- a/src/bin/test_daily_summary.rs +++ b/src/bin/test_daily_summary.rs @@ -1,7 +1,10 @@ use anyhow::Result; use chrono::NaiveDate; use clap::Parser; -use image_api::ai::{OllamaClient, SmsApiClient, strip_summary_boilerplate}; +use image_api::ai::{ + EMBEDDING_MODEL, OllamaClient, SmsApiClient, build_daily_summary_prompt, + strip_summary_boilerplate, user_display_name, +}; use image_api::database::{DailySummaryDao, InsertDailySummary, SqliteDailySummaryDao}; use std::env; use std::sync::{Arc, Mutex}; @@ -25,6 +28,26 @@ struct Args { #[arg(short, long)] model: Option, + /// Context window size passed as Ollama `num_ctx`. Omit for server default. + #[arg(long)] + num_ctx: Option, + + /// Sampling temperature. Omit for server default. + #[arg(long)] + temperature: Option, + + /// Top-p (nucleus) sampling. Omit for server default. + #[arg(long)] + top_p: Option, + + /// Top-k sampling. Omit for server default. + #[arg(long)] + top_k: Option, + + /// Min-p sampling. Omit for server default. + #[arg(long)] + min_p: Option, + /// Test mode: Generate but don't save to database (shows output only) #[arg(short = 't', long, default_value_t = false)] test_mode: bool, @@ -86,12 +109,28 @@ async fn main() -> Result<()> { .unwrap_or_else(|_| "nemotron-3-nano:30b".to_string()) }); - let ollama = OllamaClient::new( + let mut ollama = OllamaClient::new( ollama_primary_url, ollama_fallback_url.clone(), model_to_use.clone(), Some(model_to_use), // Use same model for fallback ); + if let Some(ctx) = args.num_ctx { + ollama.set_num_ctx(Some(ctx)); + } + if args.temperature.is_some() + || args.top_p.is_some() + || args.top_k.is_some() + || args.min_p.is_some() + { + ollama.set_sampling_params(args.temperature, args.top_p, args.top_k, args.min_p); + } + + // Surface what's actually configured so comparison runs are auditable. + println!( + "num_ctx={:?} temperature={:?} top_p={:?} top_k={:?} min_p={:?}", + args.num_ctx, args.temperature, args.top_p, args.top_k, args.min_p + ); let sms_api_url = env::var("SMS_API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string()); @@ -160,9 +199,14 @@ async fn main() -> Result<()> { println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"); if args.verbose { + let user_name = user_display_name(); println!("\nMessage preview:"); for (i, msg) in messages.iter().take(3).enumerate() { - let sender = if msg.is_sent { "Me" } else { &msg.contact }; + let sender: &str = if msg.is_sent { + &user_name + } else { + &msg.contact + }; let preview = msg.body.chars().take(60).collect::(); println!(" {}. {}: {}...", i + 1, sender, preview); } @@ -172,64 +216,11 @@ async fn main() -> Result<()> { println!(); } - // Format messages for LLM - let messages_text: String = messages - .iter() - .take(200) - .map(|m| { - if m.is_sent { - format!("Me: {}", m.body) - } else { - format!("{}: {}", m.contact, m.body) - } - }) - .collect::>() - .join("\n"); - - let prompt = format!( - r#"Summarize this day's conversation between me and {}. - -CRITICAL FORMAT RULES: -- Do NOT start with "Based on the conversation..." or "Here is a summary..." or similar preambles -- Do NOT repeat the date at the beginning -- Start DIRECTLY with the content - begin with a person's name or action -- Write in past tense, as if recording what happened - -NARRATIVE (3-5 sentences): -- What specific topics, activities, or events were discussed? -- What places, people, or organizations were mentioned? -- What plans were made or decisions discussed? -- Clearly distinguish between what "I" did versus what {} did - -KEYWORDS (comma-separated): -5-10 specific keywords that capture this conversation's unique content: -- Proper nouns (people, places, brands) -- Specific activities ("drum corps audition" not just "music") -- Distinctive terms that make this day unique - -Date: {} ({}) -Messages: -{} - -YOUR RESPONSE (follow this format EXACTLY): -Summary: [Start directly with content, NO preamble] - -Keywords: [specific, unique terms]"#, - args.contact, - args.contact, - date.format("%B %d, %Y"), - weekday, - messages_text - ); + let (prompt, system_prompt) = build_daily_summary_prompt(&args.contact, date, messages); println!("Generating summary..."); - let summary = ollama - .generate( - &prompt, - Some("You are a conversation summarizer. Create clear, factual summaries with precise subject attribution AND extract distinctive keywords. Focus on specific, unique terms that differentiate this conversation from others."), - ) - .await?; + let summary = ollama.generate(&prompt, Some(system_prompt)).await?; println!("\n📝 GENERATED SUMMARY:"); println!("─────────────────────────────────────────"); @@ -256,8 +247,7 @@ Keywords: [specific, unique terms]"#, message_count: messages.len() as i32, embedding, created_at: chrono::Utc::now().timestamp(), - // model_version: "nomic-embed-text:v1.5".to_string(), - model_version: "mxbai-embed-large:335m".to_string(), + model_version: EMBEDDING_MODEL.to_string(), }; let mut dao = summary_dao.lock().expect("Unable to lock DailySummaryDao"); diff --git a/src/database/daily_summary_dao.rs b/src/database/daily_summary_dao.rs index 6ea560a..276a5a2 100644 --- a/src/database/daily_summary_dao.rs +++ b/src/database/daily_summary_dao.rs @@ -268,7 +268,7 @@ impl DailySummaryDao for SqliteDailySummaryDao { .into_iter() .take(limit) .map(|(similarity, summary)| { - log::info!( + log::debug!( "Summary match: similarity={:.3}, date={}, contact={}, summary=\"{}\"", similarity, summary.date, @@ -388,7 +388,7 @@ impl DailySummaryDao for SqliteDailySummaryDao { .into_iter() .take(limit) .map(|(combined, similarity, days, summary)| { - log::info!( + log::debug!( "Summary match: combined={:.3} (sim={:.3}, days={}), date={}, contact={}, summary=\"{}\"", combined, similarity, diff --git a/src/database/insights_dao.rs b/src/database/insights_dao.rs index 553b579..2821b5b 100644 --- a/src/database/insights_dao.rs +++ b/src/database/insights_dao.rs @@ -38,6 +38,16 @@ pub trait InsightDao: Sync + Send { file_path: &str, ) -> Result, DbError>; + /// Fetch a single insight by primary key, regardless of `is_current`. + /// Used by the few-shot injection flow where the caller picks specific + /// historical insights (which may have been superseded) as training + /// exemplars for a fresh generation. + fn get_insight_by_id( + &mut self, + context: &opentelemetry::Context, + insight_id: i32, + ) -> Result, DbError>; + fn delete_insight( &mut self, context: &opentelemetry::Context, @@ -60,6 +70,17 @@ pub trait InsightDao: Sync + Send { &mut self, context: &opentelemetry::Context, ) -> Result, DbError>; + + /// Replace the `training_messages` JSON blob on the current row for + /// `(library_id, rel_path)`. Used by chat-turn append mode to persist + /// the extended conversation without inserting a new insight version. + fn update_training_messages( + &mut self, + context: &opentelemetry::Context, + library_id: i32, + file_path: &str, + training_messages_json: &str, + ) -> Result<(), DbError>; } pub struct SqliteInsightDao { @@ -187,6 +208,25 @@ impl InsightDao for SqliteInsightDao { .map_err(|_| DbError::new(DbErrorKind::QueryError)) } + fn get_insight_by_id( + &mut self, + context: &opentelemetry::Context, + insight_id: i32, + ) -> Result, DbError> { + trace_db_call(context, "query", "get_insight_by_id", |_span| { + use schema::photo_insights::dsl::*; + + let mut connection = self.connection.lock().expect("Unable to get InsightDao"); + + photo_insights + .find(insight_id) + .first::(connection.deref_mut()) + .optional() + .map_err(|_| anyhow::anyhow!("Query error")) + }) + .map_err(|_| DbError::new(DbErrorKind::QueryError)) + } + fn delete_insight( &mut self, context: &opentelemetry::Context, @@ -265,4 +305,30 @@ impl InsightDao for SqliteInsightDao { }) .map_err(|_| DbError::new(DbErrorKind::QueryError)) } + + fn update_training_messages( + &mut self, + context: &opentelemetry::Context, + lib_id: i32, + path: &str, + training_messages_json: &str, + ) -> Result<(), DbError> { + trace_db_call(context, "update", "update_training_messages", |_span| { + use schema::photo_insights::dsl::*; + + let mut connection = self.connection.lock().expect("Unable to get InsightDao"); + + diesel::update( + photo_insights + .filter(library_id.eq(lib_id)) + .filter(rel_path.eq(path)) + .filter(is_current.eq(true)), + ) + .set(training_messages.eq(Some(training_messages_json.to_string()))) + .execute(connection.deref_mut()) + .map(|_| ()) + .map_err(|_| anyhow::anyhow!("Update error")) + }) + .map_err(|_| DbError::new(DbErrorKind::UpdateError)) + } } diff --git a/src/database/models.rs b/src/database/models.rs index d95876b..96d9c53 100644 --- a/src/database/models.rs +++ b/src/database/models.rs @@ -100,6 +100,14 @@ pub struct InsertPhotoInsight { pub model_version: String, pub is_current: bool, pub training_messages: Option, + /// `"local"` (Ollama with images) | `"hybrid"` (local vision + OpenRouter chat). + pub backend: String, + /// JSON array of insight ids whose `training_messages` were compressed + /// and injected into the system prompt as few-shot exemplars when this + /// row was generated. `None` means no few-shot was used (pristine + /// generation). Used downstream to filter out contaminated rows when + /// assembling an unbiased training / evaluation set. + pub fewshot_source_ids: Option, } #[derive(Serialize, Queryable, Clone, Debug)] @@ -115,6 +123,9 @@ pub struct PhotoInsight { pub is_current: bool, pub training_messages: Option, pub approved: Option, + /// `"local"` (Ollama with images) | `"hybrid"` (local vision + OpenRouter chat). + pub backend: String, + pub fewshot_source_ids: Option, } // --- Libraries --- diff --git a/src/database/schema.rs b/src/database/schema.rs index 3352ca6..e49f21f 100644 --- a/src/database/schema.rs +++ b/src/database/schema.rs @@ -142,6 +142,8 @@ diesel::table! { is_current -> Bool, training_messages -> Nullable, approved -> Nullable, + backend -> Text, + fewshot_source_ids -> Nullable, } } diff --git a/src/exif.rs b/src/exif.rs index c096f71..0cd29d9 100644 --- a/src/exif.rs +++ b/src/exif.rs @@ -1,5 +1,5 @@ use std::fs::File; -use std::io::BufReader; +use std::io::{BufReader, Read, Seek, SeekFrom}; use std::path::Path; use anyhow::{Result, anyhow}; @@ -25,6 +25,60 @@ pub struct ExifData { pub date_taken: Option, } +/// TIFF-based RAW formats where `JPEGInterchangeFormat` offsets are +/// absolute file offsets (the file itself is a TIFF container). +fn is_tiff_raw(path: &Path) -> bool { + matches!( + path.extension() + .and_then(|e| e.to_str()) + .map(|s| s.to_lowercase()) + .as_deref(), + Some( + "tiff" | "tif" | "nef" | "cr2" | "arw" | "dng" | "raf" | "orf" | "rw2" | "pef" | "srw" + ) + ) +} + +/// Returns the bytes of the embedded JPEG thumbnail in a TIFF-based RAW or +/// TIFF file. Used to thumbnail formats whose RAW pixel data can't be decoded +/// by our normal tools (e.g. Sony ARW). Returns `None` if no preview is +/// present, the file isn't a TIFF container, or the data doesn't look like +/// a valid JPEG. +pub fn extract_embedded_jpeg_preview(path: &Path) -> Option> { + if !is_tiff_raw(path) { + return None; + } + + let file = File::open(path).ok()?; + let mut bufreader = BufReader::new(file); + let exif = Reader::new().read_from_container(&mut bufreader).ok()?; + + let offset = exif + .get_field(Tag::JPEGInterchangeFormat, In::THUMBNAIL)? + .value + .get_uint(0)?; + let length = exif + .get_field(Tag::JPEGInterchangeFormatLength, In::THUMBNAIL)? + .value + .get_uint(0)?; + if length == 0 { + return None; + } + + let mut file = File::open(path).ok()?; + file.seek(SeekFrom::Start(offset as u64)).ok()?; + let mut buf = vec![0u8; length as usize]; + file.read_exact(&mut buf).ok()?; + + // JPEG SOI marker sanity check — MakerNote offsets sometimes point at + // TIFF-wrapped previews or other non-JPEG data. + if buf.len() < 2 || buf[0] != 0xFF || buf[1] != 0xD8 { + return None; + } + + Some(buf) +} + pub fn supports_exif(path: &Path) -> bool { if let Some(ext) = path.extension() { let ext_lower = ext.to_string_lossy().to_lowercase(); diff --git a/src/file_types.rs b/src/file_types.rs index c1249d0..f312916 100644 --- a/src/file_types.rs +++ b/src/file_types.rs @@ -3,9 +3,22 @@ use walkdir::DirEntry; /// Supported image file extensions pub const IMAGE_EXTENSIONS: &[&str] = &[ - "jpg", "jpeg", "png", "webp", "tiff", "tif", "heif", "heic", "avif", "nef", + "jpg", "jpeg", "png", "webp", "tiff", "tif", "heif", "heic", "avif", "nef", "arw", ]; +/// Extensions the `image` crate cannot decode — we fall back to ffmpeg to +/// extract an embedded preview or decode the frame. +pub const FFMPEG_THUMBNAIL_EXTENSIONS: &[&str] = &["heif", "heic", "nef", "arw"]; + +/// Returns true if thumbnail generation should go through ffmpeg instead of +/// the `image` crate (RAW formats, HEIF/HEIC). +pub fn needs_ffmpeg_thumbnail(path: &Path) -> bool { + match path.extension().and_then(|e| e.to_str()) { + Some(ext) => FFMPEG_THUMBNAIL_EXTENSIONS.contains(&ext.to_lowercase().as_str()), + None => false, + } +} + /// Supported video file extensions pub const VIDEO_EXTENSIONS: &[&str] = &["mp4", "mov", "avi", "mkv"]; diff --git a/src/files.rs b/src/files.rs index 561414c..75343b1 100644 --- a/src/files.rs +++ b/src/files.rs @@ -15,6 +15,7 @@ use crate::database::ExifDao; use crate::file_types; use crate::geo::{gps_bounding_box, haversine_distance}; use crate::memories::extract_date_from_filename; +use crate::utils::earliest_fs_time; use crate::{AppState, create_thumbnails}; use actix_web::web::Data; use actix_web::{ @@ -138,8 +139,8 @@ fn in_memory_date_sort( lib_roots.get(&lib_id).and_then(|root| { let full_path = Path::new(root).join(&f.file_name); std::fs::metadata(full_path) - .and_then(|md| md.created().or(md.modified())) .ok() + .and_then(|md| earliest_fs_time(&md)) .map(|system_time| { >>::into(system_time).timestamp() }) diff --git a/src/main.rs b/src/main.rs index 570cf58..db63ef0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,8 @@ use crate::state::AppState; use crate::tags::*; use crate::video::actors::{ GeneratePreviewClipMessage, ProcessMessage, QueueVideosMessage, ScanDirectoryMessage, - VideoPlaylistManager, create_playlist, generate_video_thumbnail, + VideoPlaylistManager, create_playlist, generate_image_thumbnail_ffmpeg, + generate_video_thumbnail, }; use log::{debug, error, info, trace, warn}; use opentelemetry::trace::{Span, Status, TraceContextExt, Tracer}; @@ -1060,6 +1061,47 @@ async fn delete_favorite( } } +/// Sentinel path written next to a would-be thumbnail when a file cannot be +/// decoded by either the `image` crate or ffmpeg. Its presence causes future +/// scans to skip the file instead of re-logging the failure. +pub fn unsupported_thumbnail_sentinel(thumb_path: &Path) -> PathBuf { + let mut s = thumb_path.as_os_str().to_owned(); + s.push(".unsupported"); + PathBuf::from(s) +} + +fn generate_image_thumbnail(src: &Path, thumb_path: &Path) -> std::io::Result<()> { + // RAW formats (ARW/NEF/CR2/etc): try the file's embedded JPEG preview + // first. Avoids ffmpeg choking on proprietary RAW compression (Sony ARW + // in particular), and is faster than decoding RAW pixels anyway. + if let Some(preview) = exif::extract_embedded_jpeg_preview(src) { + let img = image::load_from_memory(&preview).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("decode embedded preview {:?}: {}", src, e), + ) + })?; + let scaled = img.thumbnail(200, u32::MAX); + scaled + .save_with_format(thumb_path, image::ImageFormat::Jpeg) + .map_err(|e| std::io::Error::other(format!("save {:?}: {}", thumb_path, e)))?; + return Ok(()); + } + + if file_types::needs_ffmpeg_thumbnail(src) { + return generate_image_thumbnail_ffmpeg(src, thumb_path); + } + + let img = image::open(src).map_err(|e| { + std::io::Error::new(std::io::ErrorKind::InvalidData, format!("{:?}: {}", src, e)) + })?; + let scaled = img.thumbnail(200, u32::MAX); + scaled + .save(thumb_path) + .map_err(|e| std::io::Error::other(format!("save {:?}: {}", thumb_path, e)))?; + Ok(()) +} + fn create_thumbnails(libs: &[libraries::Library]) { let tracer = global_tracer(); let span = tracer.start("creating thumbnails"); @@ -1080,17 +1122,26 @@ fn create_thumbnails(libs: &[libraries::Library]) { .into_par_iter() .filter_map(|entry| entry.ok()) .filter(|entry| entry.file_type().is_file()) - .filter(|entry| { - if is_video(entry) { - let relative_path = &entry.path().strip_prefix(&images).unwrap(); - let thumb_path = Path::new(thumbnail_directory).join(relative_path); - std::fs::create_dir_all( - thumb_path - .parent() - .unwrap_or_else(|| panic!("Thumbnail {:?} has no parent?", thumb_path)), - ) - .expect("Error creating directory"); + .for_each(|entry| { + let src = entry.path(); + let Ok(relative_path) = src.strip_prefix(&images) else { + return; + }; + let thumb_path = Path::new(thumbnail_directory).join(relative_path); + if thumb_path.exists() || unsupported_thumbnail_sentinel(&thumb_path).exists() { + return; + } + + let Some(parent) = thumb_path.parent() else { + return; + }; + if let Err(e) = std::fs::create_dir_all(parent) { + error!("Failed to create thumbnail dir {:?}: {}", parent, e); + return; + } + + if is_video(&entry) { let mut video_span = tracer.start_with_context( "generate_video_thumbnail", &opentelemetry::Context::new() @@ -1103,37 +1154,24 @@ fn create_thumbnails(libs: &[libraries::Library]) { ]); debug!("Generating video thumbnail: {:?}", thumb_path); - generate_video_thumbnail(entry.path(), &thumb_path); + generate_video_thumbnail(src, &thumb_path); video_span.end(); - false - } else { - is_image(entry) + } else if is_image(&entry) { + match generate_image_thumbnail(src, &thumb_path) { + Ok(_) => info!("Saved thumbnail: {:?}", thumb_path), + Err(e) => { + let sentinel = unsupported_thumbnail_sentinel(&thumb_path); + error!( + "Unable to thumbnail {:?}: {}. Writing sentinel {:?}", + src, e, sentinel + ); + if let Err(se) = std::fs::write(&sentinel, b"") { + warn!("Failed to write sentinel {:?}: {}", sentinel, se); + } + } + } } - }) - .filter(|entry| { - let path = entry.path(); - let relative_path = &path.strip_prefix(&images).unwrap(); - let thumb_path = Path::new(thumbnail_directory).join(relative_path); - !thumb_path.exists() - }) - .map(|entry| (image::open(entry.path()), entry.path().to_path_buf())) - .filter(|(img, path)| { - if let Err(e) = img { - error!("Unable to open image: {:?}. {}", path, e); - } - img.is_ok() - }) - .map(|(img, path)| (img.unwrap(), path)) - .map(|(image, path)| (image.thumbnail(200, u32::MAX), path)) - .map(|(image, path)| { - let relative_path = &path.strip_prefix(&images).unwrap(); - let thumb_path = Path::new(thumbnail_directory).join(relative_path); - std::fs::create_dir_all(thumb_path.parent().unwrap()) - .expect("There was an issue creating directory"); - info!("Saving thumbnail: {:?}", thumb_path); - image.save(thumb_path).expect("Failure saving thumbnail"); - }) - .for_each(drop); + }); } debug!("Finished making thumbnails"); @@ -1355,6 +1393,11 @@ fn main() -> std::io::Result<()> { .service(ai::delete_insight_handler) .service(ai::get_all_insights_handler) .service(ai::get_available_models_handler) + .service(ai::get_openrouter_models_handler) + .service(ai::chat_turn_handler) + .service(ai::chat_stream_handler) + .service(ai::chat_history_handler) + .service(ai::chat_rewind_handler) .service(ai::rate_insight_handler) .service(ai::export_training_data_handler) .service(libraries::list_libraries) @@ -1739,7 +1782,8 @@ fn process_new_files( // not just photos with parseable EXIF. for (file_path, relative_path) in &files { let thumb_path = thumbnail_directory.join(relative_path); - let needs_thumbnail = !thumb_path.exists(); + let needs_thumbnail = + !thumb_path.exists() && !unsupported_thumbnail_sentinel(&thumb_path).exists(); let needs_row = !existing_exif_paths.contains_key(relative_path); if needs_thumbnail || needs_row { diff --git a/src/memories.rs b/src/memories.rs index 875a72c..0e2aad5 100644 --- a/src/memories.rs +++ b/src/memories.rs @@ -19,6 +19,7 @@ use crate::files::is_image_or_video; use crate::libraries::Library; use crate::otel::{extract_context_from_request, global_tracer}; use crate::state::AppState; +use crate::utils::earliest_fs_time; // Helper that encapsulates path-exclusion semantics #[derive(Debug)] @@ -336,8 +337,8 @@ fn get_memory_date_with_priority( return Some((date, Some(exif_timestamp), modified)); } - // Priority 3: Fall back to metadata - let system_time = meta.created().ok().or_else(|| meta.modified().ok())?; + // Priority 3: Fall back to metadata (earlier of created/modified — see utils::earliest_fs_time) + let system_time = earliest_fs_time(&meta)?; let dt_utc: DateTime = system_time.into(); let date_in_timezone = if let Some(tz) = client_timezone { diff --git a/src/state.rs b/src/state.rs index 78b98ad..8e13d28 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,3 +1,5 @@ +use crate::ai::insight_chat::{ChatLockMap, InsightChatService}; +use crate::ai::openrouter::OpenRouterClient; use crate::ai::{InsightGenerator, OllamaClient, SmsApiClient}; use crate::database::{ CalendarEventDao, DailySummaryDao, ExifDao, InsightDao, KnowledgeDao, LocationHistoryDao, @@ -31,8 +33,20 @@ pub struct AppState { pub preview_clips_path: String, pub excluded_dirs: Vec, pub ollama: OllamaClient, + /// `None` when `OPENROUTER_API_KEY` is not configured. Consulted only + /// when a request explicitly opts into `backend=hybrid`. Currently + /// reached via `insight_generator`; kept here so future handlers + /// (insight_chat) can route to it without threading it through the + /// generator. + #[allow(dead_code)] + pub openrouter: Option>, + /// Curated list of OpenRouter model ids exposed to clients. Sourced from + /// `OPENROUTER_ALLOWED_MODELS` (comma-separated). Empty when unset. + pub openrouter_allowed_models: Vec, pub sms_client: SmsApiClient, pub insight_generator: InsightGenerator, + /// Chat continuation service. Hold an Arc so handlers can clone cheaply. + pub insight_chat: Arc, } impl AppState { @@ -61,8 +75,11 @@ impl AppState { preview_clips_path: String, excluded_dirs: Vec, ollama: OllamaClient, + openrouter: Option>, + openrouter_allowed_models: Vec, sms_client: SmsApiClient, insight_generator: InsightGenerator, + insight_chat: Arc, preview_dao: Arc>>, ) -> Self { assert!( @@ -92,8 +109,11 @@ impl AppState { preview_clips_path, excluded_dirs, ollama, + openrouter, + openrouter_allowed_models, sms_client, insight_generator, + insight_chat, } } @@ -127,6 +147,9 @@ impl Default for AppState { ollama_fallback_model, ); + let openrouter = build_openrouter_from_env(); + let openrouter_allowed_models = parse_openrouter_allowed_models(); + let sms_api_url = env::var("SMS_API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string()); let sms_api_token = env::var("SMS_API_TOKEN").ok(); @@ -168,6 +191,7 @@ impl Default for AppState { // Initialize InsightGenerator with all data sources let insight_generator = InsightGenerator::new( ollama.clone(), + openrouter.clone(), sms_client.clone(), insight_dao.clone(), exif_dao.clone(), @@ -180,6 +204,18 @@ impl Default for AppState { libraries_vec.clone(), ); + // Chat continuation reuses the generator for tool dispatch + image + // loading. The lock map starts empty and grows lazily per file. + let chat_locks: ChatLockMap = + Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())); + let insight_chat = Arc::new(InsightChatService::new( + Arc::new(insight_generator.clone()), + ollama.clone(), + openrouter.clone(), + insight_dao.clone(), + chat_locks, + )); + // Ensure preview clips directory exists let preview_clips_path = env::var("PREVIEW_CLIPS_DIRECTORY").unwrap_or_else(|_| "preview_clips".to_string()); @@ -195,13 +231,47 @@ impl Default for AppState { preview_clips_path, Self::parse_excluded_dirs(), ollama, + openrouter, + openrouter_allowed_models, sms_client, insight_generator, + insight_chat, preview_dao, ) } } +/// Build an `OpenRouterClient` from environment variables. Returns `None` +/// when `OPENROUTER_API_KEY` is unset (the hybrid backend is then +/// unavailable and requests for it return a clear error). +fn build_openrouter_from_env() -> Option> { + let api_key = env::var("OPENROUTER_API_KEY").ok()?; + let base_url = env::var("OPENROUTER_BASE_URL").ok(); + let default_model = env::var("OPENROUTER_DEFAULT_MODEL") + .unwrap_or_else(|_| "anthropic/claude-sonnet-4".to_string()); + let mut client = OpenRouterClient::new(api_key, base_url, default_model); + client.set_attribution( + env::var("OPENROUTER_HTTP_REFERER").ok(), + env::var("OPENROUTER_APP_TITLE").ok(), + ); + if let Ok(model) = env::var("OPENROUTER_EMBEDDING_MODEL") { + client.set_embedding_model(model); + } + Some(Arc::new(client)) +} + +/// Parse `OPENROUTER_ALLOWED_MODELS` (comma-separated) into a vec. Returns +/// empty when unset, in which case `/insights/openrouter/models` reports no +/// curated picks and the server falls back to `OPENROUTER_DEFAULT_MODEL`. +fn parse_openrouter_allowed_models() -> Vec { + env::var("OPENROUTER_ALLOWED_MODELS") + .unwrap_or_default() + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() +} + #[cfg(test)] impl AppState { /// Creates an AppState instance for testing with temporary directories @@ -255,6 +325,7 @@ impl AppState { }; let insight_generator = InsightGenerator::new( ollama.clone(), + None, sms_client.clone(), insight_dao.clone(), exif_dao.clone(), @@ -267,6 +338,16 @@ impl AppState { vec![test_lib], ); + let chat_locks: ChatLockMap = + Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())); + let insight_chat = Arc::new(InsightChatService::new( + Arc::new(insight_generator.clone()), + ollama.clone(), + None, + insight_dao.clone(), + chat_locks, + )); + // Initialize test preview DAO let preview_dao: Arc>> = Arc::new(Mutex::new(Box::new(SqlitePreviewDao::new()))); @@ -286,8 +367,11 @@ impl AppState { preview_clips_path.to_string_lossy().to_string(), Vec::new(), // No excluded directories for test state ollama, + None, + Vec::new(), sms_client, insight_generator, + insight_chat, preview_dao, ) } diff --git a/src/utils.rs b/src/utils.rs index 1779c15..fdfef9b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,5 @@ +use std::time::SystemTime; + /// Normalize a file path to use forward slashes for cross-platform consistency /// This ensures paths stored in the database always use `/` regardless of OS /// @@ -12,6 +14,20 @@ pub fn normalize_path(path: &str) -> String { path.replace('\\', "/") } +/// Pick the earlier of a file's created and modified timestamps. +/// +/// On copied/restored files (e.g., a backup library), `created` is stamped at +/// copy time while `modified` is preserved from the source — so the earlier +/// of the two is a better proxy for when the content originated. Falls back +/// to whichever timestamp is available if one platform lacks the other. +pub fn earliest_fs_time(md: &std::fs::Metadata) -> Option { + match (md.created().ok(), md.modified().ok()) { + (Some(c), Some(m)) => Some(c.min(m)), + (Some(t), None) | (None, Some(t)) => Some(t), + (None, None) => None, + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/video/actors.rs b/src/video/actors.rs index 284c8e3..c7300ef 100644 --- a/src/video/actors.rs +++ b/src/video/actors.rs @@ -4,7 +4,6 @@ use crate::libraries::Library; use crate::otel::global_tracer; use crate::video::ffmpeg::generate_preview_clip; use actix::prelude::*; -use futures::TryFutureExt; use log::{debug, error, info, trace, warn}; use opentelemetry::KeyValue; use opentelemetry::trace::{Span, Status, Tracer}; @@ -48,6 +47,24 @@ impl Handler for StreamActor { } } +pub fn playlist_file_for(playlist_dir: &str, video_path: &Path) -> PathBuf { + let filename = video_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + PathBuf::from(format!("{}/{}.m3u8", playlist_dir, filename)) +} + +/// Sentinel path written next to a would-be playlist when ffmpeg cannot +/// transcode the source (e.g. truncated mp4 with no moov atom). Its presence +/// causes future scans to skip the file instead of re-running ffmpeg every +/// pass. Delete the `.unsupported` file to force a retry. +pub fn playlist_unsupported_sentinel(playlist_file: &Path) -> PathBuf { + let mut s = playlist_file.as_os_str().to_owned(); + s.push(".unsupported"); + PathBuf::from(s) +} + pub async fn create_playlist(video_path: &str, playlist_file: &str) -> Result { if Path::new(playlist_file).exists() { debug!("Playlist already exists: {}", playlist_file); @@ -66,9 +83,11 @@ pub async fn create_playlist(video_path: &str, playlist_file: &str) -> Result bool { - let output = tokio::process::Command::new("ffprobe") - .arg("-v") - .arg("error") - .arg("-select_streams") - .arg("v:0") - .arg("-show_entries") - .arg("stream=codec_name") - .arg("-of") - .arg("default=noprint_wrappers=1:nokey=1") - .arg(video_path) - .output() - .await; +/// Use ffmpeg to extract a 200px-wide thumbnail from formats the `image` crate +/// can't decode (RAW: NEF/ARW, HEIC/HEIF). Writes JPEG bytes to `destination` +/// regardless of its extension. +pub fn generate_image_thumbnail_ffmpeg(path: &Path, destination: &Path) -> std::io::Result<()> { + let output = Command::new("ffmpeg") + .arg("-y") + .arg("-i") + .arg(path) + .arg("-vframes") + .arg("1") + .arg("-vf") + .arg("scale=200:-1") + .arg("-f") + .arg("image2") + .arg("-c:v") + .arg("mjpeg") + .arg(destination) + .output()?; - match output { - Ok(output) if output.status.success() => { - let codec = String::from_utf8_lossy(&output.stdout); - let codec = codec.trim(); - debug!("Detected codec for {}: {}", video_path, codec); - codec == "h264" - } - Ok(output) => { - warn!( - "ffprobe failed for {}: {}", - video_path, - String::from_utf8_lossy(&output.stderr) - ); - false - } - Err(e) => { - warn!("Failed to run ffprobe for {}: {}", video_path, e); - false - } + if !output.status.success() { + return Err(std::io::Error::other(format!( + "ffmpeg failed ({}): {}", + output.status, + String::from_utf8_lossy(&output.stderr).trim() + ))); } + Ok(()) } -/// Check if a video has rotation metadata -/// Returns the rotation angle in degrees (0, 90, 180, 270) or 0 if none detected -/// Checks both legacy stream tags and modern display matrix side data -async fn get_video_rotation(video_path: &str) -> i32 { - // Check legacy rotate stream tag (older videos) +/// Video stream metadata needed to pick HLS encode settings. Populated by +/// a single ffprobe call to avoid spawning multiple subprocesses per video. +#[derive(Debug, Default)] +struct VideoStreamMeta { + is_h264: bool, + /// Rotation in degrees (0/90/180/270). Checks both the legacy `rotate` + /// stream tag and the modern display-matrix side data. + rotation: i32, +} + +/// Probe video stream metadata in one ffprobe call. Returns default (codec +/// unknown, rotation 0) on any failure — callers fall back to transcoding. +async fn probe_video_stream_meta(video_path: &str) -> VideoStreamMeta { let output = tokio::process::Command::new("ffprobe") .arg("-v") .arg("error") .arg("-select_streams") .arg("v:0") + .arg("-print_format") + .arg("json") .arg("-show_entries") - .arg("stream_tags=rotate") - .arg("-of") - .arg("default=noprint_wrappers=1:nokey=1") + .arg("stream=codec_name:stream_tags=rotate:side_data_list") .arg(video_path) .output() .await; - if let Ok(output) = output - && output.status.success() - { - let rotation_str = String::from_utf8_lossy(&output.stdout); - let rotation_str = rotation_str.trim(); - if !rotation_str.is_empty() - && let Ok(rotation) = rotation_str.parse::() - && rotation != 0 - { - debug!( - "Detected rotation {}° from stream tag for {}", - rotation, video_path - ); - return rotation; - } + let Ok(output) = output else { + warn!("Failed to run ffprobe for {}", video_path); + return VideoStreamMeta::default(); + }; + if !output.status.success() { + warn!( + "ffprobe failed for {}: {}", + video_path, + String::from_utf8_lossy(&output.stderr).trim() + ); + return VideoStreamMeta::default(); } - // Check display matrix side data (modern videos, e.g. iPhone) + let Ok(json) = serde_json::from_slice::(&output.stdout) else { + warn!("ffprobe returned non-JSON for {}", video_path); + return VideoStreamMeta::default(); + }; + + let stream = &json["streams"][0]; + let is_h264 = stream + .get("codec_name") + .and_then(|v| v.as_str()) + .map(|s| s == "h264") + .unwrap_or(false); + + // Prefer legacy `tags.rotate` (older containers); fall back to the + // display-matrix side data (iPhone and other modern recorders). + let rotation = stream + .get("tags") + .and_then(|t| t.get("rotate")) + .and_then(|r| r.as_str()) + .and_then(|s| s.parse::().ok()) + .filter(|r| *r != 0) + .or_else(|| { + stream + .get("side_data_list") + .and_then(|l| l.as_array()) + .and_then(|arr| { + arr.iter() + .find_map(|sd| sd.get("rotation").and_then(|r| r.as_f64())) + }) + .map(|f| f.abs() as i32) + .filter(|r| *r != 0) + }) + .unwrap_or(0); + + debug!( + "Probed {}: codec_h264={}, rotation={}°", + video_path, is_h264, rotation + ); + + VideoStreamMeta { is_h264, rotation } +} + +/// Probe the max keyframe interval (GOP) in the first ~30s of a video. +/// Returns `None` on probe failure or if we couldn't see at least two keyframes. +/// +/// Used to decide between stream-copy and transcode: HLS needs segments to +/// start on keyframes, so if the source GOP exceeds `hls_time`, copying +/// produces oversized/glitchy segments and we need to re-encode. +async fn get_max_gop_seconds(video_path: &str) -> Option { let output = tokio::process::Command::new("ffprobe") .arg("-v") .arg("error") .arg("-select_streams") .arg("v:0") + .arg("-skip_frame") + .arg("nokey") .arg("-show_entries") - .arg("side_data=rotation") + .arg("frame=pts_time") .arg("-of") - .arg("default=noprint_wrappers=1:nokey=1") + .arg("csv=p=0") + .arg("-read_intervals") + .arg("%+30") .arg(video_path) .output() - .await; + .await + .ok()?; - if let Ok(output) = output - && output.status.success() - { - let rotation_str = String::from_utf8_lossy(&output.stdout); - let rotation_str = rotation_str.trim(); - if !rotation_str.is_empty() - && let Ok(rotation) = rotation_str.parse::() - { - let rotation = rotation.abs() as i32; - if rotation != 0 { - debug!( - "Detected rotation {}° from display matrix for {}", - rotation, video_path - ); - return rotation; - } - } + if !output.status.success() { + warn!( + "ffprobe GOP check failed for {}: {}", + video_path, + String::from_utf8_lossy(&output.stderr).trim() + ); + return None; } - 0 + let times: Vec = String::from_utf8_lossy(&output.stdout) + .lines() + .filter_map(|l| l.trim().parse::().ok()) + .collect(); + + if times.len() < 2 { + return None; + } + + let max_gop = times + .windows(2) + .map(|w| w[1] - w[0]) + .fold(0.0_f64, f64::max); + debug!( + "Max GOP in first {} keyframes of {}: {:.2}s", + times.len(), + video_path, + max_gop + ); + Some(max_gop) } pub struct VideoPlaylistManager { @@ -246,15 +321,21 @@ impl Handler for VideoPlaylistManager { msg.directory ); + let playlist_output_dir = self.playlist_dir.clone(); + let playlist_dir_str = playlist_output_dir.to_str().unwrap().to_string(); + let video_files = WalkDir::new(&msg.directory) .into_iter() .filter_map(|e| e.ok()) .filter(|e| e.file_type().is_file()) .filter(is_video) + .filter(|e| { + let playlist = playlist_file_for(&playlist_dir_str, e.path()); + !playlist.exists() && !playlist_unsupported_sentinel(&playlist).exists() + }) .collect::>(); let scan_dir_name = msg.directory.clone(); - let playlist_output_dir = self.playlist_dir.clone(); let playlist_generator = self.playlist_generator.clone(); Box::pin(async move { @@ -285,6 +366,9 @@ impl Handler for VideoPlaylistManager { path_as_str ); } + Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { + debug!("Playlist already exists for '{:?}', skipping", path); + } Err(e) => { warn!("Failed to generate playlist for path '{:?}'. {:?}", path, e); } @@ -318,14 +402,19 @@ impl Handler for VideoPlaylistManager { ); let playlist_output_dir = self.playlist_dir.clone(); + let playlist_dir_str = playlist_output_dir.to_str().unwrap().to_string(); let playlist_generator = self.playlist_generator.clone(); for video_path in msg.video_paths { + let playlist = playlist_file_for(&playlist_dir_str, &video_path); + if playlist.exists() || playlist_unsupported_sentinel(&playlist).exists() { + continue; + } let path_str = video_path.to_string_lossy().to_string(); debug!("Queueing playlist generation for: {}", path_str); playlist_generator.do_send(GeneratePlaylistMessage { - playlist_path: playlist_output_dir.to_str().unwrap().to_string(), + playlist_path: playlist_dir_str.clone(), video_path, }); } @@ -357,8 +446,17 @@ pub struct PlaylistGenerator { impl PlaylistGenerator { pub(crate) fn new() -> Self { + // Concurrency is tunable via HLS_CONCURRENCY so operators can dial + // it to their hardware: 1 on weak Synology boxes to avoid thermal + // throttling, higher on desktops with spare cores. + let concurrency = std::env::var("HLS_CONCURRENCY") + .ok() + .and_then(|v| v.parse::().ok()) + .filter(|&n| n > 0) + .unwrap_or(2); + info!("PlaylistGenerator: concurrency={}", concurrency); PlaylistGenerator { - semaphore: Arc::new(Semaphore::new(2)), + semaphore: Arc::new(Semaphore::new(concurrency)), } } } @@ -418,14 +516,42 @@ impl Handler for PlaylistGenerator { return Err(std::io::Error::from(std::io::ErrorKind::AlreadyExists)); } - // Check if video is already h264 encoded - let is_h264 = is_h264_encoded(&video_file).await; - - // Check for rotation metadata - let rotation = get_video_rotation(&video_file).await; + // One ffprobe call for codec + rotation metadata. + let stream_meta = probe_video_stream_meta(&video_file).await; + let is_h264 = stream_meta.is_h264; + let rotation = stream_meta.rotation; let has_rotation = rotation != 0; - let use_copy = is_h264 && !has_rotation; + // Stream-copy is only safe when the source GOP fits inside a + // single HLS segment. Otherwise ffmpeg has to extend segments + // past hls_time to land on a keyframe, producing uneven + // segments and seeking glitches. + const HLS_SEGMENT_SECONDS: f64 = 3.0; + let gop_ok = if is_h264 && !has_rotation { + match get_max_gop_seconds(&video_file).await { + Some(g) if g > HLS_SEGMENT_SECONDS => { + info!( + "Video {} has long GOP ({:.1}s > {}s), transcoding for segment alignment", + video_file, g, HLS_SEGMENT_SECONDS + ); + false + } + Some(_) => true, + None => { + // Probe failed — be conservative and transcode rather + // than risk broken segments from a mystery source. + debug!( + "GOP probe failed for {}, transcoding to be safe", + video_file + ); + false + } + } + } else { + false + }; + + let use_copy = is_h264 && !has_rotation && gop_ok; if has_rotation { info!( @@ -439,59 +565,182 @@ impl Handler for PlaylistGenerator { } else if use_copy { info!("Video {} is already h264, using stream copy", video_file); span.add_event("Using stream copy (h264 detected)", vec![]); + } else if is_h264 { + info!( + "Video {} is h264 but needs transcoding for GOP alignment", + video_file + ); + span.add_event("Transcoding for GOP alignment", vec![]); } else { info!("Video {} needs transcoding to h264", video_file); span.add_event("Transcoding to h264", vec![]); } - tokio::spawn(async move { - let mut cmd = tokio::process::Command::new("ffmpeg"); - cmd.arg("-i").arg(&video_file); + // Encode to a .tmp playlist and explicit segment names so a failed + // encode leaves predictable artifacts we can clean up — and so a + // concurrent scan doesn't see a half-written .m3u8 as "done". + let playlist_tmp = format!("{}.tmp", playlist_file); + let video_stem = msg + .video_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("video"); + let segment_pattern = format!("{}/{}_%03d.ts", playlist_path, video_stem); - if use_copy { - // Video is already h264, just copy the stream - // Note: rotation metadata will be preserved in the stream - cmd.arg("-c:v").arg("copy"); - cmd.arg("-c:a").arg("aac"); // Still need to ensure audio is compatible + let mut cmd = tokio::process::Command::new("ffmpeg"); + cmd.arg("-y").arg("-i").arg(&video_file); + + if use_copy { + cmd.arg("-c:v").arg("copy"); + cmd.arg("-c:a").arg("aac"); + } else { + let nvenc = crate::video::ffmpeg::is_nvenc_available().await; + if nvenc { + // NVENC: no CRF, use VBR + target CQ. p1 = fastest + // preset — prioritizes encoder throughput over bitrate + // efficiency. CQ 23 roughly matches libx264 crf 21 + // visually; NVENC has slightly lower compression + // efficiency per quality. + cmd.arg("-c:v").arg("h264_nvenc"); + cmd.arg("-preset").arg("p1"); + cmd.arg("-rc").arg("vbr"); + cmd.arg("-cq").arg("23"); + cmd.arg("-pix_fmt").arg("yuv420p"); } else { - // Need to transcode - autorotate is enabled by default and will apply rotation cmd.arg("-c:v").arg("h264"); cmd.arg("-crf").arg("21"); cmd.arg("-preset").arg("veryfast"); - cmd.arg("-vf").arg("scale=1080:-2,setsar=1:1"); - cmd.arg("-c:a").arg("aac"); } + cmd.arg("-vf").arg("scale='min(1080,iw)':-2,setsar=1:1"); + cmd.arg("-c:a").arg("aac"); + // Force an IDR frame every hls_time seconds so each HLS + // segment starts on a keyframe — accurate seeking without + // players having to decode from a prior segment. + cmd.arg("-force_key_frames").arg("expr:gte(t,n_forced*3)"); + } - // Common HLS settings - cmd.arg("-hls_time").arg("3"); - cmd.arg("-hls_list_size").arg("100"); - cmd.arg(&playlist_file); - cmd.stdout(Stdio::null()); - cmd.stderr(Stdio::piped()); + // -f hls is required because the playlist is written to a .tmp + // path during encoding — ffmpeg normally infers the muxer from + // the output extension and doesn't recognize ".m3u8.tmp". + cmd.arg("-f").arg("hls"); + cmd.arg("-hls_time").arg("3"); + cmd.arg("-hls_list_size").arg("0"); + cmd.arg("-hls_playlist_type").arg("vod"); + // independent_segments advertises that each segment can be + // decoded without reference to any other — the matching guarantee + // for the forced keyframes above. + cmd.arg("-hls_flags").arg("independent_segments"); + cmd.arg("-hls_segment_filename").arg(&segment_pattern); + cmd.arg(&playlist_tmp); + cmd.stdout(Stdio::null()); + cmd.stderr(Stdio::piped()); + cmd.kill_on_drop(true); - let ffmpeg_result = cmd - .output() - .inspect_err(|e| error!("Failed to run ffmpeg on child process: {}", e)) - .map_err(|e| std::io::Error::other(e.to_string())) - .await; + // Spawn + wait under a timeout so a hung ffmpeg (corrupt source, + // NFS stall, etc.) doesn't permanently hold a semaphore slot. + // Default is generous — a long 4K transcode on CPU can take hours. + let timeout_secs = std::env::var("HLS_TIMEOUT_SECONDS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(7200); - // Hang on to the permit until we're done decoding and then explicitly drop - drop(permit); - - if let Ok(ref res) = ffmpeg_result { - debug!("ffmpeg output: {:?}", res); + let ffmpeg_result = match cmd.spawn() { + Ok(child) => { + match tokio::time::timeout( + std::time::Duration::from_secs(timeout_secs), + child.wait_with_output(), + ) + .await + { + Ok(res) => res + .inspect_err(|e| { + error!("Failed to wait on ffmpeg child process: {}", e) + }) + .map_err(|e| std::io::Error::other(e.to_string())), + Err(_) => Err(std::io::Error::other(format!( + "ffmpeg exceeded {}s timeout", + timeout_secs + ))), + } } + Err(e) => { + error!("Failed to spawn ffmpeg: {}", e); + Err(std::io::Error::other(e.to_string())) + } + }; + drop(permit); + + let success = matches!(&ffmpeg_result, Ok(out) if out.status.success()); + + if success { + if let Err(e) = tokio::fs::rename(&playlist_tmp, &playlist_file).await { + error!( + "ffmpeg succeeded but rename {} -> {} failed: {}", + playlist_tmp, playlist_file, e + ); + cleanup_partial_hls(&playlist_tmp, playlist_path.as_str(), video_stem).await; + span.set_status(Status::error(format!("rename failed: {}", e))); + return Err(e); + } + debug!("Playlist complete: {}", playlist_file); span.set_status(Status::Ok); - - ffmpeg_result - }); - - Ok(()) + Ok(()) + } else { + let detail = match &ffmpeg_result { + Ok(out) => format!( + "exit {}: {}", + out.status, + String::from_utf8_lossy(&out.stderr).trim() + ), + Err(e) => format!("ffmpeg failed: {}", e), + }; + error!("ffmpeg failed for {}: {}", video_file, detail); + cleanup_partial_hls(&playlist_tmp, playlist_path.as_str(), video_stem).await; + let sentinel = playlist_unsupported_sentinel(Path::new(&playlist_file)); + if let Err(se) = tokio::fs::write(&sentinel, b"").await { + warn!( + "Failed to write playlist sentinel {}: {}", + sentinel.display(), + se + ); + } else { + info!( + "Wrote playlist sentinel {} so future scans skip {}", + sentinel.display(), + video_file + ); + } + span.set_status(Status::error(detail.clone())); + Err(std::io::Error::other(detail)) + } }) } } +/// Delete the temp playlist and any segment files that ffmpeg may have written +/// before failing. Called both on ffmpeg error and on rename failure so a +/// retry on the next scan starts from a clean slate. +async fn cleanup_partial_hls(playlist_tmp: &str, playlist_dir: &str, video_stem: &str) { + let _ = tokio::fs::remove_file(playlist_tmp).await; + + let segment_prefix = format!("{}_", video_stem); + let Ok(mut entries) = tokio::fs::read_dir(playlist_dir).await else { + return; + }; + while let Ok(Some(entry)) = entries.next_entry().await { + let Some(name) = entry.file_name().to_str().map(str::to_owned) else { + continue; + }; + if name.starts_with(&segment_prefix) + && name.ends_with(".ts") + && let Err(e) = tokio::fs::remove_file(entry.path()).await + { + warn!("Failed to remove partial segment {}: {}", name, e); + } + } +} + #[derive(Message)] #[rtype(result = "()")] pub struct GeneratePreviewClipMessage { diff --git a/src/video/ffmpeg.rs b/src/video/ffmpeg.rs index 5ed9308..a31fd0c 100644 --- a/src/video/ffmpeg.rs +++ b/src/video/ffmpeg.rs @@ -22,16 +22,16 @@ async fn check_nvenc_available() -> bool { } /// Returns whether NVENC is available, caching the result after first check. -async fn is_nvenc_available() -> bool { +pub async fn is_nvenc_available() -> bool { if let Some(&available) = NVENC_AVAILABLE.get() { return available; } let available = check_nvenc_available().await; let _ = NVENC_AVAILABLE.set(available); if available { - info!("CUDA NVENC hardware acceleration detected and enabled for preview clips"); + info!("CUDA NVENC hardware acceleration detected and enabled"); } else { - info!("NVENC not available, using CPU encoding for preview clips"); + info!("NVENC not available, using CPU encoding"); } available }