diff --git a/.gitignore b/.gitignore index 2bd4d6e..5dceed2 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ database/target *.db-shm *.db-wal .env +# Server-local TTS pronunciation overrides (tts_pronunciations.example.json is the template) +/tts_pronunciations.json /tmp /docs /specs diff --git a/CLAUDE.md b/CLAUDE.md index b63ed4c..4faec1c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -645,6 +645,14 @@ OPENROUTER_APP_TITLE=ImageApi # Optional attribution header # re-embedding — mixed vector spaces break similarity search. LLM_BACKEND=ollama +# Embedding model contract. Corpus and queries must be embedded by the same +# model with matching prefixes — after changing the embed model or any of +# these, run `cargo run --bin reembed_embeddings` (all tables) or search is +# garbage. Prefix values may contain a literal \n (expanded to a newline). +EMBEDDING_DIM=768 # 768 = nomic-embed-text v1.5; 1024 = Qwen3-Embedding-0.6B +EMBED_QUERY_PREFIX= # nomic: "search_query: " | Qwen3: "Instruct: \nQuery: " +EMBED_DOCUMENT_PREFIX= # nomic: "search_document: " | Qwen3: leave empty + # llama.cpp / llama-swap (used when LLM_BACKEND=llamacpp). OpenAI-compatible # proxy hosting one or more llama-server processes. Chat models receive # images directly via content-parts (all models assumed vision-capable). @@ -668,6 +676,8 @@ LLAMA_SWAP_TTS_REF_SECONDS=30 # Max voice-clone reference clip # (Chatterbox is zero-shot; ~10-20s clean ref is ideal) LLAMA_SWAP_TTS_REQUEST_TIMEOUT_SECONDS=600 # Per-request synth timeout (long chunked insights take # minutes); overrides the shared client timeout for /tts/speech +TTS_PRONUNCIATIONS_PATH=tts_pronunciations.json # JSON map of pronunciation overrides applied before synth + # (see tts_pronunciations.example.json); hot-reloaded on change # Insight Chat Continuation AGENTIC_CHAT_MAX_ITERATIONS=6 # Cap on tool-calling iterations per chat turn (default 6) diff --git a/README.md b/README.md index 39ebe30..8a6421b 100644 --- a/README.md +++ b/README.md @@ -153,17 +153,39 @@ behind the same llama-swap proxy. Only requires `LLAMA_SWAP_URL` (the TTS client is built whenever that's set — independent of `LLM_BACKEND`). Endpoints: - `POST /tts/speech` — body `{ text, voice?, format?, exaggeration?, cfg_weight?, temperature? }`; returns `{ audio_base64, format }`. Input is cleaned - server-side (markdown + emoji stripped) and the generation knobs are clamped + server-side (markdown + emoji stripped, then pronunciation overrides applied — + see below) and the generation knobs are clamped to Chatterbox's ranges. Synthesis is serialized (one at a time — the upstream has no GPU lock of its own); a concurrent request gets a fast `429`. -- `GET /tts/voices` — list the voice library. +- `POST /tts/speech/jobs` — durable variant for long syntheses: same body as + `/tts/speech`, returns `202 { job_id, status }` immediately. Jobs queue on the + GPU permit instead of fast-failing `429`. +- `GET /tts/speech/jobs/{id}` — poll a job: `{ job_id, status, format, + audio_base64?, error? }` with status `queued|running|done|error|cancelled`. + Results are kept in memory ~10 min after completion, then the job 404s. +- `DELETE /tts/speech/jobs/{id}` — cancel a queued/running job. +- `GET /tts/voices` — list the voice library. Served from an in-memory cache + (so the listing doesn't make llama-swap spin up the TTS model and evict the + resident LLM); pass `?refresh=1` to force an upstream re-query. The cache is + invalidated by voice create/delete. - `POST /tts/voices/upload` — multipart `voice_name` + `voice_file`; clone a voice from an uploaded clip (≤25 MB). - `POST /tts/voices/from-library` — body `{ voice_name, path, library? }`; clone from a library file (audio forwarded as-is; video has its audio extracted via ffmpeg). +- `DELETE /tts/voices/{name}` — remove a cloned voice from the library. + +Created voice names are tagged with the ref-clip cap in effect (e.g. +`grandma-30s`) so the library shows which reference length produced each clone. + +Words the model mispronounces (place names, initialisms) can be rewritten +before synthesis via a JSON map — copy `tts_pronunciations.example.json` to +`tts_pronunciations.json` and edit; changes apply without a restart. Full +matching rules are documented in `src/ai/pronunciation.rs`. Env: +- `TTS_PRONUNCIATIONS_PATH` - pronunciation-override JSON file + [default: `tts_pronunciations.json` in the working directory] - `LLAMA_SWAP_TTS_MODEL` - TTS model id in llama-swap's `config.yaml` [default: `chatterbox`] - `LLAMA_SWAP_TTS_VOICE` - default voice used when a `/tts/speech` request omits `voice` (optional) - `LLAMA_SWAP_TTS_REF_SECONDS` - max voice-clone reference clip length in seconds diff --git a/src/ai/gpu.rs b/src/ai/gpu.rs new file mode 100644 index 0000000..728a144 --- /dev/null +++ b/src/ai/gpu.rs @@ -0,0 +1,88 @@ +// GPU lease — in-process coordination for llama-swap model contention. +// +// llama-swap runs the heavyweight models (chat / vision / Chatterbox TTS) as +// a mutually-exclusive set on one GPU (matrix DSL `(q27 | … | tts) & e`): a +// request for a non-resident model is HELD by llama-swap until the resident +// model's in-flight requests drain, then the models swap. That hold counts +// against the *holder's* reqwest timeout — measured live: a queued TTS burned +// 77s of its budget behind a single LLM turn, and an LLM request behind a +// running synthesis waited the entire remaining synth. Uncoordinated +// cross-model traffic therefore times out instead of queueing. +// +// The lease moves that wait into this process, BEFORE the HTTP request is +// sent and before its timeout starts: +// - chat/vision requests (the LLM-side slots) share the READ lease; +// - TTS synthesis and voice-library ops (anything that spins Chatterbox up +// and evicts the LLM) take the WRITE lease; +// - embeddings take NO lease: the `embed` slot is in llama-swap's +// always-resident group (the `& e` term) and never participates in a swap, +// so leasing it would only stall searches behind a queued synthesis. +// +// tokio's RwLock is fair (FIFO, write-preferring): a queued TTS gets the GPU +// right after the current LLM request drains, and later LLM requests queue +// behind it — bounded waits in both directions, no starvation, no timeout +// budget burned while waiting. +// +// RULES: hold a lease for exactly one HTTP request (for streaming, the +// stream's lifetime) and NEVER acquire one while already holding one — once a +// writer is queued, new read acquisitions block, so nested acquisition can +// deadlock. + +use std::sync::LazyLock; +use std::time::Instant; +use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +static GPU_LEASE: LazyLock> = LazyLock::new(|| RwLock::new(())); + +/// Waits longer than this are logged — they mean a cross-model swap was +/// avoided and quantify what the request *would* have burned of its timeout. +const SLOW_WAIT_LOG_SECS: f64 = 2.0; + +/// Shared lease for LLM-side requests (chat / vision slots). +pub async fn llm_lease() -> RwLockReadGuard<'static, ()> { + let started = Instant::now(); + let guard = GPU_LEASE.read().await; + log_slow_wait("llm", started); + guard +} + +/// Exclusive lease for TTS-side requests (speech synthesis + voice-library +/// ops that spin up Chatterbox). +pub async fn tts_lease() -> RwLockWriteGuard<'static, ()> { + let started = Instant::now(); + let guard = GPU_LEASE.write().await; + log_slow_wait("tts", started); + guard +} + +fn log_slow_wait(kind: &str, started: Instant) { + let waited = started.elapsed().as_secs_f64(); + if waited > SLOW_WAIT_LOG_SECS { + log::info!("GPU lease ({kind}): waited {waited:.1}s for the other model class to drain"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // One sequential test, not several: the lease is a single global, so + // parallel tests interleaving reads and writes on it can hit the very + // nested-acquisition deadlock the module comment warns about. + #[tokio::test] + async fn write_lease_excludes_readers_then_reads_share() { + let w = tts_lease().await; + // A reader must not acquire while the writer is held. + let pending = tokio::spawn(async { drop(llm_lease().await) }); + tokio::task::yield_now().await; + assert!(!pending.is_finished()); + drop(w); + pending.await.expect("reader acquires after writer drops"); + + // With no writer queued, read leases are shared. + let a = llm_lease().await; + let b = llm_lease().await; + drop(a); + drop(b); + } +} diff --git a/src/ai/handlers.rs b/src/ai/handlers.rs index 5e46418..cb21b14 100644 --- a/src/ai/handlers.rs +++ b/src/ai/handlers.rs @@ -468,6 +468,13 @@ pub async fn generate_insight_handler( let path_for_task = path.clone(); let generator_for_task = generator.clone(); let result = tokio::task::spawn(async move { + // Cross-model barrier: if a TTS synthesis holds the GPU, wait it + // out BEFORE the generation wall-clock starts. The per-request + // lease keeps reqwest budgets honest, but this job-level timeout + // would otherwise burn while the first chat call queues behind a + // multi-minute synthesis. Dropped immediately — holding it across + // the generation would deadlock the chat calls' own leases. + drop(crate::ai::gpu::llm_lease().await); tokio::time::timeout( std::time::Duration::from_secs(timeout_secs), generator_for_task.generate_insight_for_photo_with_config( @@ -510,7 +517,9 @@ pub async fn generate_insight_handler( } Ok(Ok(Err(e))) => { log::error!("Insight generation failed for {}: {:?}", path, e); - if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:?}", e)) { + // `{:#}` = one-line context chain; the job's error_message is + // returned to the client verbatim, so no Debug/backtrace here. + if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:#}", e)) { log::error!("Failed to mark job {} as failed: {:?}", job_id, err); } } @@ -844,6 +853,9 @@ pub async fn generate_agentic_insight_handler( let path_for_task = path.clone(); let generator_for_task = generator.clone(); let result = tokio::task::spawn(async move { + // Cross-model barrier — see generate_insight_handler: wait out any + // running TTS synthesis before the generation wall-clock starts. + drop(crate::ai::gpu::llm_lease().await); tokio::time::timeout( std::time::Duration::from_secs(timeout_secs), generator_for_task.generate_agentic_insight_for_photo( @@ -884,7 +896,9 @@ pub async fn generate_agentic_insight_handler( } Ok(Ok(Err(e))) => { log::error!("Agentic insight generation failed for {}: {:?}", path, e); - if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:?}", e)) { + // `{:#}` = one-line context chain; the job's error_message is + // returned to the client verbatim, so no Debug/backtrace here. + if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:#}", e)) { log::error!("Failed to mark job {} as failed: {:?}", job_id, err); } } diff --git a/src/ai/insight_generator.rs b/src/ai/insight_generator.rs index a6a50f1..3673c43 100644 --- a/src/ai/insight_generator.rs +++ b/src/ai/insight_generator.rs @@ -33,30 +33,40 @@ use crate::utils::{earliest_fs_time, normalize_path}; /// and labels the truncation via `found_header`. const LOCATION_HISTORY_DISPLAY_LIMIT: usize = 20; +/// Strip common markdown decoration (bold/italic markers, heading hashes, +/// backticks, quotes) from both ends of a model-emitted title. Models wrap +/// the line despite the prompt: `**Title: A Day in the Woods**`, +/// `## Title: ...`, `"..."`. +pub(crate) fn strip_title_markdown(s: &str) -> &str { + s.trim_matches(|c: char| matches!(c, '*' | '_' | '`' | '#' | '"') || c.is_whitespace()) +} + /// Parse a "Title: ...\n\n" response into (title, body). /// Falls back to the first sentence as the title if the model didn't /// follow the format. pub(crate) fn parse_title_body(raw: &str) -> (String, String) { let trimmed = raw.trim(); - // Try "Title: \n\n<body>" or "Title: <title>\n<body>" - if let Some(rest) = trimmed + // Try "Title: <title>\n<body>", tolerating markdown decoration around + // the title line. + let (first_line, rest) = match trimmed.find('\n') { + Some(pos) => (&trimmed[..pos], trimmed[pos..].trim()), + None => (trimmed, ""), + }; + let first_line = strip_title_markdown(first_line); + if let Some(t) = first_line .strip_prefix("Title:") - .or_else(|| trimmed.strip_prefix("title:")) + .or_else(|| first_line.strip_prefix("title:")) { - let rest = rest.trim_start(); - if let Some(split_pos) = rest.find("\n\n").or_else(|| rest.find('\n')) { - let title = rest[..split_pos].trim(); - let body = rest[split_pos..].trim(); - if !title.is_empty() && !body.is_empty() { - return (title.to_string(), body.to_string()); - } + let title = strip_title_markdown(t); + if !title.is_empty() && !rest.is_empty() { + return (title.to_string(), rest.to_string()); } } // Fallback: first sentence (up to first `. ` or `.\n`) becomes the title if let Some(pos) = trimmed.find(". ").or_else(|| trimmed.find(".\n")) { - let title = &trimmed[..pos]; + let title = strip_title_markdown(&trimmed[..pos]); let body = trimmed[pos + 1..].trim(); if title.len() <= 100 && !body.is_empty() { return (title.to_string(), body.to_string()); @@ -65,7 +75,7 @@ pub(crate) fn parse_title_body(raw: &str) -> (String, String) { // Last resort: truncate to 60 chars for title, full text as body let title: String = trimmed.chars().take(60).collect(); - let title = title.trim_end().to_string(); + let title = strip_title_markdown(title.trim_end()).to_string(); (title, trimmed.to_string()) } @@ -535,7 +545,7 @@ impl InsightGenerator { // (`LLM_BACKEND` switch). Must match the backend that populated the // daily-summary embeddings or similarity search will be garbage. let query_embedding = - crate::ai::embed_one(&self.ollama, self.llamacpp.as_deref(), &query).await?; + crate::ai::embed_query(&self.ollama, self.llamacpp.as_deref(), &query).await?; // Search for similar daily summaries with time-based weighting // This prioritizes summaries temporally close to the query date @@ -575,6 +585,67 @@ impl InsightGenerator { Ok(formatted) } + /// Semantic search over daily summaries for the agentic `search_rag` + /// tool. Embeds the caller's query as-is (no metadata boilerplate) and + /// only applies time weighting when an anchor date is provided — + /// without one, results rank purely by similarity across all time. + async fn search_summaries_semantic( + &self, + query: &str, + date: Option<chrono::NaiveDate>, + limit: usize, + ) -> Result<Vec<String>> { + let tracer = global_tracer(); + let current_cx = opentelemetry::Context::current(); + let mut span = tracer.start_with_context("ai.rag.search_daily_summaries", ¤t_cx); + span.set_attribute(KeyValue::new("query", query.to_string())); + span.set_attribute(KeyValue::new("limit", limit as i64)); + span.set_attribute(KeyValue::new("time_weighted", date.is_some())); + if let Some(d) = date { + span.set_attribute(KeyValue::new("date", d.to_string())); + } + let search_cx = current_cx.with_span(span); + + log::info!("RAG QUERY: {} (anchor date: {:?})", query, date); + + // Must use the same backend that populated the daily-summary + // embeddings or similarity search is garbage (see embed_one docs). + let query_embedding = + crate::ai::embed_query(&self.ollama, self.llamacpp.as_deref(), query).await?; + + let mut summary_dao = self + .daily_summary_dao + .lock() + .expect("Unable to lock DailySummaryDao"); + + let similar_summaries = match date { + Some(d) => summary_dao.find_similar_summaries_with_time_weight( + &search_cx, + &query_embedding, + &d.format("%Y-%m-%d").to_string(), + limit, + ), + None => summary_dao.find_similar_summaries(&search_cx, &query_embedding, limit), + } + .map_err(|e| anyhow::anyhow!("Failed to find similar summaries: {:?}", e))?; + + search_cx.span().set_attribute(KeyValue::new( + "results_count", + similar_summaries.len() as i64, + )); + search_cx.span().set_status(Status::Ok); + + Ok(similar_summaries + .into_iter() + .map(|s| { + format!( + "[{}] {} ({} messages):\n{}", + s.date, s.contact, s.message_count, s.summary + ) + }) + .collect()) + } + /// Build a metadata-based query (fallback when no topics available) fn build_metadata_query( date: chrono::NaiveDate, @@ -626,7 +697,7 @@ impl InsightGenerator { let calendar_cx = parent_cx.with_span(span); let query_embedding = if let Some(loc) = location { - match crate::ai::embed_one(&self.ollama, self.llamacpp.as_deref(), loc).await { + match crate::ai::embed_query(&self.ollama, self.llamacpp.as_deref(), loc).await { Ok(emb) => Some(emb), Err(e) => { log::warn!("Failed to generate embedding for location '{}': {}", loc, e); @@ -798,7 +869,8 @@ impl InsightGenerator { }; let query_embedding = - match crate::ai::embed_one(&self.ollama, self.llamacpp.as_deref(), &query_text).await { + match crate::ai::embed_query(&self.ollama, self.llamacpp.as_deref(), &query_text).await + { Ok(emb) => emb, Err(e) => { log::warn!("Failed to generate search embedding: {}", e); @@ -1737,13 +1809,12 @@ Return ONLY the summary, nothing else."#, Some(q) => q.to_string(), None => return "Error: missing required parameter 'query'".to_string(), }; - let date_str = match args.get("date").and_then(|v| v.as_str()) { - Some(d) => d, - None => return "Error: missing required parameter 'date'".to_string(), - }; - let date = match NaiveDate::parse_from_str(date_str, "%Y-%m-%d") { - Ok(d) => d, - Err(e) => return format!("Error: failed to parse date '{}': {}", date_str, e), + let date = match args.get("date").and_then(|v| v.as_str()) { + Some(d) => match NaiveDate::parse_from_str(d, "%Y-%m-%d") { + Ok(d) => Some(d), + Err(e) => return format!("Error: failed to parse date '{}': {}", d, e), + }, + None => None, }; let contact = args .get("contact") @@ -1756,7 +1827,7 @@ Return ONLY the summary, nothing else."#, .clamp(1, 25) as usize; log::info!( - "tool_search_rag: query='{}', date={}, contact={:?}, limit={}", + "tool_search_rag: query='{}', date={:?}, contact={:?}, limit={}", query, date, contact, @@ -1777,15 +1848,17 @@ Return ONLY the summary, nothing else."#, limit }; + // Embed the model's query verbatim — a soft contact bias is the + // only decoration. The metadata boilerplate ("On <date>, it was a + // <weekday>") that find_relevant_messages_rag prepends drowns the + // semantic signal, so the tool path deliberately bypasses it. + let search_query = match contact.as_deref() { + Some(c) => format!("{} (conversation with {})", query, c), + None => query.clone(), + }; + let results = match self - .find_relevant_messages_rag( - date, - None, - contact.as_deref(), - None, - candidate_limit, - Some(&query), - ) + .search_summaries_semantic(&search_query, date, candidate_limit) .await { Ok(results) if !results.is_empty() => results, @@ -2062,12 +2135,15 @@ Return ONLY the summary, nothing else."#, /// Render a list of [`SmsSearchHit`] for the LLM. Prefers the SMS-API /// snippet (which already excerpts the matched span and is the only /// preview MMS-attachment-only matches have) over the full body, and - /// strips the `<mark>` tags the snippet ships with. + /// strips the `<mark>` tags the snippet ships with. Each line names + /// both parties (`sender → recipient`) — results can span multiple + /// conversations, and a sender-only label leaves sent messages + /// unattributable to a thread. fn format_search_hits(hits: &[SmsSearchHit], mode: &str, date_filtered: bool) -> String { let user_name = user_display_name(); let mut out = String::new(); out.push_str(&format!( - "Found {} messages (mode: {}{}):\n\n", + "Found {} messages (mode: {}{}, sender → recipient):\n\n", hits.len(), mode, if date_filtered { ", date-filtered" } else { "" } @@ -2076,10 +2152,10 @@ Return ONLY the summary, nothing else."#, let date = chrono::DateTime::from_timestamp(h.date, 0) .map(|dt| dt.format("%Y-%m-%d").to_string()) .unwrap_or_else(|| h.date.to_string()); - let direction: &str = if h.type_ == 2 { - &user_name + let direction = if h.type_ == 2 { + format!("{} → {}", user_name, h.contact_name) } else { - &h.contact_name + format!("{} → {}", h.contact_name, user_name) }; let score = h .similarity_score @@ -2150,11 +2226,18 @@ Return ONLY the summary, nothing else."#, { Ok(messages) if !messages.is_empty() => { let user_name = user_display_name(); + // Name both parties — without a contact filter the window + // spans every conversation, and a sender-only label leaves + // sent messages unattributable to a thread. let formatted: Vec<String> = messages .iter() .take(limit) .map(|m| { - let sender: &str = if m.is_sent { &user_name } else { &m.contact }; + let direction = if m.is_sent { + format!("{} → {}", user_name, m.contact) + } else { + format!("{} → {}", m.contact, user_name) + }; let ts = DateTime::from_timestamp(m.timestamp, 0) .map(|dt| { dt.with_timezone(&Local) @@ -2162,7 +2245,7 @@ Return ONLY the summary, nothing else."#, .to_string() }) .unwrap_or_else(|| "unknown".to_string()); - format!("[{}] {}: {}", ts, sender, m.body) + format!("[{}] {}: {}", ts, direction, m.body) }) .collect(); format!( @@ -2870,17 +2953,34 @@ Return ONLY the summary, nothing else."#, // Generate embedding for name + description (best-effort) via the // configured local backend. let embed_text = format!("{} {}", name, description); - let embedding: Option<Vec<u8>> = - match crate::ai::embed_one(&self.ollama, self.llamacpp.as_deref(), &embed_text).await { - Ok(vec) => { - let bytes: Vec<u8> = vec.iter().flat_map(|f| f.to_le_bytes()).collect(); - Some(bytes) - } - Err(e) => { - log::warn!("Embedding generation failed for entity '{}': {}", name, e); - None - } - }; + let embedding: Option<Vec<u8>> = match crate::ai::embed_document( + &self.ollama, + self.llamacpp.as_deref(), + &embed_text, + ) + .await + { + // The entities table has no dim check at the DAO layer, and a + // wrong-dim vector silently kills dedup/recall (cosine over + // mismatched lengths is 0) — guard here, store None instead. + Ok(vec) if vec.len() == crate::ai::embedding_dim() => { + let bytes: Vec<u8> = vec.iter().flat_map(|f| f.to_le_bytes()).collect(); + Some(bytes) + } + Ok(vec) => { + log::warn!( + "Entity '{}' embedding has {} dims (expected {}) — storing without embedding", + name, + vec.len(), + crate::ai::embedding_dim() + ); + None + } + Err(e) => { + log::warn!("Embedding generation failed for entity '{}': {}", name, e); + None + } + }; let now = chrono::Utc::now().timestamp(); let insert = InsertEntity { @@ -3206,21 +3306,25 @@ Return ONLY the summary, nothing else."#, if opts.daily_summaries_present { tools.push(Tool::function( "search_rag", - "Date-anchored semantic search over the user's daily-summary corpus. \ - Returns up to `limit` summaries most semantically similar to `query`, \ - weighted toward summaries near `date`. For raw message text across all \ - time, prefer `search_messages`. \ - Examples: `{query: \"family dinner\", date: \"2018-12-24\"}` — what \ + "Semantic search over the user's daily-summary corpus. Returns up to \ + `limit` summaries most semantically similar to `query`. Pass `date` \ + to anchor in time: summaries near that date rank higher and matches \ + months away decay sharply. Omit `date` to rank purely by semantic \ + similarity across all time — do this for \"when did X happen?\" \ + questions where the date is unknown. For raw message text, prefer \ + `search_messages`. \ + Examples: `{query: \"family dinner\"}` — best matches across all \ + time. `{query: \"family dinner\", date: \"2018-12-24\"}` — what \ daily summaries near Christmas Eve mention family / dinner / gathering. \ `{query: \"work travel\", date: \"2019-06-15\", contact: \"Alice\"}` — \ - narrowed to summaries that involve Alice.", + biased toward summaries that involve Alice.", serde_json::json!({ "type": "object", - "required": ["query", "date"], + "required": ["query"], "properties": { "query": { "type": "string", "description": "Free-text query, semantically matched." }, - "date": { "type": "string", "description": "Anchor date, YYYY-MM-DD. Summaries near this date rank higher." }, - "contact": { "type": "string", "description": "Optional contact name to bias toward conversations with that person." }, + "date": { "type": "string", "description": "Optional anchor date, YYYY-MM-DD. When set, summaries near this date rank higher; omit to search all time evenly." }, + "contact": { "type": "string", "description": "Optional contact name to bias toward conversations with that person (soft semantic bias, not a hard filter)." }, "limit": { "type": "integer", "description": "Max summaries to return (default 10, max 25)." } } }), @@ -4763,12 +4867,22 @@ mod tests { let hit = make_search_hit(1, "Sarah", "see you at the lake tomorrow", None, 1); let out = InsightGenerator::format_search_hits(&[hit], "fts5", false); - assert!(out.starts_with("Found 1 messages (mode: fts5):")); + assert!(out.starts_with("Found 1 messages (mode: fts5")); assert!(out.contains("see you at the lake tomorrow")); - assert!(out.contains("Sarah —")); + // Received message: contact is the sender. + assert!(out.contains("Sarah →")); assert!(!out.contains("date-filtered")); } + #[test] + fn format_search_hits_labels_sent_direction() { + // Sent messages must name the recipient — results can span multiple + // conversations, and a sender-only label left them unattributable. + let hit = make_search_hit(5, "Sarah", "on my way", None, 2); + let out = InsightGenerator::format_search_hits(&[hit], "fts5", false); + assert!(out.contains("→ Sarah —")); + } + #[test] fn format_search_hits_prefers_snippet_over_body_and_strips_marks() { let hit = make_search_hit( @@ -4799,7 +4913,7 @@ mod tests { assert!(out.contains("birthday_cake.jpg")); assert!(!out.contains("<mark>")); - assert!(out.contains("Mom —")); + assert!(out.contains("Mom →")); } #[test] @@ -5022,6 +5136,28 @@ mod tests { assert_eq!(b, "Everyone gathered..."); } + #[test] + fn parse_title_body_strips_bold_wrapper() { + let (t, b) = parse_title_body("**Title: A Day in the Woods**\n\nWe hiked the ridge trail."); + assert_eq!(t, "A Day in the Woods"); + assert_eq!(b, "We hiked the ridge trail."); + } + + #[test] + fn parse_title_body_strips_bold_label_only() { + // Bold around just the label: "**Title:** X" + let (t, b) = parse_title_body("**Title:** Garden Party\n\nEveryone gathered..."); + assert_eq!(t, "Garden Party"); + assert_eq!(b, "Everyone gathered..."); + } + + #[test] + fn parse_title_body_strips_heading_hashes() { + let (t, b) = parse_title_body("## Title: Morning Walk\nThe sun was rising..."); + assert_eq!(t, "Morning Walk"); + assert_eq!(b, "The sun was rising..."); + } + #[test] fn parse_title_body_fallback_first_sentence() { let (t, b) = parse_title_body("A warm summer day. We gathered at the park for a picnic."); diff --git a/src/ai/llamacpp.rs b/src/ai/llamacpp.rs index 6227e2f..8a7c898 100644 --- a/src/ai/llamacpp.rs +++ b/src/ai/llamacpp.rs @@ -142,6 +142,11 @@ impl LlamaCppClient { /// Chatterbox generation knobs are forwarded when set (caller is expected /// to have range-clamped them): `exaggeration` (0.25–2.0, emotion), /// `cfg_weight` (0.0–1.0, pace), `temperature` (0.05–5.0, randomness). + /// + /// Callers must hold the GPU write lease (`ai::gpu::tts_lease`) across + /// this call. It is taken at the call sites in `ai::tts` rather than here + /// so the speech-job path can flip its job to `running` between acquiring + /// the GPU and sending the request. pub async fn text_to_speech( &self, input: &str, @@ -204,6 +209,9 @@ impl LlamaCppClient { /// List voices in the Chatterbox voice library (raw JSON passthrough). pub async fn list_voices(&self) -> Result<Value> { let url = format!("{}/upstream/{}/voices", self.swap_root(), self.tts_model); + // The /upstream passthrough spins Chatterbox up (evicting the LLM), + // so it takes the exclusive GPU lease like synthesis does. + let _gpu = crate::ai::gpu::tts_lease().await; let resp = self .client .get(&url) @@ -237,6 +245,9 @@ impl LlamaCppClient { .text("voice_name", voice_name.to_string()) .part("voice_file", part); + // The /upstream passthrough spins Chatterbox up (evicting the LLM), + // so it takes the exclusive GPU lease like synthesis does. + let _gpu = crate::ai::gpu::tts_lease().await; let resp = self .client .post(&url) @@ -253,6 +264,37 @@ impl LlamaCppClient { resp.json().await.context("parsing create_voice response") } + /// Delete a cloned voice from the Chatterbox voice library + /// (`DELETE /voices/{name}` on the upstream, via llama-swap passthrough). + pub async fn delete_voice(&self, voice_name: &str) -> Result<Value> { + let url = format!( + "{}/upstream/{}/voices/{}", + self.swap_root(), + self.tts_model, + voice_name + ); + // The /upstream passthrough spins Chatterbox up (evicting the LLM), + // so it takes the exclusive GPU lease like synthesis does. + let _gpu = crate::ai::gpu::tts_lease().await; + let resp = self + .client + .delete(&url) + .send() + .await + .with_context(|| format!("DELETE {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + bail!("llama-swap delete_voice failed: {} — {}", status, text); + } + // Some upstreams reply with an empty body on delete. + Ok(resp + .json() + .await + .unwrap_or_else(|_| json!({ "status": "deleted" }))) + } + /// Translate canonical messages to the OpenAI-compatible wire shape. /// Behaviorally identical to `OpenRouterClient::messages_to_openai` — /// stringify tool-call arguments, rewrite images into content-parts, attach @@ -453,6 +495,9 @@ impl LlamaCppClient { body.insert(k.into(), v); } + // Wait for any TTS synthesis to release the GPU before the request + // timeout starts (see ai::gpu). + let _gpu = crate::ai::gpu::llm_lease().await; let resp = self .client .post(&url) @@ -571,6 +616,10 @@ impl LlmClient for LlamaCppClient { body.insert(k.into(), v); } + // Wait for any TTS synthesis to release the GPU before the request + // timeout starts (see ai::gpu). The guard is moved into the stream + // below so the lease spans the whole generation, not just the send. + let gpu = crate::ai::gpu::llm_lease().await; let resp = self .client .post(&url) @@ -587,6 +636,7 @@ impl LlmClient for LlamaCppClient { let byte_stream = resp.bytes_stream(); let stream = async_stream::stream! { + let _gpu = gpu; let mut byte_stream = byte_stream; let mut buf: Vec<u8> = Vec::new(); let mut accumulated_content = String::new(); @@ -702,6 +752,9 @@ impl LlmClient for LlamaCppClient { } async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> { + // Deliberately NO GPU lease: the embed slot sits in llama-swap's + // always-resident group and never participates in a model swap, so + // leasing here would only stall searches behind a queued synthesis. let url = format!("{}/embeddings", self.base_url); let body = json!({ "model": self.embedding_model, diff --git a/src/ai/local_llm.rs b/src/ai/local_llm.rs new file mode 100644 index 0000000..bf3510c --- /dev/null +++ b/src/ai/local_llm.rs @@ -0,0 +1,88 @@ +//! Bundle of the local LLM pair (Ollama + optional llama-swap) with the +//! `LLM_BACKEND` dispatch baked in. +//! +//! Exists because passing the pair around as loose values invited the same +//! bug three times: import/backfill tooling embedded corpora via +//! `OllamaClient` directly while the query side dispatched through +//! `embed_one`, so flipping `LLM_BACKEND=llamacpp` silently split queries +//! and corpus into different vector spaces. Anything that writes or reads +//! embeddings should go through this type (or `embed_one`/`embed_many`), +//! never a concrete client. +//! +//! Deliberately knows nothing about chat policy — hybrid/OpenRouter routing +//! is request-scoped and stays in `ResolvedBackend`. This is only the +//! local stack: embeddings and offline single-shot generation. + +// Constructed by binaries, not the server — dead code from main.rs's view. +#![allow(dead_code)] + +use std::sync::Arc; + +use anyhow::Result; + +use super::llamacpp::LlamaCppClient; +use super::llm_client::LlmClient; +use super::ollama::{EMBEDDING_MODEL, OllamaClient}; + +#[derive(Clone)] +pub struct LocalLlm { + ollama: OllamaClient, + llamacpp: Option<Arc<LlamaCppClient>>, +} + +impl LocalLlm { + pub fn new(ollama: OllamaClient, llamacpp: Option<Arc<LlamaCppClient>>) -> Self { + Self { ollama, llamacpp } + } + + /// Construct from the canonical env wiring shared with `AppState`. + pub fn from_env() -> Self { + Self::new( + crate::state::build_ollama_from_env(), + crate::state::build_llamacpp_from_env(), + ) + } + + /// Embed a search query (applies `EMBED_QUERY_PREFIX`). Callers must + /// pick query vs document — retrieval models treat the two sides + /// differently and an unmarked embed invites prefix-mismatch bugs. + pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> { + super::embed_query(&self.ollama, self.llamacpp.as_deref(), text).await + } + + /// Embed corpus text (applies `EMBED_DOCUMENT_PREFIX`). + pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> { + super::embed_document(&self.ollama, self.llamacpp.as_deref(), text).await + } + + /// Single-shot local text generation via the `LLM_BACKEND`-selected + /// client (offline tooling; chat turns belong to `ResolvedBackend`). + pub async fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> { + if super::local_backend_is_llamacpp() { + if let Some(lc) = self.llamacpp.as_deref() { + return <LlamaCppClient as LlmClient>::generate(lc, prompt, system, None).await; + } + anyhow::bail!( + "LLM_BACKEND=llamacpp but LlamaCppClient is unconfigured — \ + set LLAMA_SWAP_URL or switch to LLM_BACKEND=ollama" + ); + } + self.ollama.generate(prompt, system).await + } + + /// Label identifying which backend + model produces embeddings right + /// now. Store it alongside vectors (`model_version` columns) so a + /// backend flip is detectable in the data, not just in env history. + pub fn embedding_model_version(&self) -> String { + if super::local_backend_is_llamacpp() { + let slot = self + .llamacpp + .as_deref() + .map(|c| c.embedding_model.as_str()) + .unwrap_or("embed"); + format!("llama-swap:{}", slot) + } else { + EMBEDDING_MODEL.to_string() + } + } +} diff --git a/src/ai/mod.rs b/src/ai/mod.rs index 40a3f21..c5302fb 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -3,13 +3,16 @@ pub mod backend; pub mod clip_client; pub mod daily_summary_job; pub mod face_client; +pub mod gpu; pub mod handlers; pub mod insight_chat; pub mod insight_generator; pub mod llamacpp; pub mod llm_client; +pub mod local_llm; pub mod ollama; pub mod openrouter; +pub mod pronunciation; pub mod sms_client; pub mod tts; pub mod turn_registry; @@ -34,11 +37,15 @@ pub use llamacpp::LlamaCppClient; pub use llm_client::{ ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction, ToolFunction, }; +// LocalLlm is constructed by binaries (reembed_embeddings, importers), not the server +#[allow(unused_imports)] +pub use local_llm::LocalLlm; pub use ollama::{EMBEDDING_MODEL, OllamaClient}; pub use sms_client::{SmsApiClient, SmsMessage}; pub use tts::{ - create_voice_from_library_handler, create_voice_upload_handler, list_voices_handler, - tts_speech_handler, + cancel_speech_job_handler, create_speech_job_handler, create_voice_from_library_handler, + create_voice_upload_handler, delete_voice_handler, list_voices_handler, + speech_job_status_handler, tts_speech_handler, }; /// Display name used for the user in message transcripts and first-person @@ -69,35 +76,100 @@ pub fn local_backend_is_llamacpp() -> bool { ) } -/// Embed one string via the configured local backend. Routes through -/// llama-swap when `LLM_BACKEND=llamacpp` (and a client is configured), -/// else Ollama. Returns the single embedding vector. See -/// [`local_backend_is_llamacpp`] for the rationale on consistency. -pub async fn embed_one( +/// Expected embedding dimensionality, env-overridable via `EMBEDDING_DIM` +/// (default 768, nomic-embed-text). Every store/query dim check reads this — +/// swapping to a different-dim model (e.g. Qwen3-Embedding-0.6B at 1024) is +/// then a config flip plus a `reembed_embeddings` run, not a code change. +/// Cached for the process lifetime; a flip requires a restart anyway since +/// the corpus must be re-embedded with it. +pub fn embedding_dim() -> usize { + static DIM: std::sync::OnceLock<usize> = std::sync::OnceLock::new(); + *DIM.get_or_init(|| { + std::env::var("EMBEDDING_DIM") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(768) + }) +} + +/// Read an embedding prefix from the environment. `.env` values can't hold +/// real newlines, so a literal `\n` in the value is expanded — Qwen3-style +/// query instructions need one ("Instruct: ...\nQuery: "). +fn embed_prefix(key: &str) -> String { + std::env::var(key) + .map(|v| v.replace("\\n", "\n")) + .unwrap_or_default() +} + +/// Embed a search query. Applies `EMBED_QUERY_PREFIX` (default empty) — +/// retrieval models distinguish query-side from document-side text: +/// nomic v1.5 wants `search_query: `, Qwen3-Embedding wants +/// `Instruct: <task>\nQuery: `. Must pair with the document prefix the +/// corpus was embedded with or similarity degrades. +pub async fn embed_query( ollama: &OllamaClient, llamacpp: Option<&LlamaCppClient>, text: &str, ) -> anyhow::Result<Vec<f32>> { + let prefixed = format!("{}{}", embed_prefix("EMBED_QUERY_PREFIX"), text); + embed_one(ollama, llamacpp, &prefixed).await +} + +/// Embed corpus text (the stored side of retrieval). Applies +/// `EMBED_DOCUMENT_PREFIX` (default empty; nomic v1.5 wants +/// `search_document: `, Qwen3-Embedding wants none). +pub async fn embed_document( + ollama: &OllamaClient, + llamacpp: Option<&LlamaCppClient>, + text: &str, +) -> anyhow::Result<Vec<f32>> { + let prefixed = format!("{}{}", embed_prefix("EMBED_DOCUMENT_PREFIX"), text); + embed_one(ollama, llamacpp, &prefixed).await +} + +/// Embed a batch of strings via the configured local backend. Routes +/// through llama-swap when `LLM_BACKEND=llamacpp` (and a client is +/// configured), else Ollama. See [`local_backend_is_llamacpp`] for the +/// rationale on consistency. +pub async fn embed_many( + ollama: &OllamaClient, + llamacpp: Option<&LlamaCppClient>, + texts: &[&str], +) -> anyhow::Result<Vec<Vec<f32>>> { if local_backend_is_llamacpp() { if let Some(lc) = llamacpp { - let mut vecs = <LlamaCppClient as LlmClient>::generate_embeddings(lc, &[text]).await?; - return vecs - .pop() - .ok_or_else(|| anyhow::anyhow!("llama-swap returned no embeddings")); + return <LlamaCppClient as LlmClient>::generate_embeddings(lc, texts).await; } anyhow::bail!( "LLM_BACKEND=llamacpp but LlamaCppClient is unconfigured — \ set LLAMA_SWAP_URL or switch to LLM_BACKEND=ollama" ); } - ollama.generate_embedding(text).await + ollama.generate_embeddings(texts).await +} + +/// Embed one string via the configured local backend. Single-text +/// convenience over [`embed_many`]. +pub async fn embed_one( + ollama: &OllamaClient, + llamacpp: Option<&LlamaCppClient>, + text: &str, +) -> anyhow::Result<Vec<f32>> { + let mut vecs = embed_many(ollama, llamacpp, &[text]).await?; + vecs.pop() + .ok_or_else(|| anyhow::anyhow!("embedding backend returned no embeddings")) } #[cfg(test)] mod env_dispatch_tests { use super::*; + /// Env vars are process-global, and the test harness runs in parallel — + /// without this lock the `LLM_BACKEND` tests race each other and flake. + static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + fn with_env<F: FnOnce()>(key: &str, val: Option<&str>, f: F) { + let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner()); let prev = std::env::var(key).ok(); match val { Some(v) => unsafe { std::env::set_var(key, v) }, diff --git a/src/ai/ollama.rs b/src/ai/ollama.rs index 75c8a02..5316208 100644 --- a/src/ai/ollama.rs +++ b/src/ai/ollama.rs @@ -548,7 +548,16 @@ Capture the key moment or theme. Return ONLY the title, nothing else."#, let title = self .generate_with_images(&prompt, Some(system), None) .await?; - Ok(title.trim().trim_matches('"').to_string()) + // Models decorate despite "Return ONLY the title": quotes, bold + // markers, sometimes a "Title:" label. + use crate::ai::insight_generator::strip_title_markdown; + let cleaned = strip_title_markdown(title.trim()); + let cleaned = cleaned + .strip_prefix("Title:") + .or_else(|| cleaned.strip_prefix("title:")) + .map(strip_title_markdown) + .unwrap_or(cleaned); + Ok(cleaned.to_string()) } /// Generate a summary for a single photo based on its context @@ -1046,13 +1055,14 @@ Analyze the image and use specific details from both the visual content and the } }; - // Validate embedding dimensions (should be 768 for nomic-embed-text:v1.5) + // Validate embedding dimensions (EMBEDDING_DIM; 768 for nomic-embed-text:v1.5) for (i, embedding) in embeddings.iter().enumerate() { - if embedding.len() != 768 { + if embedding.len() != crate::ai::embedding_dim() { log::warn!( - "Unexpected embedding dimensions for item {}: {} (expected 768)", + "Unexpected embedding dimensions for item {}: {} (expected {})", i, - embedding.len() + embedding.len(), + crate::ai::embedding_dim() ); } } diff --git a/src/ai/pronunciation.rs b/src/ai/pronunciation.rs new file mode 100644 index 0000000..b9d7f6e --- /dev/null +++ b/src/ai/pronunciation.rs @@ -0,0 +1,282 @@ +// User-configurable pronunciation overrides for TTS. Chatterbox mispronounces +// place names ("Worcester"), initialisms ("WSL"), and clipped abbreviations +// ("blvd"), so we rewrite them to phonetic spellings before synthesis. +// +// The map lives in a JSON file on the server — a flat object of +// `"written form": "spoken form"` pairs, e.g.: +// +// { +// "Worcester": "Wuster", +// "WSL": "W S L", +// "blvd": "boulevard", +// "Dr.": "Doctor" +// } +// +// Path comes from `TTS_PRONUNCIATIONS_PATH` (default `tts_pronunciations.json` +// in the working directory). A missing file simply disables the feature. The +// file is re-read whenever its mtime changes, so edits apply to the next +// synthesis without a restart; a malformed edit keeps the last good map and +// logs the parse error instead of silently dropping all overrides. +// +// Matching rules: +// - Whole words only — `cat` never rewrites `category`. (Boundaries are only +// asserted next to word characters, so keys like `Dr.` still work.) +// - Smartcase: an all-lowercase key matches case-insensitively; a key with +// any uppercase matches exactly. That lets `worcester` catch every casing +// while `US` (the country) leaves the pronoun `us` alone. +// - Longer keys win over shorter ones (`New York Times` before `New York`). + +use regex::Regex; +use std::collections::HashMap; +use std::path::Path; +use std::sync::{Arc, LazyLock, Mutex as StdMutex}; +use std::time::SystemTime; + +/// A compiled pronunciation map: one alternation regex over every key plus +/// the lookup tables the replacement closure resolves matches against. +#[derive(Default)] +struct CompiledMap { + /// `None` when the map is empty — apply() is then a no-op. + regex: Option<Regex>, + /// Case-sensitive entries, keyed verbatim. + exact: HashMap<String, String>, + /// Case-insensitive entries, keyed lowercased. + folded: HashMap<String, String>, +} + +impl CompiledMap { + fn from_entries(entries: &HashMap<String, String>) -> Self { + let mut keys: Vec<&str> = entries + .keys() + .map(|k| k.as_str()) + .filter(|k| !k.trim().is_empty()) + .collect(); + if keys.is_empty() { + return Self::default(); + } + // Longest key first so overlapping entries prefer the more specific + // one (regex alternation is first-match-wins, not longest-match). + keys.sort_by(|a, b| b.len().cmp(&a.len()).then(a.cmp(b))); + + let mut exact = HashMap::new(); + let mut folded = HashMap::new(); + let alternatives: Vec<String> = keys + .iter() + .map(|key| { + let escaped = regex::escape(key); + // Only assert a word boundary where the key edge is a word + // character — `\b` adjacent to punctuation (e.g. the dot in + // `Dr.`) would otherwise never match. + let lead = if key + .chars() + .next() + .is_some_and(|c| c.is_alphanumeric() || c == '_') + { + r"\b" + } else { + "" + }; + let trail = if key + .chars() + .last() + .is_some_and(|c| c.is_alphanumeric() || c == '_') + { + r"\b" + } else { + "" + }; + let case_sensitive = key.chars().any(|c| c.is_uppercase()); + if case_sensitive { + exact.insert(key.to_string(), entries[*key].clone()); + format!("{lead}{escaped}{trail}") + } else { + folded.insert(key.to_lowercase(), entries[*key].clone()); + format!("{lead}(?i:{escaped}){trail}") + } + }) + .collect(); + + // Escaped fixed strings can't produce an invalid pattern; if one ever + // does, treat the whole map as empty rather than panicking a handler. + let pattern = alternatives.join("|"); + let regex = match Regex::new(&pattern) { + Ok(r) => Some(r), + Err(e) => { + log::error!("pronunciation map failed to compile: {e}"); + None + } + }; + Self { + regex, + exact, + folded, + } + } + + fn apply(&self, text: &str) -> String { + let Some(re) = &self.regex else { + return text.to_string(); + }; + re.replace_all(text, |caps: ®ex::Captures| { + let m = &caps[0]; + self.exact + .get(m) + .or_else(|| self.folded.get(&m.to_lowercase())) + .cloned() + // Unreachable in practice — every alternative came from one + // of the two maps — but never drop the user's text. + .unwrap_or_else(|| m.to_string()) + }) + .into_owned() + } +} + +struct CacheEntry { + mtime: Option<SystemTime>, + compiled: Arc<CompiledMap>, +} + +static CACHE: LazyLock<StdMutex<Option<CacheEntry>>> = LazyLock::new(|| StdMutex::new(None)); + +fn config_path() -> String { + std::env::var("TTS_PRONUNCIATIONS_PATH") + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| "tts_pronunciations.json".to_string()) +} + +/// Load the compiled map, re-reading the file only when its mtime changed +/// since the last call (or it appeared/disappeared). Synthesis is serialized +/// on a single GPU permit, so a stat per call is noise. +fn current_map() -> Arc<CompiledMap> { + let path_s = config_path(); + let path = Path::new(&path_s); + let mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok(); + + let mut cache = CACHE.lock().unwrap(); + if let Some(entry) = cache.as_ref() + && entry.mtime == mtime + { + return entry.compiled.clone(); + } + + let compiled = match mtime { + None => Arc::new(CompiledMap::default()), // no file → no overrides + Some(_) => match std::fs::read_to_string(path) + .map_err(anyhow::Error::from) + .and_then(|s| Ok(serde_json::from_str::<HashMap<String, String>>(&s)?)) + { + Ok(entries) => { + log::info!( + "loaded {} pronunciation override(s) from {path_s}", + entries.len() + ); + Arc::new(CompiledMap::from_entries(&entries)) + } + Err(e) => { + log::error!("failed to load pronunciation map {path_s}: {e}"); + // Keep serving the previous map rather than regressing to + // none mid-edit; still record the new mtime so the error + // logs once per bad save, not once per synthesis. + cache + .as_ref() + .map(|c| c.compiled.clone()) + .unwrap_or_default() + } + }, + }; + *cache = Some(CacheEntry { + mtime, + compiled: compiled.clone(), + }); + compiled +} + +/// Rewrite configured words/abbreviations to their phonetic spellings. +/// Call on cleaned (post-markdown-strip) text, right before synthesis. +pub fn apply_pronunciations(text: &str) -> String { + current_map().apply(text) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn compile(pairs: &[(&str, &str)]) -> CompiledMap { + let entries = pairs + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + CompiledMap::from_entries(&entries) + } + + #[test] + fn empty_map_is_a_noop() { + let m = compile(&[]); + assert_eq!(m.apply("nothing changes"), "nothing changes"); + } + + #[test] + fn replaces_whole_words_only() { + let m = compile(&[("cat", "kitty")]); + assert_eq!(m.apply("the cat sat"), "the kitty sat"); + // No substring rewrites. + assert_eq!(m.apply("the category"), "the category"); + assert_eq!(m.apply("concatenate"), "concatenate"); + } + + #[test] + fn lowercase_keys_match_any_casing() { + let m = compile(&[("worcester", "Wuster")]); + assert_eq!(m.apply("Worcester is nice"), "Wuster is nice"); + assert_eq!(m.apply("in WORCESTER today"), "in Wuster today"); + assert_eq!(m.apply("worcester sauce"), "Wuster sauce"); + } + + #[test] + fn uppercase_keys_match_case_sensitively() { + let m = compile(&[("US", "U S")]); + assert_eq!(m.apply("the US economy"), "the U S economy"); + // The pronoun survives. + assert_eq!(m.apply("join us today"), "join us today"); + } + + #[test] + fn keys_with_punctuation_work() { + // `\b` is only asserted next to word characters, so the trailing dot + // doesn't break matching. + let m = compile(&[("Dr.", "Doctor"), ("blvd", "boulevard")]); + assert_eq!( + m.apply("Dr. Smith on Sunset blvd"), + "Doctor Smith on Sunset boulevard" + ); + } + + #[test] + fn longer_keys_win_over_shorter() { + let m = compile(&[("new york", "Noo York"), ("new york times", "the Times")]); + assert_eq!(m.apply("read the new york times"), "read the the Times"); + assert_eq!(m.apply("visit new york soon"), "visit Noo York soon"); + } + + #[test] + fn multiple_occurrences_all_rewrite() { + let m = compile(&[("wsl", "W S L")]); + assert_eq!(m.apply("WSL and wsl and Wsl"), "W S L and W S L and W S L"); + } + + #[test] + fn replacement_text_is_verbatim() { + // Replacements aren't re-scanned — a value containing another key + // doesn't cascade. + let m = compile(&[("a1", "b2"), ("b2", "c3")]); + assert_eq!(m.apply("a1"), "b2"); + } + + #[test] + fn blank_keys_are_ignored() { + let m = compile(&[("", "x"), (" ", "y"), ("ok", "fine")]); + assert_eq!(m.apply("ok then"), "fine then"); + } +} diff --git a/src/ai/tts.rs b/src/ai/tts.rs index b94be36..08d9dcd 100644 --- a/src/ai/tts.rs +++ b/src/ai/tts.rs @@ -6,7 +6,7 @@ // (audio read directly; video has its audio track extracted via ffmpeg). use actix_multipart::Multipart; -use actix_web::{HttpRequest, HttpResponse, Responder, get, post, web}; +use actix_web::{HttpRequest, HttpResponse, Responder, delete, get, post, web}; use anyhow::Context; use base64::Engine; use bytes::{BufMut, BytesMut}; @@ -15,10 +15,13 @@ use opentelemetry::KeyValue; use opentelemetry::trace::{Span, Status, Tracer}; use regex::Regex; use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{Value, json}; +use std::collections::HashMap; use std::path::Path; -use std::sync::LazyLock; +use std::sync::{LazyLock, Mutex as StdMutex}; +use std::time::{Duration, Instant}; use tokio::sync::Semaphore; +use uuid::Uuid; use crate::data::Claims; use crate::file_types::{is_audio_file, is_video_file}; @@ -40,6 +43,105 @@ const MAX_VOICE_UPLOAD_BYTES: usize = 25 * 1024 * 1024; // 25 MB /// finishes — that's a wrapper limitation; the chunked-queue plan fixes it.) static TTS_PERMIT: LazyLock<Semaphore> = LazyLock::new(|| Semaphore::new(1)); +// --- Voice-list cache -------------------------------------------------------- + +/// Cached raw voice-library JSON. llama-swap's `/upstream/<model>/voices` +/// passthrough spins the TTS model up just to answer a listing — which can +/// evict the resident LLM — so we serve a cached copy and only hit upstream on +/// a cold cache, an explicit `?refresh=1`, or after a voice create/delete +/// invalidates it (the TTS model is already loaded right then anyway). +static VOICES_CACHE: LazyLock<StdMutex<Option<Value>>> = LazyLock::new(|| StdMutex::new(None)); + +fn cached_voices() -> Option<Value> { + VOICES_CACHE.lock().unwrap().clone() +} + +fn store_voices_cache(v: &Value) { + *VOICES_CACHE.lock().unwrap() = Some(v.clone()); +} + +fn invalidate_voices_cache() { + *VOICES_CACHE.lock().unwrap() = None; +} + +// --- Async speech jobs ------------------------------------------------------- +// +// Synthesizing a long insight can take minutes — too long to hang one HTTP +// request from a phone that may background the app or drop the connection. +// Durable variant: POST /tts/speech/jobs returns a job id immediately, the +// synth runs in a spawned task (queuing on TTS_PERMIT instead of fast-failing +// 429), and the client polls GET /tts/speech/jobs/{id} until it collects the +// audio. State is in-memory only (deliberately lighter than the chat +// TurnRegistry): a restart loses jobs, the client surfaces that and retries. + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum TtsJobStatus { + Queued, + Running, + Done, + Error, + Cancelled, +} + +impl TtsJobStatus { + fn is_terminal(self) -> bool { + matches!(self, Self::Done | Self::Error | Self::Cancelled) + } +} + +struct TtsJob { + status: TtsJobStatus, + format: String, + audio_base64: Option<String>, + error: Option<String>, + created_at: Instant, + finished_at: Option<Instant>, + abort: Option<tokio::task::AbortHandle>, +} + +/// Finished jobs linger so a client that lost connectivity can still collect +/// the result on a later poll; anything older than MAX_AGE is dropped outright +/// (aborted first if somehow still running). Swept lazily on each dispatch. +const TTS_JOB_RESULT_TTL: Duration = Duration::from_secs(10 * 60); +const TTS_JOB_MAX_AGE: Duration = Duration::from_secs(30 * 60); + +static TTS_JOBS: LazyLock<StdMutex<HashMap<Uuid, TtsJob>>> = + LazyLock::new(|| StdMutex::new(HashMap::new())); + +fn sweep_stale_jobs(jobs: &mut HashMap<Uuid, TtsJob>, now: Instant) { + jobs.retain(|_, job| { + let result_expired = job + .finished_at + .is_some_and(|t| now.duration_since(t) >= TTS_JOB_RESULT_TTL); + let too_old = now.duration_since(job.created_at) >= TTS_JOB_MAX_AGE; + if too_old && let Some(h) = job.abort.take() { + h.abort(); + } + !(result_expired || too_old) + }); +} + +/// Run `f` against a job, if it still exists. +fn with_job<R>(id: Uuid, f: impl FnOnce(&mut TtsJob) -> R) -> Option<R> { + TTS_JOBS.lock().unwrap().get_mut(&id).map(f) +} + +/// Move a job to a terminal state (first terminal write wins — a cancel that +/// raced a completion keeps the cancel). +fn finish_job(id: Uuid, status: TtsJobStatus, audio_base64: Option<String>, error: Option<String>) { + with_job(id, |job| { + if job.status.is_terminal() { + return; + } + job.status = status; + job.audio_base64 = audio_base64; + job.error = error; + job.finished_at = Some(Instant::now()); + job.abort = None; + }); +} + /// Sanitize a user-supplied voice name. The name is forwarded to Chatterbox /// where it becomes a filename in the voice-library directory, so we restrict /// it to a safe charset (alphanumerics, dash, underscore) — no path @@ -64,6 +166,66 @@ fn sanitize_voice_name(raw: &str) -> Option<String> { Some(cleaned.chars().take(64).collect()) } +/// Reference-clip cap in seconds for voice cloning. Chatterbox is zero-shot — +/// a clean ~10–20s sample is the sweet spot and more rarely helps. Tune via +/// `LLAMA_SWAP_TTS_REF_SECONDS` (default 30). +fn tts_ref_seconds() -> u32 { + std::env::var("LLAMA_SWAP_TTS_REF_SECONDS") + .ok() + .and_then(|s| s.trim().parse::<u32>().ok()) + .filter(|n| *n > 0) + .unwrap_or(30) +} + +/// Tag a (sanitized) voice name with the reference window used to create it: +/// `grandma` → `grandma-30s` (from the start), or `grandma-at1m32s-30s` (30s +/// window starting at 1:32). The tag makes the window visible in the voice +/// list so clones of the same source from different sections can be compared. +/// Skips the append when the name already ends in the same tag; keeps the +/// 64-char bound by truncating the base name, never the tag. +fn append_ref_window(name: &str, start: f64, secs: u32) -> String { + let start_whole = start.round().max(0.0) as u64; + let suffix = if start_whole > 0 { + // ':' isn't in the safe voice-name charset, so 1:32 becomes 1m32s. + let at = if start_whole >= 60 { + format!("at{}m{:02}s", start_whole / 60, start_whole % 60) + } else { + format!("at{start_whole}s") + }; + format!("-{at}-{secs}s") + } else { + format!("-{secs}s") + }; + if name.ends_with(&suffix) { + return name.to_string(); + } + let max_base = 64usize.saturating_sub(suffix.len()); + let base: String = name.chars().take(max_base).collect(); + let base = base.trim_end_matches('-'); + format!("{base}{suffix}") +} + +/// Resolve a caller-supplied reference window into concrete `(start, duration)` +/// seconds for ffmpeg. Start defaults to 0; duration defaults to the +/// `tts_ref_seconds` cap and is clamped to it (the cap is the most audio the +/// TTS backend benefits from, so longer requests are quietly bounded rather +/// than rejected). Non-finite or negative values are the caller's bug → Err. +fn resolve_ref_window( + start_seconds: Option<f64>, + duration_seconds: Option<f64>, +) -> Result<(f64, f64), String> { + let cap = f64::from(tts_ref_seconds()); + let start = start_seconds.unwrap_or(0.0); + if !start.is_finite() || start < 0.0 { + return Err("start_seconds must be a non-negative number".to_string()); + } + let duration = duration_seconds.unwrap_or(cap); + if !duration.is_finite() || duration <= 0.0 { + return Err("duration_seconds must be a positive number".to_string()); + } + Ok((start, duration.min(cap))) +} + /// Optional default voice for synthesis when the request doesn't name one. /// Set `LLAMA_SWAP_TTS_VOICE=m` to read insights in a cloned voice by default. fn default_voice() -> Option<String> { @@ -125,33 +287,42 @@ fn clean_for_tts(input: &str) -> String { s.trim().to_string() } +/// Full text-preparation pipeline for synthesis: markdown/emoji cleanup, then +/// the user's pronunciation overrides (see [`crate::ai::pronunciation`]) on +/// the resulting plain text — after cleanup so word boundaries aren't +/// obscured by `**WSL**`-style markup. +fn prepare_for_tts(input: &str) -> String { + crate::ai::pronunciation::apply_pronunciations(&clean_for_tts(input)) +} + /// Decode an audio/video file to mono 24 kHz WAV via ffmpeg, returning the WAV /// bytes. Chatterbox validates the reference clip by file *extension* and /// rejects several formats (e.g. `.aac`, `.opus`), so we always normalize to -/// WAV regardless of the source container. Capped at 30s — references only need -/// a few seconds of clean speech. -async fn run_ffmpeg_to_wav(input_path: &str) -> anyhow::Result<Vec<u8>> { +/// WAV regardless of the source container. Extracts `duration` seconds starting +/// at `start` (see resolve_ref_window) — references only need a few seconds of +/// clean speech, which may sit anywhere in a long recording. +async fn run_ffmpeg_to_wav(input_path: &str, start: f64, duration: f64) -> anyhow::Result<Vec<u8>> { let out = tempfile::Builder::new() .suffix(".wav") .tempfile() .context("creating temp wav")?; let out_s = out.path().to_string_lossy().to_string(); - // Cap the reference clip length. Chatterbox is zero-shot — a clean ~10–20s - // sample is the sweet spot and more rarely helps — so we use the first N - // seconds. Tune via LLAMA_SWAP_TTS_REF_SECONDS (default 30). - let secs = std::env::var("LLAMA_SWAP_TTS_REF_SECONDS") - .ok() - .and_then(|s| s.trim().parse::<u32>().ok()) - .filter(|n| *n > 0) - .unwrap_or(30) - .to_string(); + let start_s = format!("{start}"); + let secs = format!("{duration}"); + + // -ss before -i is input seeking: fast, and frame accuracy doesn't matter + // for picking a speech window. + let mut args: Vec<&str> = vec!["-y"]; + if start > 0.0 { + args.extend(["-ss", &start_s]); + } + args.extend([ + "-i", input_path, "-vn", "-ac", "1", "-ar", "24000", "-t", &secs, "-f", "wav", &out_s, + ]); let output = tokio::process::Command::new("ffmpeg") - .args([ - "-y", "-i", input_path, "-vn", "-ac", "1", "-ar", "24000", "-t", &secs, "-f", "wav", - &out_s, - ]) + .args(&args) .output() .await .context("spawning ffmpeg")?; @@ -164,7 +335,12 @@ async fn run_ffmpeg_to_wav(input_path: &str) -> anyhow::Result<Vec<u8>> { /// Normalize in-memory upload bytes to WAV: write to a temp file (keeping the /// source extension as an ffmpeg probe hint) then transcode. -async fn transcode_bytes_to_wav(input: &[u8], src_ext: Option<&str>) -> anyhow::Result<Vec<u8>> { +async fn transcode_bytes_to_wav( + input: &[u8], + src_ext: Option<&str>, + start: f64, + duration: f64, +) -> anyhow::Result<Vec<u8>> { let suffix = src_ext .filter(|e| !e.is_empty()) .map(|e| format!(".{e}")) @@ -174,7 +350,7 @@ async fn transcode_bytes_to_wav(input: &[u8], src_ext: Option<&str>) -> anyhow:: .tempfile() .context("creating temp input")?; std::fs::write(in_tmp.path(), input).context("writing temp input")?; - run_ffmpeg_to_wav(&in_tmp.path().to_string_lossy()).await + run_ffmpeg_to_wav(&in_tmp.path().to_string_lossy(), start, duration).await } #[derive(Debug, Deserialize)] @@ -214,7 +390,7 @@ pub async fn tts_speech_handler( let parent_context = extract_context_from_request(&http_request); let mut span = global_tracer().start_with_context("http.tts.speech", &parent_context); - let text = clean_for_tts(&req.text); + let text = prepare_for_tts(&req.text); if text.is_empty() { span.set_status(Status::error("text is required")); return HttpResponse::BadRequest().json(json!({ "error": "text is required" })); @@ -255,6 +431,10 @@ pub async fn tts_speech_handler( })); }; + // Wait for the LLM side to release the GPU before sending — the synthesis + // timeout starts at send, not here (see ai::gpu). + let _gpu = crate::ai::gpu::tts_lease().await; + match client .text_to_speech(&text, voice, format, exaggeration, cfg_weight, temperature) .await @@ -276,16 +456,283 @@ pub async fn tts_speech_handler( } } -/// GET /tts/voices — list the Chatterbox voice library (raw passthrough). +#[derive(Debug, Serialize)] +pub struct TtsJobCreatedResponse { + pub job_id: String, + pub status: TtsJobStatus, +} + +#[derive(Debug, Serialize)] +pub struct TtsJobStatusResponse { + pub job_id: String, + pub status: TtsJobStatus, + pub format: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub audio_base64: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<String>, +} + +/// POST /tts/speech/jobs — durable variant of /tts/speech for long syntheses. +/// Returns 202 + a job id immediately; the synth queues on the single GPU +/// permit (instead of fast-failing 429) and the client polls the job until +/// the audio is ready. +#[post("/tts/speech/jobs")] +pub async fn create_speech_job_handler( + http_request: HttpRequest, + _claims: Claims, + req: web::Json<TtsSpeechRequest>, + app_state: web::Data<AppState>, +) -> impl Responder { + let parent_context = extract_context_from_request(&http_request); + let mut span = + global_tracer().start_with_context("http.tts.speech_job.create", &parent_context); + + let text = prepare_for_tts(&req.text); + if text.is_empty() { + span.set_status(Status::error("text is required")); + return HttpResponse::BadRequest().json(json!({ "error": "text is required" })); + } + if app_state.llamacpp.is_none() { + span.set_status(Status::error("tts backend not configured")); + return HttpResponse::ServiceUnavailable() + .json(json!({ "error": "TTS backend not configured (set LLAMA_SWAP_URL)" })); + } + + let format = req + .format + .as_deref() + .filter(|s| !s.is_empty()) + .unwrap_or("mp3") + .to_string(); + let voice = req + .voice + .clone() + .filter(|s| !s.is_empty()) + .or_else(default_voice); + // Clamp generation knobs to Chatterbox's documented ranges before forwarding. + let exaggeration = req.exaggeration.map(|x| x.clamp(0.25, 2.0)); + let cfg_weight = req.cfg_weight.map(|x| x.clamp(0.0, 1.0)); + let temperature = req.temperature.map(|x| x.clamp(0.05, 5.0)); + + span.set_attribute(KeyValue::new("tts.format", format.clone())); + span.set_attribute(KeyValue::new("tts.has_voice", voice.is_some())); + span.set_attribute(KeyValue::new("tts.text_len", text.len() as i64)); + + let job_id = Uuid::new_v4(); + { + let mut jobs = TTS_JOBS.lock().unwrap(); + sweep_stale_jobs(&mut jobs, Instant::now()); + jobs.insert( + job_id, + TtsJob { + status: TtsJobStatus::Queued, + format: format.clone(), + audio_base64: None, + error: None, + created_at: Instant::now(), + finished_at: None, + abort: None, + }, + ); + } + + let state = app_state.clone(); + let handle = tokio::spawn(async move { + // Queue rather than fast-fail: jobs wait their turn for the GPU. + let _permit = match TTS_PERMIT.acquire().await { + Ok(p) => p, + Err(_) => { + finish_job( + job_id, + TtsJobStatus::Error, + None, + Some("TTS queue closed".to_string()), + ); + return; + } + }; + // Wait for the LLM side to release the GPU too (see ai::gpu) — only + // then does the job count as running. The synthesis timeout starts at + // the HTTP send below, so neither wait burns it, and the client can + // anchor its own deadline to the queued→running transition. + let _gpu = crate::ai::gpu::tts_lease().await; + + // Cancelled while queued — release the permits without synthesizing. + let cancelled = with_job(job_id, |job| { + if job.status == TtsJobStatus::Queued { + job.status = TtsJobStatus::Running; + false + } else { + true + } + }) + .unwrap_or(true); + if cancelled { + return; + } + + let Some(client) = state.llamacpp.as_ref() else { + finish_job( + job_id, + TtsJobStatus::Error, + None, + Some("TTS backend not configured".to_string()), + ); + return; + }; + match client + .text_to_speech( + &text, + voice.as_deref(), + &format, + exaggeration, + cfg_weight, + temperature, + ) + .await + { + Ok(bytes) => { + let audio = base64::engine::general_purpose::STANDARD.encode(&bytes); + finish_job(job_id, TtsJobStatus::Done, Some(audio), None); + } + Err(e) => { + log::error!("TTS job {job_id} failed: {:?}", e); + finish_job( + job_id, + TtsJobStatus::Error, + None, + Some(format!("TTS failed: {e}")), + ); + } + } + }); + // Aborting an already-finished task is a no-op, so this late install is + // safe even if the job raced to completion. + with_job(job_id, |job| { + if !job.status.is_terminal() { + job.abort = Some(handle.abort_handle()); + } + }); + + span.set_status(Status::Ok); + HttpResponse::Accepted().json(TtsJobCreatedResponse { + job_id: job_id.to_string(), + status: TtsJobStatus::Queued, + }) +} + +/// GET /tts/speech/jobs/{id} — poll a speech job; returns the audio once done. +/// 404s after the job expires (results are kept ~10 min past completion). +#[get("/tts/speech/jobs/{id}")] +pub async fn speech_job_status_handler( + http_request: HttpRequest, + _claims: Claims, + path: web::Path<String>, +) -> impl Responder { + let parent_context = extract_context_from_request(&http_request); + let mut span = + global_tracer().start_with_context("http.tts.speech_job.status", &parent_context); + + let Ok(id) = Uuid::parse_str(&path.into_inner()) else { + span.set_status(Status::error("invalid job id")); + return HttpResponse::BadRequest().json(json!({ "error": "invalid job id" })); + }; + let resp = { + let jobs = TTS_JOBS.lock().unwrap(); + jobs.get(&id).map(|job| TtsJobStatusResponse { + job_id: id.to_string(), + status: job.status, + format: job.format.clone(), + audio_base64: job.audio_base64.clone(), + error: job.error.clone(), + }) + }; + match resp { + Some(r) => { + span.set_status(Status::Ok); + HttpResponse::Ok().json(r) + } + None => { + span.set_status(Status::error("job not found")); + HttpResponse::NotFound() + .json(json!({ "error": "TTS job not found (it may have expired)" })) + } + } +} + +/// DELETE /tts/speech/jobs/{id} — cancel a queued/running speech job. Once the +/// upstream GPU job has started it can't be interrupted (same wrapper +/// limitation as the sync path); cancelling stops the wait and discards the +/// result. Cancelling an already-finished job leaves it terminal. +#[delete("/tts/speech/jobs/{id}")] +pub async fn cancel_speech_job_handler( + http_request: HttpRequest, + _claims: Claims, + path: web::Path<String>, +) -> impl Responder { + let parent_context = extract_context_from_request(&http_request); + let mut span = + global_tracer().start_with_context("http.tts.speech_job.cancel", &parent_context); + + let Ok(id) = Uuid::parse_str(&path.into_inner()) else { + span.set_status(Status::error("invalid job id")); + return HttpResponse::BadRequest().json(json!({ "error": "invalid job id" })); + }; + let status = with_job(id, |job| { + if !job.status.is_terminal() { + if let Some(h) = job.abort.take() { + h.abort(); + } + job.status = TtsJobStatus::Cancelled; + job.finished_at = Some(Instant::now()); + } + job.status + }); + match status { + Some(s) => { + span.set_status(Status::Ok); + HttpResponse::Ok().json(json!({ "job_id": id.to_string(), "status": s })) + } + None => { + span.set_status(Status::error("job not found")); + HttpResponse::NotFound() + .json(json!({ "error": "TTS job not found (it may have expired)" })) + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ListVoicesQuery { + /// `?refresh=1` bypasses the voice-list cache and re-queries upstream + /// (which may spin up the TTS model). + #[serde(default)] + pub refresh: Option<String>, +} + +/// GET /tts/voices — list the Chatterbox voice library. Served from an +/// in-memory cache when possible so browsing settings doesn't make llama-swap +/// load the TTS model (and evict the resident LLM); see VOICES_CACHE. #[get("/tts/voices")] pub async fn list_voices_handler( http_request: HttpRequest, _claims: Claims, + query: web::Query<ListVoicesQuery>, app_state: web::Data<AppState>, ) -> impl Responder { let parent_context = extract_context_from_request(&http_request); let mut span = global_tracer().start_with_context("http.tts.voices.list", &parent_context); + let force = query + .refresh + .as_deref() + .is_some_and(|v| matches!(v, "1" | "true" | "yes")); + if !force && let Some(v) = cached_voices() { + span.set_attribute(KeyValue::new("tts.voices_cache_hit", true)); + span.set_status(Status::Ok); + return HttpResponse::Ok().json(v); + } + let Some(client) = app_state.llamacpp.as_ref() else { span.set_status(Status::error("tts backend not configured")); return HttpResponse::ServiceUnavailable() @@ -293,6 +740,8 @@ pub async fn list_voices_handler( }; match client.list_voices().await { Ok(v) => { + store_voices_cache(&v); + span.set_attribute(KeyValue::new("tts.voices_cache_hit", false)); span.set_status(Status::Ok); HttpResponse::Ok().json(v) } @@ -304,8 +753,52 @@ pub async fn list_voices_handler( } } +/// DELETE /tts/voices/{name} — remove a cloned voice from the library. +#[delete("/tts/voices/{name}")] +pub async fn delete_voice_handler( + http_request: HttpRequest, + _claims: Claims, + path: web::Path<String>, + app_state: web::Data<AppState>, +) -> impl Responder { + let parent_context = extract_context_from_request(&http_request); + let mut span = global_tracer().start_with_context("http.tts.voices.delete", &parent_context); + + let Some(client) = app_state.llamacpp.as_ref() else { + span.set_status(Status::error("tts backend not configured")); + return HttpResponse::ServiceUnavailable() + .json(json!({ "error": "TTS backend not configured" })); + }; + // Same charset rule as creation — a name that sanitizes differently was + // never a voice we created, and must not reach the upstream URL. + let raw = path.into_inner(); + let name = match sanitize_voice_name(&raw) { + Some(n) if n == raw => n, + _ => { + span.set_status(Status::error("invalid voice name")); + return HttpResponse::BadRequest().json(json!({ "error": "invalid voice name" })); + } + }; + span.set_attribute(KeyValue::new("tts.voice_name", name.clone())); + + match client.delete_voice(&name).await { + Ok(v) => { + invalidate_voices_cache(); + span.set_status(Status::Ok); + HttpResponse::Ok().json(v) + } + Err(e) => { + span.set_status(Status::error("delete_voice failed")); + log::error!("delete_voice failed: {:?}", e); + HttpResponse::BadGateway().json(json!({ "error": format!("{e}") })) + } + } +} + /// POST /tts/voices/upload — register a cloned voice from an uploaded audio -/// clip. Multipart fields: `voice_name` (text) + a file part (`voice_file`). +/// clip. Multipart fields: `voice_name` (text) + a file part (`voice_file`), +/// plus optional `start_seconds` / `duration_seconds` (text) selecting which +/// window of a longer recording becomes the reference clip. #[post("/tts/voices/upload")] pub async fn create_voice_upload_handler( http_request: HttpRequest, @@ -323,6 +816,8 @@ pub async fn create_voice_upload_handler( }; let mut voice_name: Option<String> = None; + let mut start_field: Option<String> = None; + let mut duration_field: Option<String> = None; let mut file_bytes = BytesMut::new(); let mut filename = "voice.wav".to_string(); @@ -347,22 +842,57 @@ pub async fn create_voice_upload_handler( } file_bytes.put(data); } - } else if name_opt.as_deref() == Some("voice_name") { + } else if matches!( + name_opt.as_deref(), + Some("voice_name" | "start_seconds" | "duration_seconds") + ) { + let field = name_opt.as_deref().unwrap().to_string(); let mut buf = BytesMut::new(); while let Some(Ok(data)) = part.next().await { buf.put(data); } - voice_name = Some(String::from_utf8_lossy(&buf).trim().to_string()); + let text = String::from_utf8_lossy(&buf).trim().to_string(); + match field.as_str() { + "voice_name" => voice_name = Some(text), + "start_seconds" => start_field = Some(text), + _ => duration_field = Some(text), + } } else { while let Some(Ok(_)) = part.next().await {} } } + // Empty text parts are treated as absent; anything else must parse, so a + // client bug ("abc") fails loudly instead of silently cloning from 0s. + let parse_secs = |field: Option<&String>, name: &str| -> Result<Option<f64>, String> { + match field.map(|s| s.as_str()).filter(|s| !s.is_empty()) { + None => Ok(None), + Some(s) => s + .parse::<f64>() + .map(Some) + .map_err(|_| format!("{name} must be a number of seconds")), + } + }; + let window = parse_secs(start_field.as_ref(), "start_seconds").and_then(|start| { + parse_secs(duration_field.as_ref(), "duration_seconds") + .and_then(|dur| resolve_ref_window(start, dur)) + }); + let (ref_start, ref_duration) = match window { + Ok(w) => w, + Err(msg) => { + span.set_status(Status::error("invalid reference window")); + return HttpResponse::BadRequest().json(json!({ "error": msg })); + } + }; + let Some(name) = voice_name.as_deref().and_then(sanitize_voice_name) else { span.set_status(Status::error("voice_name is required")); return HttpResponse::BadRequest() .json(json!({ "error": "voice_name is required (alphanumerics, - and _ only)" })); }; + // Tag the name with the ref-clip length (e.g. `grandma-30s`) so the + // library shows which reference length produced each clone. + let name = append_ref_window(&name, ref_start, ref_duration.round().max(1.0) as u32); if file_bytes.is_empty() { span.set_status(Status::error("voice_file is required")); return HttpResponse::BadRequest().json(json!({ "error": "voice_file is required" })); @@ -373,21 +903,23 @@ pub async fn create_voice_upload_handler( // Normalize to WAV so any device format (e.g. .aac / .opus, which Chatterbox // rejects by extension) is accepted. let src_ext = Path::new(&filename).extension().and_then(|e| e.to_str()); - let wav = match transcode_bytes_to_wav(file_bytes.as_ref(), src_ext).await { - Ok(w) => w, - Err(e) => { - span.set_status(Status::error("audio decode failed")); - log::error!("voice upload transcode failed: {:?}", e); - return HttpResponse::BadRequest() - .json(json!({ "error": "couldn't decode that audio file" })); - } - }; + let wav = + match transcode_bytes_to_wav(file_bytes.as_ref(), src_ext, ref_start, ref_duration).await { + Ok(w) => w, + Err(e) => { + span.set_status(Status::error("audio decode failed")); + log::error!("voice upload transcode failed: {:?}", e); + return HttpResponse::BadRequest() + .json(json!({ "error": "couldn't decode that audio file" })); + } + }; match client .create_voice(&name, wav, "reference.wav", "audio/wav") .await { Ok(v) => { + invalidate_voices_cache(); span.set_status(Status::Ok); HttpResponse::Ok().json(v) } @@ -406,11 +938,19 @@ pub struct CreateVoiceFromLibraryRequest { pub path: String, #[serde(default)] pub library: Option<String>, + /// Offset into the source where the reference window begins (default 0) — + /// lets the client pick the clean-speech section of a long recording. + #[serde(default)] + pub start_seconds: Option<f64>, + /// Reference window length; clamped to LLAMA_SWAP_TTS_REF_SECONDS. + #[serde(default)] + pub duration_seconds: Option<f64>, } /// POST /tts/voices/from-library — register a cloned voice from a file already /// in a library. Audio and video alike are ffmpeg-normalized to a mono 24 kHz -/// WAV reference clip (length capped by LLAMA_SWAP_TTS_REF_SECONDS). +/// WAV reference clip (window selected by start/duration_seconds, length +/// capped by LLAMA_SWAP_TTS_REF_SECONDS). #[post("/tts/voices/from-library")] pub async fn create_voice_from_library_handler( http_request: HttpRequest, @@ -432,6 +972,18 @@ pub async fn create_voice_from_library_handler( return HttpResponse::BadRequest() .json(json!({ "error": "voice_name is required (alphanumerics, - and _ only)" })); }; + let (ref_start, ref_duration) = + match resolve_ref_window(req.start_seconds, req.duration_seconds) { + Ok(w) => w, + Err(msg) => { + span.set_status(Status::error("invalid reference window")); + return HttpResponse::BadRequest().json(json!({ "error": msg })); + } + }; + // Tag the name with the ref-clip length (e.g. `grandma-30s`) so the + // library shows which reference length produced each clone. + let voice_name = + append_ref_window(&voice_name, ref_start, ref_duration.round().max(1.0) as u32); let library = match libraries::resolve_library_param(&app_state, req.library.as_deref()) { Ok(Some(l)) => l, @@ -460,7 +1012,7 @@ pub async fn create_voice_from_library_handler( } span.set_attribute(KeyValue::new("tts.voice_name", voice_name.clone())); - let wav = match prepare_reference_audio(&abs).await { + let wav = match prepare_reference_audio(&abs, ref_start, ref_duration).await { Ok(b) => b, Err(e) => { span.set_status(Status::error("audio decode failed")); @@ -475,6 +1027,7 @@ pub async fn create_voice_from_library_handler( .await { Ok(v) => { + invalidate_voices_cache(); span.set_status(Status::Ok); HttpResponse::Ok().json(v) } @@ -489,8 +1042,8 @@ pub async fn create_voice_from_library_handler( /// Read a library file (audio or video) as a Chatterbox-ready reference: ffmpeg /// decodes/extracts its audio to mono 24 kHz WAV. Reading straight from the /// library path avoids slurping a (possibly large) video into memory. -async fn prepare_reference_audio(abs: &Path) -> anyhow::Result<Vec<u8>> { - run_ffmpeg_to_wav(&abs.to_string_lossy()).await +async fn prepare_reference_audio(abs: &Path, start: f64, duration: f64) -> anyhow::Result<Vec<u8>> { + run_ffmpeg_to_wav(&abs.to_string_lossy(), start, duration).await } #[cfg(test)] @@ -534,6 +1087,151 @@ mod tests { assert_eq!(sanitize_voice_name(&long).unwrap().len(), 64); } + #[test] + fn append_ref_window_tags_name() { + assert_eq!(append_ref_window("grandma", 0.0, 30), "grandma-30s"); + assert_eq!(append_ref_window("voice_01", 0.0, 15), "voice_01-15s"); + } + + #[test] + fn append_ref_window_includes_nonzero_start() { + // Sub-minute starts stay in seconds; longer ones read as XmYYs since + // ':' isn't allowed in voice names. + assert_eq!(append_ref_window("grandma", 45.0, 30), "grandma-at45s-30s"); + assert_eq!( + append_ref_window("grandma", 92.4, 30), + "grandma-at1m32s-30s" + ); + assert_eq!( + append_ref_window("grandma", 600.0, 12), + "grandma-at10m00s-12s" + ); + // A start that rounds to zero is "from the start". + assert_eq!(append_ref_window("grandma", 0.3, 30), "grandma-30s"); + } + + #[test] + fn append_ref_window_is_idempotent_for_same_window() { + assert_eq!(append_ref_window("grandma-30s", 0.0, 30), "grandma-30s"); + assert_eq!( + append_ref_window("grandma-at45s-30s", 45.0, 30), + "grandma-at45s-30s" + ); + // A different window still appends — that's the comparison use-case. + assert_eq!(append_ref_window("grandma-15s", 0.0, 30), "grandma-15s-30s"); + assert_eq!( + append_ref_window("grandma-30s", 45.0, 30), + "grandma-30s-at45s-30s" + ); + } + + #[test] + fn append_ref_window_keeps_64_char_bound() { + let long = "a".repeat(64); + let tagged = append_ref_window(&long, 0.0, 30); + assert_eq!(tagged.len(), 64); + assert!(tagged.ends_with("-30s")); + + let tagged = append_ref_window(&long, 92.0, 30); + assert_eq!(tagged.len(), 64); + assert!(tagged.ends_with("-at1m32s-30s")); + } + + #[test] + fn resolve_ref_window_defaults_to_start_of_clip_at_cap_length() { + // Reads the live cap rather than mutating LLAMA_SWAP_TTS_REF_SECONDS: + // env mutation flakes under the parallel suite (see env_dispatch). + let cap = f64::from(tts_ref_seconds()); + assert_eq!(resolve_ref_window(None, None), Ok((0.0, cap))); + } + + #[test] + fn resolve_ref_window_accepts_offset_and_clamps_duration() { + let cap = f64::from(tts_ref_seconds()); + assert_eq!(resolve_ref_window(Some(92.5), None), Ok((92.5, cap))); + assert_eq!(resolve_ref_window(Some(10.0), Some(12.0)), Ok((10.0, 12.0))); + // Longer-than-cap windows are bounded, not rejected. + assert_eq!(resolve_ref_window(None, Some(cap + 100.0)), Ok((0.0, cap))); + } + + #[test] + fn resolve_ref_window_rejects_garbage() { + assert!(resolve_ref_window(Some(-1.0), None).is_err()); + assert!(resolve_ref_window(Some(f64::NAN), None).is_err()); + assert!(resolve_ref_window(Some(f64::INFINITY), None).is_err()); + assert!(resolve_ref_window(None, Some(0.0)).is_err()); + assert!(resolve_ref_window(None, Some(-5.0)).is_err()); + assert!(resolve_ref_window(None, Some(f64::NAN)).is_err()); + } + + #[test] + fn sweep_drops_expired_results_and_keeps_live_jobs() { + let now = Instant::now(); + let mk = |status: TtsJobStatus, created: Instant, finished: Option<Instant>| TtsJob { + status, + format: "mp3".into(), + audio_base64: None, + error: None, + created_at: created, + finished_at: finished, + abort: None, + }; + let mut jobs = HashMap::new(); + let live = Uuid::new_v4(); + let fresh_done = Uuid::new_v4(); + let stale_done = Uuid::new_v4(); + jobs.insert(live, mk(TtsJobStatus::Running, now, None)); + jobs.insert( + fresh_done, + mk(TtsJobStatus::Done, now, Some(now - Duration::from_secs(60))), + ); + jobs.insert( + stale_done, + mk( + TtsJobStatus::Done, + now - TTS_JOB_MAX_AGE / 2, + Some(now - TTS_JOB_RESULT_TTL), + ), + ); + + sweep_stale_jobs(&mut jobs, now); + assert!(jobs.contains_key(&live)); + assert!(jobs.contains_key(&fresh_done)); + assert!(!jobs.contains_key(&stale_done)); + } + + #[test] + fn sweep_drops_jobs_past_max_age_even_if_unfinished() { + let now = Instant::now(); + let mut jobs = HashMap::new(); + let ancient = Uuid::new_v4(); + jobs.insert( + ancient, + TtsJob { + status: TtsJobStatus::Running, + format: "mp3".into(), + audio_base64: None, + error: None, + created_at: now - TTS_JOB_MAX_AGE, + finished_at: None, + abort: None, + }, + ); + sweep_stale_jobs(&mut jobs, now); + assert!(jobs.is_empty()); + } + + #[test] + fn voices_cache_roundtrip_and_invalidation() { + invalidate_voices_cache(); + assert!(cached_voices().is_none()); + let v = json!({ "voices": [{ "name": "m-30s" }], "count": 1 }); + store_voices_cache(&v); + assert_eq!(cached_voices(), Some(v)); + invalidate_voices_cache(); + assert!(cached_voices().is_none()); + } + #[test] fn clean_for_tts_strips_markdown() { assert_eq!( diff --git a/src/bin/import_calendar.rs b/src/bin/import_calendar.rs index c8f941b..98b3f37 100644 --- a/src/bin/import_calendar.rs +++ b/src/bin/import_calendar.rs @@ -1,7 +1,7 @@ use anyhow::{Context, Result}; use chrono::Utc; use clap::Parser; -use image_api::ai::ollama::OllamaClient; +use image_api::ai::LocalLlm; use image_api::bin_progress; use image_api::database::calendar_dao::{InsertCalendarEvent, SqliteCalendarEventDao}; use image_api::parsers::ical_parser::parse_ics_file; @@ -44,22 +44,10 @@ async fn main() -> Result<()> { let context = opentelemetry::Context::current(); - let ollama = if args.generate_embeddings { - let primary_url = dotenv::var("OLLAMA_PRIMARY_URL") - .or_else(|_| dotenv::var("OLLAMA_URL")) - .unwrap_or_else(|_| "http://localhost:11434".to_string()); - let fallback_url = dotenv::var("OLLAMA_FALLBACK_URL").ok(); - let primary_model = dotenv::var("OLLAMA_PRIMARY_MODEL") - .or_else(|_| dotenv::var("OLLAMA_MODEL")) - .unwrap_or_else(|_| "nomic-embed-text:v1.5".to_string()); - let fallback_model = dotenv::var("OLLAMA_FALLBACK_MODEL").ok(); - - Some(OllamaClient::new( - primary_url, - fallback_url, - primary_model, - fallback_model, - )) + // LocalLlm dispatches per LLM_BACKEND, so embeddings written here land + // in the same vector space the query side searches. + let llm = if args.generate_embeddings { + Some(LocalLlm::from_env()) } else { None }; @@ -90,7 +78,7 @@ async fn main() -> Result<()> { } // Generate embedding if requested (blocking call) - let embedding = if let Some(ref ollama_client) = ollama { + let embedding = if let Some(ref llm) = llm { let text = format!( "{} {} {}", event.summary, @@ -100,7 +88,7 @@ async fn main() -> Result<()> { match tokio::task::block_in_place(|| { tokio::runtime::Handle::current() - .block_on(async { ollama_client.generate_embedding(&text).await }) + .block_on(async { llm.embed_document(&text).await }) }) { Ok(emb) => Some(emb), Err(e) => { diff --git a/src/bin/import_search_history.rs b/src/bin/import_search_history.rs index 21af659..93605cc 100644 --- a/src/bin/import_search_history.rs +++ b/src/bin/import_search_history.rs @@ -1,7 +1,7 @@ use anyhow::{Context, Result}; use chrono::Utc; use clap::Parser; -use image_api::ai::ollama::OllamaClient; +use image_api::ai::LocalLlm; use image_api::bin_progress; use image_api::database::search_dao::{InsertSearchRecord, SqliteSearchHistoryDao}; use image_api::parsers::search_html_parser::parse_search_html; @@ -38,16 +38,9 @@ async fn main() -> Result<()> { info!("Found {} search records", searches.len()); - let primary_url = dotenv::var("OLLAMA_PRIMARY_URL") - .or_else(|_| dotenv::var("OLLAMA_URL")) - .unwrap_or_else(|_| "http://localhost:11434".to_string()); - let fallback_url = dotenv::var("OLLAMA_FALLBACK_URL").ok(); - let primary_model = dotenv::var("OLLAMA_PRIMARY_MODEL") - .or_else(|_| dotenv::var("OLLAMA_MODEL")) - .unwrap_or_else(|_| "nomic-embed-text:v1.5".to_string()); - let fallback_model = dotenv::var("OLLAMA_FALLBACK_MODEL").ok(); - - let ollama = OllamaClient::new(primary_url, fallback_url, primary_model, fallback_model); + // LocalLlm dispatches per LLM_BACKEND, so embeddings written here land + // in the same vector space the query side searches. + let llm = LocalLlm::from_env(); let context = opentelemetry::Context::current(); let mut inserted_count = 0usize; @@ -67,12 +60,11 @@ async fn main() -> Result<()> { let pb_for_warn = pb.clone(); let embeddings_result = tokio::task::spawn({ - let ollama_client = ollama.clone(); + let llm = llm.clone(); async move { - // Generate embeddings in parallel for the batch let mut embeddings = Vec::new(); for query in &queries { - match ollama_client.generate_embedding(query).await { + match llm.embed_document(query).await { Ok(emb) => embeddings.push(Some(emb)), Err(e) => { pb_for_warn.println(format!("embedding failed for '{}': {}", query, e)); diff --git a/src/bin/reembed_embeddings.rs b/src/bin/reembed_embeddings.rs new file mode 100644 index 0000000..a2fdd4c --- /dev/null +++ b/src/bin/reembed_embeddings.rs @@ -0,0 +1,465 @@ +//! Re-embed stored corpora through `LocalLlm`, i.e. the same +//! `LLM_BACKEND` dispatch the query side uses. The original import / +//! backfill tools always embedded via Ollama, so a deploy running +//! `LLM_BACKEND=llamacpp` queries vector spaces the corpora may not live +//! in. Three tables share the problem and are all covered here: +//! +//! - `daily_conversation_summaries` — re-embeds +//! `strip_summary_boilerplate(summary)` (what the original job fed the +//! embedder); also rewrites `model_version`. +//! - `calendar_events` — re-embeds "summary description location" exactly +//! as `import_calendar` does; rows without an embedding are skipped (the +//! import only embeds under `--generate-embeddings`). +//! - `search_history` — re-embeds the raw query text. +//! - `entities` (knowledge graph) — re-embeds "name description" exactly as +//! `tool_store_entity` does; embedding-less rows are skipped (embedding +//! is best-effort at store time). +//! +//! Source text is untouched — only vectors are rewritten. The old↔new +//! cosine report doubles as a diagnostic: ~1.0 means both backends already +//! shared a space (re-embedding was a no-op); low values confirm the +//! mismatch this tool exists to fix. + +use anyhow::{Context, Result}; +use clap::Parser; +use diesel::prelude::*; +use diesel::sql_query; +use diesel::sqlite::SqliteConnection; +use image_api::ai::{LocalLlm, strip_summary_boilerplate}; +use image_api::bin_progress; +use std::env; + +#[derive(Parser, Debug)] +#[command(author, version, about = "Re-embed stored corpora via the configured LLM_BACKEND", long_about = None)] +struct Args { + /// Comma-separated tables to process: summaries, calendar, search, entities + #[arg(long, default_value = "summaries,calendar,search,entities")] + tables: String, + + /// Only process the first N rows per table (smoke test) + #[arg(long)] + limit: Option<usize>, + + /// Compute embeddings and report old↔new similarity without writing + #[arg(long, default_value_t = false)] + dry_run: bool, +} + +#[derive(QueryableByName)] +struct SummaryRow { + #[diesel(sql_type = diesel::sql_types::Integer)] + id: i32, + #[diesel(sql_type = diesel::sql_types::Text)] + summary: String, + #[diesel(sql_type = diesel::sql_types::Binary)] + embedding: Vec<u8>, + #[diesel(sql_type = diesel::sql_types::Text)] + model_version: String, +} + +#[derive(QueryableByName)] +struct CalendarRow { + #[diesel(sql_type = diesel::sql_types::Integer)] + id: i32, + #[diesel(sql_type = diesel::sql_types::Text)] + summary: String, + #[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)] + description: Option<String>, + #[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)] + location: Option<String>, + #[diesel(sql_type = diesel::sql_types::Binary)] + embedding: Vec<u8>, +} + +#[derive(QueryableByName)] +struct SearchRow { + #[diesel(sql_type = diesel::sql_types::BigInt)] + id: i64, + #[diesel(sql_type = diesel::sql_types::Text)] + query: String, + #[diesel(sql_type = diesel::sql_types::Binary)] + embedding: Vec<u8>, +} + +#[derive(QueryableByName)] +struct EntityRow { + #[diesel(sql_type = diesel::sql_types::Integer)] + id: i32, + #[diesel(sql_type = diesel::sql_types::Text)] + name: String, + #[diesel(sql_type = diesel::sql_types::Text)] + description: String, + #[diesel(sql_type = diesel::sql_types::Binary)] + embedding: Vec<u8>, +} + +/// One unit of re-embed work, normalized across tables. +struct WorkItem { + /// Row key, as i64 so both i32 ids and rowids fit. + id: i64, + /// Text fed to the embedder — must match what the original writer used. + text: String, + /// Existing vector bytes, for the old↔new similarity report. + old_embedding: Vec<u8>, +} + +fn deserialize_vector(bytes: &[u8]) -> Option<Vec<f32>> { + if !bytes.len().is_multiple_of(4) { + return None; + } + Some( + bytes + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(), + ) +} + +fn serialize_vector(vec: &[f32]) -> Vec<u8> { + vec.iter().flat_map(|f| f.to_le_bytes()).collect() +} + +fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + if a.len() != b.len() { + return 0.0; + } + let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum(); + let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt(); + let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt(); + if mag_a == 0.0 || mag_b == 0.0 { + return 0.0; + } + dot / (mag_a * mag_b) +} + +/// Embed `text`, halving it on "input too large" errors until it fits the +/// server's physical batch (`--ubatch-size`). Mirrors the silent truncation +/// Ollama applied when these corpora were first embedded — llama-server +/// returns a 500 instead — except here it's surfaced via the returned flag. +/// Returns `(embedding, truncated)`. +async fn embed_with_truncation(llm: &LocalLlm, text: &str) -> Result<(Vec<f32>, bool)> { + let mut text = text.to_string(); + let mut truncated = false; + loop { + match llm.embed_document(&text).await { + Ok(emb) => return Ok((emb, truncated)), + Err(e) + if e.to_string().contains("too large to process") && text.chars().count() > 64 => + { + let keep = text.chars().count() / 2; + text = text.chars().take(keep).collect(); + truncated = true; + } + Err(e) => return Err(e), + } + } +} + +/// Re-embed `items`, writing each new vector via `update`. Returns the +/// old↔new cosines for the similarity report. +async fn reembed_table( + conn: &mut SqliteConnection, + llm: &LocalLlm, + label: &str, + items: Vec<WorkItem>, + dry_run: bool, + update: impl Fn(&mut SqliteConnection, i64, Vec<u8>) -> Result<()>, +) -> Result<Vec<f32>> { + println!("\n[{}] re-embedding {} rows...", label, items.len()); + let pb = bin_progress::determinate(items.len() as u64, format!("re-embedding {}", label)); + + let mut sims: Vec<f32> = Vec::with_capacity(items.len()); + let mut updated = 0usize; + let mut failed = 0usize; + let mut truncated_count = 0usize; + + for item in &items { + let new_emb = match embed_with_truncation(llm, &item.text).await { + Ok((e, truncated)) => { + if truncated { + truncated_count += 1; + pb.println(format!( + "⚠ {} id={}: input exceeded the embed server's batch size, \ + truncated before embedding", + label, item.id + )); + } + e + } + Err(e) => { + pb.inc(1); + failed += 1; + eprintln!("✗ {} id={}: {}", label, item.id, e); + continue; + } + }; + + // The whole pipeline (DAO checks, stored corpora) assumes + // EMBEDDING_DIM dims. A mismatch means the active embed slot is not + // serving the configured model — stop rather than corrupt the table. + anyhow::ensure!( + new_emb.len() == image_api::ai::embedding_dim(), + "backend returned {}-dim embedding (expected {}) — '{}' does not \ + match the configured EMBEDDING_DIM", + new_emb.len(), + image_api::ai::embedding_dim(), + llm.embedding_model_version() + ); + + if let Some(old_emb) = deserialize_vector(&item.old_embedding) { + sims.push(cosine_similarity(&old_emb, &new_emb)); + } + + if !dry_run { + update(conn, item.id, serialize_vector(&new_emb)) + .with_context(|| format!("updating {} id={}", label, item.id))?; + } + updated += 1; + pb.inc(1); + } + pb.finish_and_clear(); + + println!( + "[{}] {} re-embedded ({} truncated), {} failed", + label, updated, truncated_count, failed + ); + Ok(sims) +} + +fn report_similarity(label: &str, mut sims: Vec<f32>) { + if sims.is_empty() { + println!("[{}] no old↔new pairs to compare", label); + return; + } + sims.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let mean: f32 = sims.iter().sum::<f32>() / sims.len() as f32; + let median = sims[sims.len() / 2]; + println!( + "[{}] old↔new cosine over identical text: min={:.3} median={:.3} mean={:.3} max={:.3}", + label, + sims.first().unwrap(), + median, + mean, + sims.last().unwrap() + ); + if median > 0.98 { + println!( + "[{}] → old and new backends agree (~same vector space); poor search \ + results are coming from something else (prefixes, thresholds, corpus).", + label + ); + } else if median > 0.9 { + println!( + "[{}] → same model family but measurably different vectors \ + (quantization / runtime drift); re-embedding was worthwhile.", + label + ); + } else { + println!( + "[{}] → vector-space mismatch confirmed — queries were searching a \ + different space than the corpus. This re-embed should fix it.", + label + ); + } +} + +#[tokio::main] +async fn main() -> Result<()> { + dotenv::dotenv().ok(); + env_logger::init(); + let args = Args::parse(); + + let tables: Vec<&str> = args.tables.split(',').map(|t| t.trim()).collect(); + for t in &tables { + anyhow::ensure!( + matches!(*t, "summaries" | "calendar" | "search" | "entities"), + "unknown table '{}' — expected summaries, calendar, search, entities", + t + ); + } + + let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| "auth.db".to_string()); + println!("Database: {}", database_url); + + let mut conn = SqliteConnection::establish(&database_url) + .with_context(|| format!("connecting to {}", database_url))?; + + let llm = LocalLlm::from_env(); + let model_version = llm.embedding_model_version(); + println!("Embedding via '{}'", model_version); + if args.dry_run { + println!("DRY RUN — no rows will be written"); + } + + if tables.contains(&"summaries") { + let mut rows: Vec<SummaryRow> = sql_query( + "SELECT id, summary, embedding, model_version + FROM daily_conversation_summaries ORDER BY date", + ) + .load(&mut conn) + .context("loading daily summaries")?; + if let Some(limit) = args.limit { + rows.truncate(limit); + } + if let Some(first) = rows.first() { + println!( + "\n[summaries] previous model_version '{}' → '{}'", + first.model_version, model_version + ); + } + let items = rows + .into_iter() + .map(|r| WorkItem { + id: r.id as i64, + text: strip_summary_boilerplate(&r.summary), + old_embedding: r.embedding, + }) + .collect(); + let mv = model_version.clone(); + let sims = reembed_table( + &mut conn, + &llm, + "summaries", + items, + args.dry_run, + move |conn, id, emb| { + sql_query( + "UPDATE daily_conversation_summaries + SET embedding = ?1, model_version = ?2 WHERE id = ?3", + ) + .bind::<diesel::sql_types::Binary, _>(emb) + .bind::<diesel::sql_types::Text, _>(&mv) + .bind::<diesel::sql_types::Integer, _>(id as i32) + .execute(conn)?; + Ok(()) + }, + ) + .await?; + report_similarity("summaries", sims); + } + + if tables.contains(&"calendar") { + let mut rows: Vec<CalendarRow> = sql_query( + "SELECT id, summary, description, location, embedding + FROM calendar_events WHERE embedding IS NOT NULL ORDER BY id", + ) + .load(&mut conn) + .context("loading calendar events")?; + if let Some(limit) = args.limit { + rows.truncate(limit); + } + let items = rows + .into_iter() + .map(|r| WorkItem { + id: r.id as i64, + // Same text construction as import_calendar. + text: format!( + "{} {} {}", + r.summary, + r.description.as_deref().unwrap_or(""), + r.location.as_deref().unwrap_or("") + ), + old_embedding: r.embedding, + }) + .collect(); + let sims = reembed_table( + &mut conn, + &llm, + "calendar", + items, + args.dry_run, + |conn, id, emb| { + sql_query("UPDATE calendar_events SET embedding = ?1 WHERE id = ?2") + .bind::<diesel::sql_types::Binary, _>(emb) + .bind::<diesel::sql_types::Integer, _>(id as i32) + .execute(conn)?; + Ok(()) + }, + ) + .await?; + report_similarity("calendar", sims); + } + + if tables.contains(&"search") { + let mut rows: Vec<SearchRow> = sql_query( + "SELECT rowid AS id, query, embedding + FROM search_history ORDER BY rowid", + ) + .load(&mut conn) + .context("loading search history")?; + if let Some(limit) = args.limit { + rows.truncate(limit); + } + let items = rows + .into_iter() + .map(|r| WorkItem { + id: r.id, + text: r.query, + old_embedding: r.embedding, + }) + .collect(); + let sims = reembed_table( + &mut conn, + &llm, + "search", + items, + args.dry_run, + |conn, id, emb| { + sql_query("UPDATE search_history SET embedding = ?1 WHERE rowid = ?2") + .bind::<diesel::sql_types::Binary, _>(emb) + .bind::<diesel::sql_types::BigInt, _>(id) + .execute(conn)?; + Ok(()) + }, + ) + .await?; + report_similarity("search", sims); + } + + if tables.contains(&"entities") { + let mut rows: Vec<EntityRow> = sql_query( + "SELECT id, name, description, embedding + FROM entities WHERE embedding IS NOT NULL ORDER BY id", + ) + .load(&mut conn) + .context("loading knowledge entities")?; + if let Some(limit) = args.limit { + rows.truncate(limit); + } + let items = rows + .into_iter() + .map(|r| WorkItem { + id: r.id as i64, + // Same text construction as tool_store_entity. + text: format!("{} {}", r.name, r.description), + old_embedding: r.embedding, + }) + .collect(); + let sims = reembed_table( + &mut conn, + &llm, + "entities", + items, + args.dry_run, + |conn, id, emb| { + sql_query("UPDATE entities SET embedding = ?1 WHERE id = ?2") + .bind::<diesel::sql_types::Binary, _>(emb) + .bind::<diesel::sql_types::Integer, _>(id as i32) + .execute(conn)?; + Ok(()) + }, + ) + .await?; + report_similarity("entities", sims); + } + + println!( + "\n{}", + if args.dry_run { + "Dry run complete" + } else { + "Done" + } + ); + Ok(()) +} diff --git a/src/database/calendar_dao.rs b/src/database/calendar_dao.rs index 4ebd21c..f739d87 100644 --- a/src/database/calendar_dao.rs +++ b/src/database/calendar_dao.rs @@ -222,11 +222,12 @@ impl CalendarEventDao for SqliteCalendarEventDao { // Validate embedding dimensions if provided if let Some(ref emb) = event.embedding - && emb.len() != 768 + && emb.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid embedding dimensions: {} (expected 768)", - emb.len() + "Invalid embedding dimensions: {} (expected {})", + emb.len(), + crate::ai::embedding_dim() )); } @@ -293,7 +294,7 @@ impl CalendarEventDao for SqliteCalendarEventDao { for event in events { // Validate embedding if provided if let Some(ref emb) = event.embedding - && emb.len() != 768 + && emb.len() != crate::ai::embedding_dim() { log::warn!( "Skipping event with invalid embedding dimensions: {}", @@ -385,10 +386,11 @@ impl CalendarEventDao for SqliteCalendarEventDao { trace_db_call(context, "query", "find_similar_events", |_span| { let mut conn = self.connection.lock().expect("Unable to get CalendarEventDao"); - if query_embedding.len() != 768 { + if query_embedding.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid query embedding dimensions: {} (expected 768)", - query_embedding.len() + "Invalid query embedding dimensions: {} (expected {})", + query_embedding.len(), + crate::ai::embedding_dim() )); } @@ -461,10 +463,11 @@ impl CalendarEventDao for SqliteCalendarEventDao { // Step 2: If query embedding provided, rank by semantic similarity if let Some(query_emb) = query_embedding { - if query_emb.len() != 768 { + if query_emb.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid query embedding dimensions: {} (expected 768)", - query_emb.len() + "Invalid query embedding dimensions: {} (expected {})", + query_emb.len(), + crate::ai::embedding_dim() )); } diff --git a/src/database/daily_summary_dao.rs b/src/database/daily_summary_dao.rs index 521c1a5..af1d16f 100644 --- a/src/database/daily_summary_dao.rs +++ b/src/database/daily_summary_dao.rs @@ -150,10 +150,11 @@ impl DailySummaryDao for SqliteDailySummaryDao { .expect("Unable to get DailySummaryDao"); // Validate embedding dimensions - if summary.embedding.len() != 768 { + if summary.embedding.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid embedding dimensions: {} (expected 768)", - summary.embedding.len() + "Invalid embedding dimensions: {} (expected {})", + summary.embedding.len(), + crate::ai::embedding_dim() )); } @@ -202,10 +203,11 @@ impl DailySummaryDao for SqliteDailySummaryDao { trace_db_call(context, "query", "find_similar_summaries", |_span| { let mut conn = self.connection.lock().expect("Unable to get DailySummaryDao"); - if query_embedding.len() != 768 { + if query_embedding.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid query embedding dimensions: {} (expected 768)", - query_embedding.len() + "Invalid query embedding dimensions: {} (expected {})", + query_embedding.len(), + crate::ai::embedding_dim() )); } @@ -299,10 +301,11 @@ impl DailySummaryDao for SqliteDailySummaryDao { trace_db_call(context, "query", "find_similar_summaries_with_time_weight", |_span| { let mut conn = self.connection.lock().expect("Unable to get DailySummaryDao"); - if query_embedding.len() != 768 { + if query_embedding.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid query embedding dimensions: {} (expected 768)", - query_embedding.len() + "Invalid query embedding dimensions: {} (expected {})", + query_embedding.len(), + crate::ai::embedding_dim() )); } diff --git a/src/database/location_dao.rs b/src/database/location_dao.rs index 8bb0ac4..9840279 100644 --- a/src/database/location_dao.rs +++ b/src/database/location_dao.rs @@ -216,11 +216,12 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { // Validate embedding dimensions if provided (rare for location data) if let Some(ref emb) = location.embedding - && emb.len() != 768 + && emb.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid embedding dimensions: {} (expected 768)", - emb.len() + "Invalid embedding dimensions: {} (expected {})", + emb.len(), + crate::ai::embedding_dim() )); } @@ -292,7 +293,7 @@ impl LocationHistoryDao for SqliteLocationHistoryDao { for location in locations { // Validate embedding if provided (rare) if let Some(ref emb) = location.embedding - && emb.len() != 768 + && emb.len() != crate::ai::embedding_dim() { log::warn!( "Skipping location with invalid embedding dimensions: {}", diff --git a/src/database/search_dao.rs b/src/database/search_dao.rs index ee7d0ad..a73c9fb 100644 --- a/src/database/search_dao.rs +++ b/src/database/search_dao.rs @@ -189,10 +189,11 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { .expect("Unable to get SearchHistoryDao"); // Validate embedding dimensions (REQUIRED for searches) - if search.embedding.len() != 768 { + if search.embedding.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid embedding dimensions: {} (expected 768)", - search.embedding.len() + "Invalid embedding dimensions: {} (expected {})", + search.embedding.len(), + crate::ai::embedding_dim() )); } @@ -245,7 +246,7 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { conn.transaction::<_, anyhow::Error, _>(|conn| { for search in searches { // Validate embedding (REQUIRED) - if search.embedding.len() != 768 { + if search.embedding.len() != crate::ai::embedding_dim() { log::warn!( "Skipping search with invalid embedding dimensions: {}", search.embedding.len() @@ -325,10 +326,11 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { .lock() .expect("Unable to get SearchHistoryDao"); - if query_embedding.len() != 768 { + if query_embedding.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid query embedding dimensions: {} (expected 768)", - query_embedding.len() + "Invalid query embedding dimensions: {} (expected {})", + query_embedding.len(), + crate::ai::embedding_dim() )); } @@ -406,10 +408,11 @@ impl SearchHistoryDao for SqliteSearchHistoryDao { // Step 2: If query embedding provided, rank by semantic similarity if let Some(query_emb) = query_embedding { - if query_emb.len() != 768 { + if query_emb.len() != crate::ai::embedding_dim() { return Err(anyhow::anyhow!( - "Invalid query embedding dimensions: {} (expected 768)", - query_emb.len() + "Invalid query embedding dimensions: {} (expected {})", + query_emb.len(), + crate::ai::embedding_dim() )); } diff --git a/src/main.rs b/src/main.rs index f27cf8f..8b56efd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -364,9 +364,13 @@ fn main() -> std::io::Result<()> { .service(ai::rate_insight_handler) .service(ai::export_training_data_handler) .service(ai::tts_speech_handler) + .service(ai::create_speech_job_handler) + .service(ai::speech_job_status_handler) + .service(ai::cancel_speech_job_handler) .service(ai::list_voices_handler) .service(ai::create_voice_upload_handler) .service(ai::create_voice_from_library_handler) + .service(ai::delete_voice_handler) .service(libraries::list_libraries) .service(libraries::patch_library) .add_feature(add_tag_services::<_, SqliteTagDao>) diff --git a/src/state.rs b/src/state.rs index ef071a8..e678ad1 100644 --- a/src/state.rs +++ b/src/state.rs @@ -186,21 +186,7 @@ impl AppState { impl Default for AppState { fn default() -> Self { // Initialize AI clients - let ollama_primary_url = env::var("OLLAMA_PRIMARY_URL").unwrap_or_else(|_| { - env::var("OLLAMA_URL").unwrap_or_else(|_| "http://localhost:11434".to_string()) - }); - let ollama_fallback_url = env::var("OLLAMA_FALLBACK_URL").ok(); - let ollama_primary_model = env::var("OLLAMA_PRIMARY_MODEL") - .or_else(|_| env::var("OLLAMA_MODEL")) - .unwrap_or_else(|_| "nemotron-3-nano:30b".to_string()); - let ollama_fallback_model = env::var("OLLAMA_FALLBACK_MODEL").ok(); - - let ollama = OllamaClient::new( - ollama_primary_url, - ollama_fallback_url, - ollama_primary_model, - ollama_fallback_model, - ); + let ollama = build_ollama_from_env(); let openrouter = build_openrouter_from_env(); let openrouter_allowed_models = parse_openrouter_allowed_models(); @@ -375,13 +361,29 @@ fn parse_openrouter_allowed_models() -> Vec<String> { .collect() } +/// Build the `OllamaClient` from environment variables — the canonical +/// `OLLAMA_*` wiring shared by the server (`AppState::default`) and the +/// standalone binaries (which predate this helper and used to copy it). +pub fn build_ollama_from_env() -> OllamaClient { + let primary_url = env::var("OLLAMA_PRIMARY_URL").unwrap_or_else(|_| { + env::var("OLLAMA_URL").unwrap_or_else(|_| "http://localhost:11434".to_string()) + }); + let fallback_url = env::var("OLLAMA_FALLBACK_URL").ok(); + let primary_model = env::var("OLLAMA_PRIMARY_MODEL") + .or_else(|_| env::var("OLLAMA_MODEL")) + .unwrap_or_else(|_| "nemotron-3-nano:30b".to_string()); + let fallback_model = env::var("OLLAMA_FALLBACK_MODEL").ok(); + + OllamaClient::new(primary_url, fallback_url, primary_model, fallback_model) +} + /// Build a `LlamaCppClient` from environment variables. Returns `None` when /// `LLAMA_SWAP_URL` is unset. The client is constructed unconditionally /// when the URL is set (so it's available even under `LLM_BACKEND=ollama` /// for ad-hoc tooling), but the agentic / chat paths only route through it /// when `LLM_BACKEND=llamacpp`. Slot ids default to the names the bundled /// `llama-swap/config.yaml` uses — `chat` / `vision` / `embed`. -fn build_llamacpp_from_env() -> Option<Arc<LlamaCppClient>> { +pub fn build_llamacpp_from_env() -> Option<Arc<LlamaCppClient>> { let base_url = env::var("LLAMA_SWAP_URL").ok()?; let primary_model = env::var("LLAMA_SWAP_PRIMARY_MODEL").ok(); let mut client = LlamaCppClient::new(Some(base_url), primary_model); diff --git a/tts_pronunciations.example.json b/tts_pronunciations.example.json new file mode 100644 index 0000000..9bc9df9 --- /dev/null +++ b/tts_pronunciations.example.json @@ -0,0 +1,13 @@ +{ + "Worcester": "Wuster", + "Spokane": "Spo can", + "wsl": "W S L", + "sql": "sequel", + "api": "A P I", + "US": "U S", + "Dr.": "Doctor", + "St.": "Saint", + "blvd": "boulevard", + "vs.": "versus", + "etc.": "et cetera" +}