Merge pull request 'Feature/tts voice management' (#105) from feature/tts-voice-management into master
Reviewed-on: #105
This commit was merged in pull request #105.
This commit is contained in:
@@ -5,6 +5,8 @@ database/target
|
|||||||
*.db-shm
|
*.db-shm
|
||||||
*.db-wal
|
*.db-wal
|
||||||
.env
|
.env
|
||||||
|
# Server-local TTS pronunciation overrides (tts_pronunciations.example.json is the template)
|
||||||
|
/tts_pronunciations.json
|
||||||
/tmp
|
/tmp
|
||||||
/docs
|
/docs
|
||||||
/specs
|
/specs
|
||||||
|
|||||||
@@ -645,6 +645,14 @@ OPENROUTER_APP_TITLE=ImageApi # Optional attribution header
|
|||||||
# re-embedding — mixed vector spaces break similarity search.
|
# re-embedding — mixed vector spaces break similarity search.
|
||||||
LLM_BACKEND=ollama
|
LLM_BACKEND=ollama
|
||||||
|
|
||||||
|
# Embedding model contract. Corpus and queries must be embedded by the same
|
||||||
|
# model with matching prefixes — after changing the embed model or any of
|
||||||
|
# these, run `cargo run --bin reembed_embeddings` (all tables) or search is
|
||||||
|
# garbage. Prefix values may contain a literal \n (expanded to a newline).
|
||||||
|
EMBEDDING_DIM=768 # 768 = nomic-embed-text v1.5; 1024 = Qwen3-Embedding-0.6B
|
||||||
|
EMBED_QUERY_PREFIX= # nomic: "search_query: " | Qwen3: "Instruct: <task>\nQuery: "
|
||||||
|
EMBED_DOCUMENT_PREFIX= # nomic: "search_document: " | Qwen3: leave empty
|
||||||
|
|
||||||
# llama.cpp / llama-swap (used when LLM_BACKEND=llamacpp). OpenAI-compatible
|
# llama.cpp / llama-swap (used when LLM_BACKEND=llamacpp). OpenAI-compatible
|
||||||
# proxy hosting one or more llama-server processes. Chat models receive
|
# proxy hosting one or more llama-server processes. Chat models receive
|
||||||
# images directly via content-parts (all models assumed vision-capable).
|
# images directly via content-parts (all models assumed vision-capable).
|
||||||
@@ -668,6 +676,8 @@ LLAMA_SWAP_TTS_REF_SECONDS=30 # Max voice-clone reference clip
|
|||||||
# (Chatterbox is zero-shot; ~10-20s clean ref is ideal)
|
# (Chatterbox is zero-shot; ~10-20s clean ref is ideal)
|
||||||
LLAMA_SWAP_TTS_REQUEST_TIMEOUT_SECONDS=600 # Per-request synth timeout (long chunked insights take
|
LLAMA_SWAP_TTS_REQUEST_TIMEOUT_SECONDS=600 # Per-request synth timeout (long chunked insights take
|
||||||
# minutes); overrides the shared client timeout for /tts/speech
|
# minutes); overrides the shared client timeout for /tts/speech
|
||||||
|
TTS_PRONUNCIATIONS_PATH=tts_pronunciations.json # JSON map of pronunciation overrides applied before synth
|
||||||
|
# (see tts_pronunciations.example.json); hot-reloaded on change
|
||||||
|
|
||||||
# Insight Chat Continuation
|
# Insight Chat Continuation
|
||||||
AGENTIC_CHAT_MAX_ITERATIONS=6 # Cap on tool-calling iterations per chat turn (default 6)
|
AGENTIC_CHAT_MAX_ITERATIONS=6 # Cap on tool-calling iterations per chat turn (default 6)
|
||||||
|
|||||||
@@ -153,17 +153,39 @@ behind the same llama-swap proxy. Only requires `LLAMA_SWAP_URL` (the TTS client
|
|||||||
is built whenever that's set — independent of `LLM_BACKEND`). Endpoints:
|
is built whenever that's set — independent of `LLM_BACKEND`). Endpoints:
|
||||||
- `POST /tts/speech` — body `{ text, voice?, format?, exaggeration?, cfg_weight?,
|
- `POST /tts/speech` — body `{ text, voice?, format?, exaggeration?, cfg_weight?,
|
||||||
temperature? }`; returns `{ audio_base64, format }`. Input is cleaned
|
temperature? }`; returns `{ audio_base64, format }`. Input is cleaned
|
||||||
server-side (markdown + emoji stripped) and the generation knobs are clamped
|
server-side (markdown + emoji stripped, then pronunciation overrides applied —
|
||||||
|
see below) and the generation knobs are clamped
|
||||||
to Chatterbox's ranges. Synthesis is serialized (one at a time — the upstream
|
to Chatterbox's ranges. Synthesis is serialized (one at a time — the upstream
|
||||||
has no GPU lock of its own); a concurrent request gets a fast `429`.
|
has no GPU lock of its own); a concurrent request gets a fast `429`.
|
||||||
- `GET /tts/voices` — list the voice library.
|
- `POST /tts/speech/jobs` — durable variant for long syntheses: same body as
|
||||||
|
`/tts/speech`, returns `202 { job_id, status }` immediately. Jobs queue on the
|
||||||
|
GPU permit instead of fast-failing `429`.
|
||||||
|
- `GET /tts/speech/jobs/{id}` — poll a job: `{ job_id, status, format,
|
||||||
|
audio_base64?, error? }` with status `queued|running|done|error|cancelled`.
|
||||||
|
Results are kept in memory ~10 min after completion, then the job 404s.
|
||||||
|
- `DELETE /tts/speech/jobs/{id}` — cancel a queued/running job.
|
||||||
|
- `GET /tts/voices` — list the voice library. Served from an in-memory cache
|
||||||
|
(so the listing doesn't make llama-swap spin up the TTS model and evict the
|
||||||
|
resident LLM); pass `?refresh=1` to force an upstream re-query. The cache is
|
||||||
|
invalidated by voice create/delete.
|
||||||
- `POST /tts/voices/upload` — multipart `voice_name` + `voice_file`; clone a
|
- `POST /tts/voices/upload` — multipart `voice_name` + `voice_file`; clone a
|
||||||
voice from an uploaded clip (≤25 MB).
|
voice from an uploaded clip (≤25 MB).
|
||||||
- `POST /tts/voices/from-library` — body `{ voice_name, path, library? }`; clone
|
- `POST /tts/voices/from-library` — body `{ voice_name, path, library? }`; clone
|
||||||
from a library file (audio forwarded as-is; video has its audio extracted via
|
from a library file (audio forwarded as-is; video has its audio extracted via
|
||||||
ffmpeg).
|
ffmpeg).
|
||||||
|
- `DELETE /tts/voices/{name}` — remove a cloned voice from the library.
|
||||||
|
|
||||||
|
Created voice names are tagged with the ref-clip cap in effect (e.g.
|
||||||
|
`grandma-30s`) so the library shows which reference length produced each clone.
|
||||||
|
|
||||||
|
Words the model mispronounces (place names, initialisms) can be rewritten
|
||||||
|
before synthesis via a JSON map — copy `tts_pronunciations.example.json` to
|
||||||
|
`tts_pronunciations.json` and edit; changes apply without a restart. Full
|
||||||
|
matching rules are documented in `src/ai/pronunciation.rs`.
|
||||||
|
|
||||||
Env:
|
Env:
|
||||||
|
- `TTS_PRONUNCIATIONS_PATH` - pronunciation-override JSON file
|
||||||
|
[default: `tts_pronunciations.json` in the working directory]
|
||||||
- `LLAMA_SWAP_TTS_MODEL` - TTS model id in llama-swap's `config.yaml` [default: `chatterbox`]
|
- `LLAMA_SWAP_TTS_MODEL` - TTS model id in llama-swap's `config.yaml` [default: `chatterbox`]
|
||||||
- `LLAMA_SWAP_TTS_VOICE` - default voice used when a `/tts/speech` request omits `voice` (optional)
|
- `LLAMA_SWAP_TTS_VOICE` - default voice used when a `/tts/speech` request omits `voice` (optional)
|
||||||
- `LLAMA_SWAP_TTS_REF_SECONDS` - max voice-clone reference clip length in seconds
|
- `LLAMA_SWAP_TTS_REF_SECONDS` - max voice-clone reference clip length in seconds
|
||||||
|
|||||||
@@ -0,0 +1,88 @@
|
|||||||
|
// GPU lease — in-process coordination for llama-swap model contention.
|
||||||
|
//
|
||||||
|
// llama-swap runs the heavyweight models (chat / vision / Chatterbox TTS) as
|
||||||
|
// a mutually-exclusive set on one GPU (matrix DSL `(q27 | … | tts) & e`): a
|
||||||
|
// request for a non-resident model is HELD by llama-swap until the resident
|
||||||
|
// model's in-flight requests drain, then the models swap. That hold counts
|
||||||
|
// against the *holder's* reqwest timeout — measured live: a queued TTS burned
|
||||||
|
// 77s of its budget behind a single LLM turn, and an LLM request behind a
|
||||||
|
// running synthesis waited the entire remaining synth. Uncoordinated
|
||||||
|
// cross-model traffic therefore times out instead of queueing.
|
||||||
|
//
|
||||||
|
// The lease moves that wait into this process, BEFORE the HTTP request is
|
||||||
|
// sent and before its timeout starts:
|
||||||
|
// - chat/vision requests (the LLM-side slots) share the READ lease;
|
||||||
|
// - TTS synthesis and voice-library ops (anything that spins Chatterbox up
|
||||||
|
// and evicts the LLM) take the WRITE lease;
|
||||||
|
// - embeddings take NO lease: the `embed` slot is in llama-swap's
|
||||||
|
// always-resident group (the `& e` term) and never participates in a swap,
|
||||||
|
// so leasing it would only stall searches behind a queued synthesis.
|
||||||
|
//
|
||||||
|
// tokio's RwLock is fair (FIFO, write-preferring): a queued TTS gets the GPU
|
||||||
|
// right after the current LLM request drains, and later LLM requests queue
|
||||||
|
// behind it — bounded waits in both directions, no starvation, no timeout
|
||||||
|
// budget burned while waiting.
|
||||||
|
//
|
||||||
|
// RULES: hold a lease for exactly one HTTP request (for streaming, the
|
||||||
|
// stream's lifetime) and NEVER acquire one while already holding one — once a
|
||||||
|
// writer is queued, new read acquisitions block, so nested acquisition can
|
||||||
|
// deadlock.
|
||||||
|
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||||
|
|
||||||
|
static GPU_LEASE: LazyLock<RwLock<()>> = LazyLock::new(|| RwLock::new(()));
|
||||||
|
|
||||||
|
/// Waits longer than this are logged — they mean a cross-model swap was
|
||||||
|
/// avoided and quantify what the request *would* have burned of its timeout.
|
||||||
|
const SLOW_WAIT_LOG_SECS: f64 = 2.0;
|
||||||
|
|
||||||
|
/// Shared lease for LLM-side requests (chat / vision slots).
|
||||||
|
pub async fn llm_lease() -> RwLockReadGuard<'static, ()> {
|
||||||
|
let started = Instant::now();
|
||||||
|
let guard = GPU_LEASE.read().await;
|
||||||
|
log_slow_wait("llm", started);
|
||||||
|
guard
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Exclusive lease for TTS-side requests (speech synthesis + voice-library
|
||||||
|
/// ops that spin up Chatterbox).
|
||||||
|
pub async fn tts_lease() -> RwLockWriteGuard<'static, ()> {
|
||||||
|
let started = Instant::now();
|
||||||
|
let guard = GPU_LEASE.write().await;
|
||||||
|
log_slow_wait("tts", started);
|
||||||
|
guard
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_slow_wait(kind: &str, started: Instant) {
|
||||||
|
let waited = started.elapsed().as_secs_f64();
|
||||||
|
if waited > SLOW_WAIT_LOG_SECS {
|
||||||
|
log::info!("GPU lease ({kind}): waited {waited:.1}s for the other model class to drain");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
// One sequential test, not several: the lease is a single global, so
|
||||||
|
// parallel tests interleaving reads and writes on it can hit the very
|
||||||
|
// nested-acquisition deadlock the module comment warns about.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn write_lease_excludes_readers_then_reads_share() {
|
||||||
|
let w = tts_lease().await;
|
||||||
|
// A reader must not acquire while the writer is held.
|
||||||
|
let pending = tokio::spawn(async { drop(llm_lease().await) });
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
assert!(!pending.is_finished());
|
||||||
|
drop(w);
|
||||||
|
pending.await.expect("reader acquires after writer drops");
|
||||||
|
|
||||||
|
// With no writer queued, read leases are shared.
|
||||||
|
let a = llm_lease().await;
|
||||||
|
let b = llm_lease().await;
|
||||||
|
drop(a);
|
||||||
|
drop(b);
|
||||||
|
}
|
||||||
|
}
|
||||||
+16
-2
@@ -468,6 +468,13 @@ pub async fn generate_insight_handler(
|
|||||||
let path_for_task = path.clone();
|
let path_for_task = path.clone();
|
||||||
let generator_for_task = generator.clone();
|
let generator_for_task = generator.clone();
|
||||||
let result = tokio::task::spawn(async move {
|
let result = tokio::task::spawn(async move {
|
||||||
|
// Cross-model barrier: if a TTS synthesis holds the GPU, wait it
|
||||||
|
// out BEFORE the generation wall-clock starts. The per-request
|
||||||
|
// lease keeps reqwest budgets honest, but this job-level timeout
|
||||||
|
// would otherwise burn while the first chat call queues behind a
|
||||||
|
// multi-minute synthesis. Dropped immediately — holding it across
|
||||||
|
// the generation would deadlock the chat calls' own leases.
|
||||||
|
drop(crate::ai::gpu::llm_lease().await);
|
||||||
tokio::time::timeout(
|
tokio::time::timeout(
|
||||||
std::time::Duration::from_secs(timeout_secs),
|
std::time::Duration::from_secs(timeout_secs),
|
||||||
generator_for_task.generate_insight_for_photo_with_config(
|
generator_for_task.generate_insight_for_photo_with_config(
|
||||||
@@ -510,7 +517,9 @@ pub async fn generate_insight_handler(
|
|||||||
}
|
}
|
||||||
Ok(Ok(Err(e))) => {
|
Ok(Ok(Err(e))) => {
|
||||||
log::error!("Insight generation failed for {}: {:?}", path, e);
|
log::error!("Insight generation failed for {}: {:?}", path, e);
|
||||||
if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:?}", e)) {
|
// `{:#}` = one-line context chain; the job's error_message is
|
||||||
|
// returned to the client verbatim, so no Debug/backtrace here.
|
||||||
|
if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:#}", e)) {
|
||||||
log::error!("Failed to mark job {} as failed: {:?}", job_id, err);
|
log::error!("Failed to mark job {} as failed: {:?}", job_id, err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -844,6 +853,9 @@ pub async fn generate_agentic_insight_handler(
|
|||||||
let path_for_task = path.clone();
|
let path_for_task = path.clone();
|
||||||
let generator_for_task = generator.clone();
|
let generator_for_task = generator.clone();
|
||||||
let result = tokio::task::spawn(async move {
|
let result = tokio::task::spawn(async move {
|
||||||
|
// Cross-model barrier — see generate_insight_handler: wait out any
|
||||||
|
// running TTS synthesis before the generation wall-clock starts.
|
||||||
|
drop(crate::ai::gpu::llm_lease().await);
|
||||||
tokio::time::timeout(
|
tokio::time::timeout(
|
||||||
std::time::Duration::from_secs(timeout_secs),
|
std::time::Duration::from_secs(timeout_secs),
|
||||||
generator_for_task.generate_agentic_insight_for_photo(
|
generator_for_task.generate_agentic_insight_for_photo(
|
||||||
@@ -884,7 +896,9 @@ pub async fn generate_agentic_insight_handler(
|
|||||||
}
|
}
|
||||||
Ok(Ok(Err(e))) => {
|
Ok(Ok(Err(e))) => {
|
||||||
log::error!("Agentic insight generation failed for {}: {:?}", path, e);
|
log::error!("Agentic insight generation failed for {}: {:?}", path, e);
|
||||||
if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:?}", e)) {
|
// `{:#}` = one-line context chain; the job's error_message is
|
||||||
|
// returned to the client verbatim, so no Debug/backtrace here.
|
||||||
|
if let Err(err) = dao.fail_job(&ctx, job_id, &format!("{:#}", e)) {
|
||||||
log::error!("Failed to mark job {} as failed: {:?}", job_id, err);
|
log::error!("Failed to mark job {} as failed: {:?}", job_id, err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+197
-61
@@ -33,30 +33,40 @@ use crate::utils::{earliest_fs_time, normalize_path};
|
|||||||
/// and labels the truncation via `found_header`.
|
/// and labels the truncation via `found_header`.
|
||||||
const LOCATION_HISTORY_DISPLAY_LIMIT: usize = 20;
|
const LOCATION_HISTORY_DISPLAY_LIMIT: usize = 20;
|
||||||
|
|
||||||
|
/// Strip common markdown decoration (bold/italic markers, heading hashes,
|
||||||
|
/// backticks, quotes) from both ends of a model-emitted title. Models wrap
|
||||||
|
/// the line despite the prompt: `**Title: A Day in the Woods**`,
|
||||||
|
/// `## Title: ...`, `"..."`.
|
||||||
|
pub(crate) fn strip_title_markdown(s: &str) -> &str {
|
||||||
|
s.trim_matches(|c: char| matches!(c, '*' | '_' | '`' | '#' | '"') || c.is_whitespace())
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse a "Title: ...\n\n<body>" response into (title, body).
|
/// Parse a "Title: ...\n\n<body>" response into (title, body).
|
||||||
/// Falls back to the first sentence as the title if the model didn't
|
/// Falls back to the first sentence as the title if the model didn't
|
||||||
/// follow the format.
|
/// follow the format.
|
||||||
pub(crate) fn parse_title_body(raw: &str) -> (String, String) {
|
pub(crate) fn parse_title_body(raw: &str) -> (String, String) {
|
||||||
let trimmed = raw.trim();
|
let trimmed = raw.trim();
|
||||||
|
|
||||||
// Try "Title: <title>\n\n<body>" or "Title: <title>\n<body>"
|
// Try "Title: <title>\n<body>", tolerating markdown decoration around
|
||||||
if let Some(rest) = trimmed
|
// the title line.
|
||||||
|
let (first_line, rest) = match trimmed.find('\n') {
|
||||||
|
Some(pos) => (&trimmed[..pos], trimmed[pos..].trim()),
|
||||||
|
None => (trimmed, ""),
|
||||||
|
};
|
||||||
|
let first_line = strip_title_markdown(first_line);
|
||||||
|
if let Some(t) = first_line
|
||||||
.strip_prefix("Title:")
|
.strip_prefix("Title:")
|
||||||
.or_else(|| trimmed.strip_prefix("title:"))
|
.or_else(|| first_line.strip_prefix("title:"))
|
||||||
{
|
{
|
||||||
let rest = rest.trim_start();
|
let title = strip_title_markdown(t);
|
||||||
if let Some(split_pos) = rest.find("\n\n").or_else(|| rest.find('\n')) {
|
if !title.is_empty() && !rest.is_empty() {
|
||||||
let title = rest[..split_pos].trim();
|
return (title.to_string(), rest.to_string());
|
||||||
let body = rest[split_pos..].trim();
|
|
||||||
if !title.is_empty() && !body.is_empty() {
|
|
||||||
return (title.to_string(), body.to_string());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback: first sentence (up to first `. ` or `.\n`) becomes the title
|
// Fallback: first sentence (up to first `. ` or `.\n`) becomes the title
|
||||||
if let Some(pos) = trimmed.find(". ").or_else(|| trimmed.find(".\n")) {
|
if let Some(pos) = trimmed.find(". ").or_else(|| trimmed.find(".\n")) {
|
||||||
let title = &trimmed[..pos];
|
let title = strip_title_markdown(&trimmed[..pos]);
|
||||||
let body = trimmed[pos + 1..].trim();
|
let body = trimmed[pos + 1..].trim();
|
||||||
if title.len() <= 100 && !body.is_empty() {
|
if title.len() <= 100 && !body.is_empty() {
|
||||||
return (title.to_string(), body.to_string());
|
return (title.to_string(), body.to_string());
|
||||||
@@ -65,7 +75,7 @@ pub(crate) fn parse_title_body(raw: &str) -> (String, String) {
|
|||||||
|
|
||||||
// Last resort: truncate to 60 chars for title, full text as body
|
// Last resort: truncate to 60 chars for title, full text as body
|
||||||
let title: String = trimmed.chars().take(60).collect();
|
let title: String = trimmed.chars().take(60).collect();
|
||||||
let title = title.trim_end().to_string();
|
let title = strip_title_markdown(title.trim_end()).to_string();
|
||||||
(title, trimmed.to_string())
|
(title, trimmed.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,7 +545,7 @@ impl InsightGenerator {
|
|||||||
// (`LLM_BACKEND` switch). Must match the backend that populated the
|
// (`LLM_BACKEND` switch). Must match the backend that populated the
|
||||||
// daily-summary embeddings or similarity search will be garbage.
|
// daily-summary embeddings or similarity search will be garbage.
|
||||||
let query_embedding =
|
let query_embedding =
|
||||||
crate::ai::embed_one(&self.ollama, self.llamacpp.as_deref(), &query).await?;
|
crate::ai::embed_query(&self.ollama, self.llamacpp.as_deref(), &query).await?;
|
||||||
|
|
||||||
// Search for similar daily summaries with time-based weighting
|
// Search for similar daily summaries with time-based weighting
|
||||||
// This prioritizes summaries temporally close to the query date
|
// This prioritizes summaries temporally close to the query date
|
||||||
@@ -575,6 +585,67 @@ impl InsightGenerator {
|
|||||||
Ok(formatted)
|
Ok(formatted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Semantic search over daily summaries for the agentic `search_rag`
|
||||||
|
/// tool. Embeds the caller's query as-is (no metadata boilerplate) and
|
||||||
|
/// only applies time weighting when an anchor date is provided —
|
||||||
|
/// without one, results rank purely by similarity across all time.
|
||||||
|
async fn search_summaries_semantic(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
date: Option<chrono::NaiveDate>,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<String>> {
|
||||||
|
let tracer = global_tracer();
|
||||||
|
let current_cx = opentelemetry::Context::current();
|
||||||
|
let mut span = tracer.start_with_context("ai.rag.search_daily_summaries", ¤t_cx);
|
||||||
|
span.set_attribute(KeyValue::new("query", query.to_string()));
|
||||||
|
span.set_attribute(KeyValue::new("limit", limit as i64));
|
||||||
|
span.set_attribute(KeyValue::new("time_weighted", date.is_some()));
|
||||||
|
if let Some(d) = date {
|
||||||
|
span.set_attribute(KeyValue::new("date", d.to_string()));
|
||||||
|
}
|
||||||
|
let search_cx = current_cx.with_span(span);
|
||||||
|
|
||||||
|
log::info!("RAG QUERY: {} (anchor date: {:?})", query, date);
|
||||||
|
|
||||||
|
// Must use the same backend that populated the daily-summary
|
||||||
|
// embeddings or similarity search is garbage (see embed_one docs).
|
||||||
|
let query_embedding =
|
||||||
|
crate::ai::embed_query(&self.ollama, self.llamacpp.as_deref(), query).await?;
|
||||||
|
|
||||||
|
let mut summary_dao = self
|
||||||
|
.daily_summary_dao
|
||||||
|
.lock()
|
||||||
|
.expect("Unable to lock DailySummaryDao");
|
||||||
|
|
||||||
|
let similar_summaries = match date {
|
||||||
|
Some(d) => summary_dao.find_similar_summaries_with_time_weight(
|
||||||
|
&search_cx,
|
||||||
|
&query_embedding,
|
||||||
|
&d.format("%Y-%m-%d").to_string(),
|
||||||
|
limit,
|
||||||
|
),
|
||||||
|
None => summary_dao.find_similar_summaries(&search_cx, &query_embedding, limit),
|
||||||
|
}
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to find similar summaries: {:?}", e))?;
|
||||||
|
|
||||||
|
search_cx.span().set_attribute(KeyValue::new(
|
||||||
|
"results_count",
|
||||||
|
similar_summaries.len() as i64,
|
||||||
|
));
|
||||||
|
search_cx.span().set_status(Status::Ok);
|
||||||
|
|
||||||
|
Ok(similar_summaries
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| {
|
||||||
|
format!(
|
||||||
|
"[{}] {} ({} messages):\n{}",
|
||||||
|
s.date, s.contact, s.message_count, s.summary
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
/// Build a metadata-based query (fallback when no topics available)
|
/// Build a metadata-based query (fallback when no topics available)
|
||||||
fn build_metadata_query(
|
fn build_metadata_query(
|
||||||
date: chrono::NaiveDate,
|
date: chrono::NaiveDate,
|
||||||
@@ -626,7 +697,7 @@ impl InsightGenerator {
|
|||||||
let calendar_cx = parent_cx.with_span(span);
|
let calendar_cx = parent_cx.with_span(span);
|
||||||
|
|
||||||
let query_embedding = if let Some(loc) = location {
|
let query_embedding = if let Some(loc) = location {
|
||||||
match crate::ai::embed_one(&self.ollama, self.llamacpp.as_deref(), loc).await {
|
match crate::ai::embed_query(&self.ollama, self.llamacpp.as_deref(), loc).await {
|
||||||
Ok(emb) => Some(emb),
|
Ok(emb) => Some(emb),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::warn!("Failed to generate embedding for location '{}': {}", loc, e);
|
log::warn!("Failed to generate embedding for location '{}': {}", loc, e);
|
||||||
@@ -798,7 +869,8 @@ impl InsightGenerator {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let query_embedding =
|
let query_embedding =
|
||||||
match crate::ai::embed_one(&self.ollama, self.llamacpp.as_deref(), &query_text).await {
|
match crate::ai::embed_query(&self.ollama, self.llamacpp.as_deref(), &query_text).await
|
||||||
|
{
|
||||||
Ok(emb) => emb,
|
Ok(emb) => emb,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::warn!("Failed to generate search embedding: {}", e);
|
log::warn!("Failed to generate search embedding: {}", e);
|
||||||
@@ -1737,13 +1809,12 @@ Return ONLY the summary, nothing else."#,
|
|||||||
Some(q) => q.to_string(),
|
Some(q) => q.to_string(),
|
||||||
None => return "Error: missing required parameter 'query'".to_string(),
|
None => return "Error: missing required parameter 'query'".to_string(),
|
||||||
};
|
};
|
||||||
let date_str = match args.get("date").and_then(|v| v.as_str()) {
|
let date = match args.get("date").and_then(|v| v.as_str()) {
|
||||||
Some(d) => d,
|
Some(d) => match NaiveDate::parse_from_str(d, "%Y-%m-%d") {
|
||||||
None => return "Error: missing required parameter 'date'".to_string(),
|
Ok(d) => Some(d),
|
||||||
};
|
Err(e) => return format!("Error: failed to parse date '{}': {}", d, e),
|
||||||
let date = match NaiveDate::parse_from_str(date_str, "%Y-%m-%d") {
|
},
|
||||||
Ok(d) => d,
|
None => None,
|
||||||
Err(e) => return format!("Error: failed to parse date '{}': {}", date_str, e),
|
|
||||||
};
|
};
|
||||||
let contact = args
|
let contact = args
|
||||||
.get("contact")
|
.get("contact")
|
||||||
@@ -1756,7 +1827,7 @@ Return ONLY the summary, nothing else."#,
|
|||||||
.clamp(1, 25) as usize;
|
.clamp(1, 25) as usize;
|
||||||
|
|
||||||
log::info!(
|
log::info!(
|
||||||
"tool_search_rag: query='{}', date={}, contact={:?}, limit={}",
|
"tool_search_rag: query='{}', date={:?}, contact={:?}, limit={}",
|
||||||
query,
|
query,
|
||||||
date,
|
date,
|
||||||
contact,
|
contact,
|
||||||
@@ -1777,15 +1848,17 @@ Return ONLY the summary, nothing else."#,
|
|||||||
limit
|
limit
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Embed the model's query verbatim — a soft contact bias is the
|
||||||
|
// only decoration. The metadata boilerplate ("On <date>, it was a
|
||||||
|
// <weekday>") that find_relevant_messages_rag prepends drowns the
|
||||||
|
// semantic signal, so the tool path deliberately bypasses it.
|
||||||
|
let search_query = match contact.as_deref() {
|
||||||
|
Some(c) => format!("{} (conversation with {})", query, c),
|
||||||
|
None => query.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
let results = match self
|
let results = match self
|
||||||
.find_relevant_messages_rag(
|
.search_summaries_semantic(&search_query, date, candidate_limit)
|
||||||
date,
|
|
||||||
None,
|
|
||||||
contact.as_deref(),
|
|
||||||
None,
|
|
||||||
candidate_limit,
|
|
||||||
Some(&query),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(results) if !results.is_empty() => results,
|
Ok(results) if !results.is_empty() => results,
|
||||||
@@ -2062,12 +2135,15 @@ Return ONLY the summary, nothing else."#,
|
|||||||
/// Render a list of [`SmsSearchHit`] for the LLM. Prefers the SMS-API
|
/// Render a list of [`SmsSearchHit`] for the LLM. Prefers the SMS-API
|
||||||
/// snippet (which already excerpts the matched span and is the only
|
/// snippet (which already excerpts the matched span and is the only
|
||||||
/// preview MMS-attachment-only matches have) over the full body, and
|
/// preview MMS-attachment-only matches have) over the full body, and
|
||||||
/// strips the `<mark>` tags the snippet ships with.
|
/// strips the `<mark>` tags the snippet ships with. Each line names
|
||||||
|
/// both parties (`sender → recipient`) — results can span multiple
|
||||||
|
/// conversations, and a sender-only label leaves sent messages
|
||||||
|
/// unattributable to a thread.
|
||||||
fn format_search_hits(hits: &[SmsSearchHit], mode: &str, date_filtered: bool) -> String {
|
fn format_search_hits(hits: &[SmsSearchHit], mode: &str, date_filtered: bool) -> String {
|
||||||
let user_name = user_display_name();
|
let user_name = user_display_name();
|
||||||
let mut out = String::new();
|
let mut out = String::new();
|
||||||
out.push_str(&format!(
|
out.push_str(&format!(
|
||||||
"Found {} messages (mode: {}{}):\n\n",
|
"Found {} messages (mode: {}{}, sender → recipient):\n\n",
|
||||||
hits.len(),
|
hits.len(),
|
||||||
mode,
|
mode,
|
||||||
if date_filtered { ", date-filtered" } else { "" }
|
if date_filtered { ", date-filtered" } else { "" }
|
||||||
@@ -2076,10 +2152,10 @@ Return ONLY the summary, nothing else."#,
|
|||||||
let date = chrono::DateTime::from_timestamp(h.date, 0)
|
let date = chrono::DateTime::from_timestamp(h.date, 0)
|
||||||
.map(|dt| dt.format("%Y-%m-%d").to_string())
|
.map(|dt| dt.format("%Y-%m-%d").to_string())
|
||||||
.unwrap_or_else(|| h.date.to_string());
|
.unwrap_or_else(|| h.date.to_string());
|
||||||
let direction: &str = if h.type_ == 2 {
|
let direction = if h.type_ == 2 {
|
||||||
&user_name
|
format!("{} → {}", user_name, h.contact_name)
|
||||||
} else {
|
} else {
|
||||||
&h.contact_name
|
format!("{} → {}", h.contact_name, user_name)
|
||||||
};
|
};
|
||||||
let score = h
|
let score = h
|
||||||
.similarity_score
|
.similarity_score
|
||||||
@@ -2150,11 +2226,18 @@ Return ONLY the summary, nothing else."#,
|
|||||||
{
|
{
|
||||||
Ok(messages) if !messages.is_empty() => {
|
Ok(messages) if !messages.is_empty() => {
|
||||||
let user_name = user_display_name();
|
let user_name = user_display_name();
|
||||||
|
// Name both parties — without a contact filter the window
|
||||||
|
// spans every conversation, and a sender-only label leaves
|
||||||
|
// sent messages unattributable to a thread.
|
||||||
let formatted: Vec<String> = messages
|
let formatted: Vec<String> = messages
|
||||||
.iter()
|
.iter()
|
||||||
.take(limit)
|
.take(limit)
|
||||||
.map(|m| {
|
.map(|m| {
|
||||||
let sender: &str = if m.is_sent { &user_name } else { &m.contact };
|
let direction = if m.is_sent {
|
||||||
|
format!("{} → {}", user_name, m.contact)
|
||||||
|
} else {
|
||||||
|
format!("{} → {}", m.contact, user_name)
|
||||||
|
};
|
||||||
let ts = DateTime::from_timestamp(m.timestamp, 0)
|
let ts = DateTime::from_timestamp(m.timestamp, 0)
|
||||||
.map(|dt| {
|
.map(|dt| {
|
||||||
dt.with_timezone(&Local)
|
dt.with_timezone(&Local)
|
||||||
@@ -2162,7 +2245,7 @@ Return ONLY the summary, nothing else."#,
|
|||||||
.to_string()
|
.to_string()
|
||||||
})
|
})
|
||||||
.unwrap_or_else(|| "unknown".to_string());
|
.unwrap_or_else(|| "unknown".to_string());
|
||||||
format!("[{}] {}: {}", ts, sender, m.body)
|
format!("[{}] {}: {}", ts, direction, m.body)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
format!(
|
format!(
|
||||||
@@ -2870,17 +2953,34 @@ Return ONLY the summary, nothing else."#,
|
|||||||
// Generate embedding for name + description (best-effort) via the
|
// Generate embedding for name + description (best-effort) via the
|
||||||
// configured local backend.
|
// configured local backend.
|
||||||
let embed_text = format!("{} {}", name, description);
|
let embed_text = format!("{} {}", name, description);
|
||||||
let embedding: Option<Vec<u8>> =
|
let embedding: Option<Vec<u8>> = match crate::ai::embed_document(
|
||||||
match crate::ai::embed_one(&self.ollama, self.llamacpp.as_deref(), &embed_text).await {
|
&self.ollama,
|
||||||
Ok(vec) => {
|
self.llamacpp.as_deref(),
|
||||||
let bytes: Vec<u8> = vec.iter().flat_map(|f| f.to_le_bytes()).collect();
|
&embed_text,
|
||||||
Some(bytes)
|
)
|
||||||
}
|
.await
|
||||||
Err(e) => {
|
{
|
||||||
log::warn!("Embedding generation failed for entity '{}': {}", name, e);
|
// The entities table has no dim check at the DAO layer, and a
|
||||||
None
|
// wrong-dim vector silently kills dedup/recall (cosine over
|
||||||
}
|
// mismatched lengths is 0) — guard here, store None instead.
|
||||||
};
|
Ok(vec) if vec.len() == crate::ai::embedding_dim() => {
|
||||||
|
let bytes: Vec<u8> = vec.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||||
|
Some(bytes)
|
||||||
|
}
|
||||||
|
Ok(vec) => {
|
||||||
|
log::warn!(
|
||||||
|
"Entity '{}' embedding has {} dims (expected {}) — storing without embedding",
|
||||||
|
name,
|
||||||
|
vec.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Embedding generation failed for entity '{}': {}", name, e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let now = chrono::Utc::now().timestamp();
|
let now = chrono::Utc::now().timestamp();
|
||||||
let insert = InsertEntity {
|
let insert = InsertEntity {
|
||||||
@@ -3206,21 +3306,25 @@ Return ONLY the summary, nothing else."#,
|
|||||||
if opts.daily_summaries_present {
|
if opts.daily_summaries_present {
|
||||||
tools.push(Tool::function(
|
tools.push(Tool::function(
|
||||||
"search_rag",
|
"search_rag",
|
||||||
"Date-anchored semantic search over the user's daily-summary corpus. \
|
"Semantic search over the user's daily-summary corpus. Returns up to \
|
||||||
Returns up to `limit` summaries most semantically similar to `query`, \
|
`limit` summaries most semantically similar to `query`. Pass `date` \
|
||||||
weighted toward summaries near `date`. For raw message text across all \
|
to anchor in time: summaries near that date rank higher and matches \
|
||||||
time, prefer `search_messages`. \
|
months away decay sharply. Omit `date` to rank purely by semantic \
|
||||||
Examples: `{query: \"family dinner\", date: \"2018-12-24\"}` — what \
|
similarity across all time — do this for \"when did X happen?\" \
|
||||||
|
questions where the date is unknown. For raw message text, prefer \
|
||||||
|
`search_messages`. \
|
||||||
|
Examples: `{query: \"family dinner\"}` — best matches across all \
|
||||||
|
time. `{query: \"family dinner\", date: \"2018-12-24\"}` — what \
|
||||||
daily summaries near Christmas Eve mention family / dinner / gathering. \
|
daily summaries near Christmas Eve mention family / dinner / gathering. \
|
||||||
`{query: \"work travel\", date: \"2019-06-15\", contact: \"Alice\"}` — \
|
`{query: \"work travel\", date: \"2019-06-15\", contact: \"Alice\"}` — \
|
||||||
narrowed to summaries that involve Alice.",
|
biased toward summaries that involve Alice.",
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["query", "date"],
|
"required": ["query"],
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": { "type": "string", "description": "Free-text query, semantically matched." },
|
"query": { "type": "string", "description": "Free-text query, semantically matched." },
|
||||||
"date": { "type": "string", "description": "Anchor date, YYYY-MM-DD. Summaries near this date rank higher." },
|
"date": { "type": "string", "description": "Optional anchor date, YYYY-MM-DD. When set, summaries near this date rank higher; omit to search all time evenly." },
|
||||||
"contact": { "type": "string", "description": "Optional contact name to bias toward conversations with that person." },
|
"contact": { "type": "string", "description": "Optional contact name to bias toward conversations with that person (soft semantic bias, not a hard filter)." },
|
||||||
"limit": { "type": "integer", "description": "Max summaries to return (default 10, max 25)." }
|
"limit": { "type": "integer", "description": "Max summaries to return (default 10, max 25)." }
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
@@ -4763,12 +4867,22 @@ mod tests {
|
|||||||
let hit = make_search_hit(1, "Sarah", "see you at the lake tomorrow", None, 1);
|
let hit = make_search_hit(1, "Sarah", "see you at the lake tomorrow", None, 1);
|
||||||
let out = InsightGenerator::format_search_hits(&[hit], "fts5", false);
|
let out = InsightGenerator::format_search_hits(&[hit], "fts5", false);
|
||||||
|
|
||||||
assert!(out.starts_with("Found 1 messages (mode: fts5):"));
|
assert!(out.starts_with("Found 1 messages (mode: fts5"));
|
||||||
assert!(out.contains("see you at the lake tomorrow"));
|
assert!(out.contains("see you at the lake tomorrow"));
|
||||||
assert!(out.contains("Sarah —"));
|
// Received message: contact is the sender.
|
||||||
|
assert!(out.contains("Sarah →"));
|
||||||
assert!(!out.contains("date-filtered"));
|
assert!(!out.contains("date-filtered"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn format_search_hits_labels_sent_direction() {
|
||||||
|
// Sent messages must name the recipient — results can span multiple
|
||||||
|
// conversations, and a sender-only label left them unattributable.
|
||||||
|
let hit = make_search_hit(5, "Sarah", "on my way", None, 2);
|
||||||
|
let out = InsightGenerator::format_search_hits(&[hit], "fts5", false);
|
||||||
|
assert!(out.contains("→ Sarah —"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn format_search_hits_prefers_snippet_over_body_and_strips_marks() {
|
fn format_search_hits_prefers_snippet_over_body_and_strips_marks() {
|
||||||
let hit = make_search_hit(
|
let hit = make_search_hit(
|
||||||
@@ -4799,7 +4913,7 @@ mod tests {
|
|||||||
|
|
||||||
assert!(out.contains("birthday_cake.jpg"));
|
assert!(out.contains("birthday_cake.jpg"));
|
||||||
assert!(!out.contains("<mark>"));
|
assert!(!out.contains("<mark>"));
|
||||||
assert!(out.contains("Mom —"));
|
assert!(out.contains("Mom →"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -5022,6 +5136,28 @@ mod tests {
|
|||||||
assert_eq!(b, "Everyone gathered...");
|
assert_eq!(b, "Everyone gathered...");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_title_body_strips_bold_wrapper() {
|
||||||
|
let (t, b) = parse_title_body("**Title: A Day in the Woods**\n\nWe hiked the ridge trail.");
|
||||||
|
assert_eq!(t, "A Day in the Woods");
|
||||||
|
assert_eq!(b, "We hiked the ridge trail.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_title_body_strips_bold_label_only() {
|
||||||
|
// Bold around just the label: "**Title:** X"
|
||||||
|
let (t, b) = parse_title_body("**Title:** Garden Party\n\nEveryone gathered...");
|
||||||
|
assert_eq!(t, "Garden Party");
|
||||||
|
assert_eq!(b, "Everyone gathered...");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_title_body_strips_heading_hashes() {
|
||||||
|
let (t, b) = parse_title_body("## Title: Morning Walk\nThe sun was rising...");
|
||||||
|
assert_eq!(t, "Morning Walk");
|
||||||
|
assert_eq!(b, "The sun was rising...");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_title_body_fallback_first_sentence() {
|
fn parse_title_body_fallback_first_sentence() {
|
||||||
let (t, b) = parse_title_body("A warm summer day. We gathered at the park for a picnic.");
|
let (t, b) = parse_title_body("A warm summer day. We gathered at the park for a picnic.");
|
||||||
|
|||||||
@@ -142,6 +142,11 @@ impl LlamaCppClient {
|
|||||||
/// Chatterbox generation knobs are forwarded when set (caller is expected
|
/// Chatterbox generation knobs are forwarded when set (caller is expected
|
||||||
/// to have range-clamped them): `exaggeration` (0.25–2.0, emotion),
|
/// to have range-clamped them): `exaggeration` (0.25–2.0, emotion),
|
||||||
/// `cfg_weight` (0.0–1.0, pace), `temperature` (0.05–5.0, randomness).
|
/// `cfg_weight` (0.0–1.0, pace), `temperature` (0.05–5.0, randomness).
|
||||||
|
///
|
||||||
|
/// Callers must hold the GPU write lease (`ai::gpu::tts_lease`) across
|
||||||
|
/// this call. It is taken at the call sites in `ai::tts` rather than here
|
||||||
|
/// so the speech-job path can flip its job to `running` between acquiring
|
||||||
|
/// the GPU and sending the request.
|
||||||
pub async fn text_to_speech(
|
pub async fn text_to_speech(
|
||||||
&self,
|
&self,
|
||||||
input: &str,
|
input: &str,
|
||||||
@@ -204,6 +209,9 @@ impl LlamaCppClient {
|
|||||||
/// List voices in the Chatterbox voice library (raw JSON passthrough).
|
/// List voices in the Chatterbox voice library (raw JSON passthrough).
|
||||||
pub async fn list_voices(&self) -> Result<Value> {
|
pub async fn list_voices(&self) -> Result<Value> {
|
||||||
let url = format!("{}/upstream/{}/voices", self.swap_root(), self.tts_model);
|
let url = format!("{}/upstream/{}/voices", self.swap_root(), self.tts_model);
|
||||||
|
// The /upstream passthrough spins Chatterbox up (evicting the LLM),
|
||||||
|
// so it takes the exclusive GPU lease like synthesis does.
|
||||||
|
let _gpu = crate::ai::gpu::tts_lease().await;
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
.get(&url)
|
.get(&url)
|
||||||
@@ -237,6 +245,9 @@ impl LlamaCppClient {
|
|||||||
.text("voice_name", voice_name.to_string())
|
.text("voice_name", voice_name.to_string())
|
||||||
.part("voice_file", part);
|
.part("voice_file", part);
|
||||||
|
|
||||||
|
// The /upstream passthrough spins Chatterbox up (evicting the LLM),
|
||||||
|
// so it takes the exclusive GPU lease like synthesis does.
|
||||||
|
let _gpu = crate::ai::gpu::tts_lease().await;
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
@@ -253,6 +264,37 @@ impl LlamaCppClient {
|
|||||||
resp.json().await.context("parsing create_voice response")
|
resp.json().await.context("parsing create_voice response")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Delete a cloned voice from the Chatterbox voice library
|
||||||
|
/// (`DELETE /voices/{name}` on the upstream, via llama-swap passthrough).
|
||||||
|
pub async fn delete_voice(&self, voice_name: &str) -> Result<Value> {
|
||||||
|
let url = format!(
|
||||||
|
"{}/upstream/{}/voices/{}",
|
||||||
|
self.swap_root(),
|
||||||
|
self.tts_model,
|
||||||
|
voice_name
|
||||||
|
);
|
||||||
|
// The /upstream passthrough spins Chatterbox up (evicting the LLM),
|
||||||
|
// so it takes the exclusive GPU lease like synthesis does.
|
||||||
|
let _gpu = crate::ai::gpu::tts_lease().await;
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.delete(&url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("DELETE {} failed", url))?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let status = resp.status();
|
||||||
|
let text = resp.text().await.unwrap_or_default();
|
||||||
|
bail!("llama-swap delete_voice failed: {} — {}", status, text);
|
||||||
|
}
|
||||||
|
// Some upstreams reply with an empty body on delete.
|
||||||
|
Ok(resp
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| json!({ "status": "deleted" })))
|
||||||
|
}
|
||||||
|
|
||||||
/// Translate canonical messages to the OpenAI-compatible wire shape.
|
/// Translate canonical messages to the OpenAI-compatible wire shape.
|
||||||
/// Behaviorally identical to `OpenRouterClient::messages_to_openai` —
|
/// Behaviorally identical to `OpenRouterClient::messages_to_openai` —
|
||||||
/// stringify tool-call arguments, rewrite images into content-parts, attach
|
/// stringify tool-call arguments, rewrite images into content-parts, attach
|
||||||
@@ -453,6 +495,9 @@ impl LlamaCppClient {
|
|||||||
body.insert(k.into(), v);
|
body.insert(k.into(), v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait for any TTS synthesis to release the GPU before the request
|
||||||
|
// timeout starts (see ai::gpu).
|
||||||
|
let _gpu = crate::ai::gpu::llm_lease().await;
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
@@ -571,6 +616,10 @@ impl LlmClient for LlamaCppClient {
|
|||||||
body.insert(k.into(), v);
|
body.insert(k.into(), v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait for any TTS synthesis to release the GPU before the request
|
||||||
|
// timeout starts (see ai::gpu). The guard is moved into the stream
|
||||||
|
// below so the lease spans the whole generation, not just the send.
|
||||||
|
let gpu = crate::ai::gpu::llm_lease().await;
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
@@ -587,6 +636,7 @@ impl LlmClient for LlamaCppClient {
|
|||||||
|
|
||||||
let byte_stream = resp.bytes_stream();
|
let byte_stream = resp.bytes_stream();
|
||||||
let stream = async_stream::stream! {
|
let stream = async_stream::stream! {
|
||||||
|
let _gpu = gpu;
|
||||||
let mut byte_stream = byte_stream;
|
let mut byte_stream = byte_stream;
|
||||||
let mut buf: Vec<u8> = Vec::new();
|
let mut buf: Vec<u8> = Vec::new();
|
||||||
let mut accumulated_content = String::new();
|
let mut accumulated_content = String::new();
|
||||||
@@ -702,6 +752,9 @@ impl LlmClient for LlamaCppClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
||||||
|
// Deliberately NO GPU lease: the embed slot sits in llama-swap's
|
||||||
|
// always-resident group and never participates in a model swap, so
|
||||||
|
// leasing here would only stall searches behind a queued synthesis.
|
||||||
let url = format!("{}/embeddings", self.base_url);
|
let url = format!("{}/embeddings", self.base_url);
|
||||||
let body = json!({
|
let body = json!({
|
||||||
"model": self.embedding_model,
|
"model": self.embedding_model,
|
||||||
|
|||||||
@@ -0,0 +1,88 @@
|
|||||||
|
//! Bundle of the local LLM pair (Ollama + optional llama-swap) with the
|
||||||
|
//! `LLM_BACKEND` dispatch baked in.
|
||||||
|
//!
|
||||||
|
//! Exists because passing the pair around as loose values invited the same
|
||||||
|
//! bug three times: import/backfill tooling embedded corpora via
|
||||||
|
//! `OllamaClient` directly while the query side dispatched through
|
||||||
|
//! `embed_one`, so flipping `LLM_BACKEND=llamacpp` silently split queries
|
||||||
|
//! and corpus into different vector spaces. Anything that writes or reads
|
||||||
|
//! embeddings should go through this type (or `embed_one`/`embed_many`),
|
||||||
|
//! never a concrete client.
|
||||||
|
//!
|
||||||
|
//! Deliberately knows nothing about chat policy — hybrid/OpenRouter routing
|
||||||
|
//! is request-scoped and stays in `ResolvedBackend`. This is only the
|
||||||
|
//! local stack: embeddings and offline single-shot generation.
|
||||||
|
|
||||||
|
// Constructed by binaries, not the server — dead code from main.rs's view.
|
||||||
|
#![allow(dead_code)]
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
|
use super::llamacpp::LlamaCppClient;
|
||||||
|
use super::llm_client::LlmClient;
|
||||||
|
use super::ollama::{EMBEDDING_MODEL, OllamaClient};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct LocalLlm {
|
||||||
|
ollama: OllamaClient,
|
||||||
|
llamacpp: Option<Arc<LlamaCppClient>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LocalLlm {
|
||||||
|
pub fn new(ollama: OllamaClient, llamacpp: Option<Arc<LlamaCppClient>>) -> Self {
|
||||||
|
Self { ollama, llamacpp }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct from the canonical env wiring shared with `AppState`.
|
||||||
|
pub fn from_env() -> Self {
|
||||||
|
Self::new(
|
||||||
|
crate::state::build_ollama_from_env(),
|
||||||
|
crate::state::build_llamacpp_from_env(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embed a search query (applies `EMBED_QUERY_PREFIX`). Callers must
|
||||||
|
/// pick query vs document — retrieval models treat the two sides
|
||||||
|
/// differently and an unmarked embed invites prefix-mismatch bugs.
|
||||||
|
pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
|
||||||
|
super::embed_query(&self.ollama, self.llamacpp.as_deref(), text).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embed corpus text (applies `EMBED_DOCUMENT_PREFIX`).
|
||||||
|
pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
|
||||||
|
super::embed_document(&self.ollama, self.llamacpp.as_deref(), text).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Single-shot local text generation via the `LLM_BACKEND`-selected
|
||||||
|
/// client (offline tooling; chat turns belong to `ResolvedBackend`).
|
||||||
|
pub async fn generate(&self, prompt: &str, system: Option<&str>) -> Result<String> {
|
||||||
|
if super::local_backend_is_llamacpp() {
|
||||||
|
if let Some(lc) = self.llamacpp.as_deref() {
|
||||||
|
return <LlamaCppClient as LlmClient>::generate(lc, prompt, system, None).await;
|
||||||
|
}
|
||||||
|
anyhow::bail!(
|
||||||
|
"LLM_BACKEND=llamacpp but LlamaCppClient is unconfigured — \
|
||||||
|
set LLAMA_SWAP_URL or switch to LLM_BACKEND=ollama"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
self.ollama.generate(prompt, system).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Label identifying which backend + model produces embeddings right
|
||||||
|
/// now. Store it alongside vectors (`model_version` columns) so a
|
||||||
|
/// backend flip is detectable in the data, not just in env history.
|
||||||
|
pub fn embedding_model_version(&self) -> String {
|
||||||
|
if super::local_backend_is_llamacpp() {
|
||||||
|
let slot = self
|
||||||
|
.llamacpp
|
||||||
|
.as_deref()
|
||||||
|
.map(|c| c.embedding_model.as_str())
|
||||||
|
.unwrap_or("embed");
|
||||||
|
format!("llama-swap:{}", slot)
|
||||||
|
} else {
|
||||||
|
EMBEDDING_MODEL.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
+84
-12
@@ -3,13 +3,16 @@ pub mod backend;
|
|||||||
pub mod clip_client;
|
pub mod clip_client;
|
||||||
pub mod daily_summary_job;
|
pub mod daily_summary_job;
|
||||||
pub mod face_client;
|
pub mod face_client;
|
||||||
|
pub mod gpu;
|
||||||
pub mod handlers;
|
pub mod handlers;
|
||||||
pub mod insight_chat;
|
pub mod insight_chat;
|
||||||
pub mod insight_generator;
|
pub mod insight_generator;
|
||||||
pub mod llamacpp;
|
pub mod llamacpp;
|
||||||
pub mod llm_client;
|
pub mod llm_client;
|
||||||
|
pub mod local_llm;
|
||||||
pub mod ollama;
|
pub mod ollama;
|
||||||
pub mod openrouter;
|
pub mod openrouter;
|
||||||
|
pub mod pronunciation;
|
||||||
pub mod sms_client;
|
pub mod sms_client;
|
||||||
pub mod tts;
|
pub mod tts;
|
||||||
pub mod turn_registry;
|
pub mod turn_registry;
|
||||||
@@ -34,11 +37,15 @@ pub use llamacpp::LlamaCppClient;
|
|||||||
pub use llm_client::{
|
pub use llm_client::{
|
||||||
ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction, ToolFunction,
|
ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction, ToolFunction,
|
||||||
};
|
};
|
||||||
|
// LocalLlm is constructed by binaries (reembed_embeddings, importers), not the server
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
pub use local_llm::LocalLlm;
|
||||||
pub use ollama::{EMBEDDING_MODEL, OllamaClient};
|
pub use ollama::{EMBEDDING_MODEL, OllamaClient};
|
||||||
pub use sms_client::{SmsApiClient, SmsMessage};
|
pub use sms_client::{SmsApiClient, SmsMessage};
|
||||||
pub use tts::{
|
pub use tts::{
|
||||||
create_voice_from_library_handler, create_voice_upload_handler, list_voices_handler,
|
cancel_speech_job_handler, create_speech_job_handler, create_voice_from_library_handler,
|
||||||
tts_speech_handler,
|
create_voice_upload_handler, delete_voice_handler, list_voices_handler,
|
||||||
|
speech_job_status_handler, tts_speech_handler,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Display name used for the user in message transcripts and first-person
|
/// Display name used for the user in message transcripts and first-person
|
||||||
@@ -69,35 +76,100 @@ pub fn local_backend_is_llamacpp() -> bool {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Embed one string via the configured local backend. Routes through
|
/// Expected embedding dimensionality, env-overridable via `EMBEDDING_DIM`
|
||||||
/// llama-swap when `LLM_BACKEND=llamacpp` (and a client is configured),
|
/// (default 768, nomic-embed-text). Every store/query dim check reads this —
|
||||||
/// else Ollama. Returns the single embedding vector. See
|
/// swapping to a different-dim model (e.g. Qwen3-Embedding-0.6B at 1024) is
|
||||||
/// [`local_backend_is_llamacpp`] for the rationale on consistency.
|
/// then a config flip plus a `reembed_embeddings` run, not a code change.
|
||||||
pub async fn embed_one(
|
/// Cached for the process lifetime; a flip requires a restart anyway since
|
||||||
|
/// the corpus must be re-embedded with it.
|
||||||
|
pub fn embedding_dim() -> usize {
|
||||||
|
static DIM: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
|
||||||
|
*DIM.get_or_init(|| {
|
||||||
|
std::env::var("EMBEDDING_DIM")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(768)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read an embedding prefix from the environment. `.env` values can't hold
|
||||||
|
/// real newlines, so a literal `\n` in the value is expanded — Qwen3-style
|
||||||
|
/// query instructions need one ("Instruct: ...\nQuery: ").
|
||||||
|
fn embed_prefix(key: &str) -> String {
|
||||||
|
std::env::var(key)
|
||||||
|
.map(|v| v.replace("\\n", "\n"))
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embed a search query. Applies `EMBED_QUERY_PREFIX` (default empty) —
|
||||||
|
/// retrieval models distinguish query-side from document-side text:
|
||||||
|
/// nomic v1.5 wants `search_query: `, Qwen3-Embedding wants
|
||||||
|
/// `Instruct: <task>\nQuery: `. Must pair with the document prefix the
|
||||||
|
/// corpus was embedded with or similarity degrades.
|
||||||
|
pub async fn embed_query(
|
||||||
ollama: &OllamaClient,
|
ollama: &OllamaClient,
|
||||||
llamacpp: Option<&LlamaCppClient>,
|
llamacpp: Option<&LlamaCppClient>,
|
||||||
text: &str,
|
text: &str,
|
||||||
) -> anyhow::Result<Vec<f32>> {
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
let prefixed = format!("{}{}", embed_prefix("EMBED_QUERY_PREFIX"), text);
|
||||||
|
embed_one(ollama, llamacpp, &prefixed).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embed corpus text (the stored side of retrieval). Applies
|
||||||
|
/// `EMBED_DOCUMENT_PREFIX` (default empty; nomic v1.5 wants
|
||||||
|
/// `search_document: `, Qwen3-Embedding wants none).
|
||||||
|
pub async fn embed_document(
|
||||||
|
ollama: &OllamaClient,
|
||||||
|
llamacpp: Option<&LlamaCppClient>,
|
||||||
|
text: &str,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
let prefixed = format!("{}{}", embed_prefix("EMBED_DOCUMENT_PREFIX"), text);
|
||||||
|
embed_one(ollama, llamacpp, &prefixed).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embed a batch of strings via the configured local backend. Routes
|
||||||
|
/// through llama-swap when `LLM_BACKEND=llamacpp` (and a client is
|
||||||
|
/// configured), else Ollama. See [`local_backend_is_llamacpp`] for the
|
||||||
|
/// rationale on consistency.
|
||||||
|
pub async fn embed_many(
|
||||||
|
ollama: &OllamaClient,
|
||||||
|
llamacpp: Option<&LlamaCppClient>,
|
||||||
|
texts: &[&str],
|
||||||
|
) -> anyhow::Result<Vec<Vec<f32>>> {
|
||||||
if local_backend_is_llamacpp() {
|
if local_backend_is_llamacpp() {
|
||||||
if let Some(lc) = llamacpp {
|
if let Some(lc) = llamacpp {
|
||||||
let mut vecs = <LlamaCppClient as LlmClient>::generate_embeddings(lc, &[text]).await?;
|
return <LlamaCppClient as LlmClient>::generate_embeddings(lc, texts).await;
|
||||||
return vecs
|
|
||||||
.pop()
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("llama-swap returned no embeddings"));
|
|
||||||
}
|
}
|
||||||
anyhow::bail!(
|
anyhow::bail!(
|
||||||
"LLM_BACKEND=llamacpp but LlamaCppClient is unconfigured — \
|
"LLM_BACKEND=llamacpp but LlamaCppClient is unconfigured — \
|
||||||
set LLAMA_SWAP_URL or switch to LLM_BACKEND=ollama"
|
set LLAMA_SWAP_URL or switch to LLM_BACKEND=ollama"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
ollama.generate_embedding(text).await
|
ollama.generate_embeddings(texts).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embed one string via the configured local backend. Single-text
|
||||||
|
/// convenience over [`embed_many`].
|
||||||
|
pub async fn embed_one(
|
||||||
|
ollama: &OllamaClient,
|
||||||
|
llamacpp: Option<&LlamaCppClient>,
|
||||||
|
text: &str,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
let mut vecs = embed_many(ollama, llamacpp, &[text]).await?;
|
||||||
|
vecs.pop()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("embedding backend returned no embeddings"))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod env_dispatch_tests {
|
mod env_dispatch_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
/// Env vars are process-global, and the test harness runs in parallel —
|
||||||
|
/// without this lock the `LLM_BACKEND` tests race each other and flake.
|
||||||
|
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||||
|
|
||||||
fn with_env<F: FnOnce()>(key: &str, val: Option<&str>, f: F) {
|
fn with_env<F: FnOnce()>(key: &str, val: Option<&str>, f: F) {
|
||||||
|
let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
|
||||||
let prev = std::env::var(key).ok();
|
let prev = std::env::var(key).ok();
|
||||||
match val {
|
match val {
|
||||||
Some(v) => unsafe { std::env::set_var(key, v) },
|
Some(v) => unsafe { std::env::set_var(key, v) },
|
||||||
|
|||||||
+15
-5
@@ -548,7 +548,16 @@ Capture the key moment or theme. Return ONLY the title, nothing else."#,
|
|||||||
let title = self
|
let title = self
|
||||||
.generate_with_images(&prompt, Some(system), None)
|
.generate_with_images(&prompt, Some(system), None)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(title.trim().trim_matches('"').to_string())
|
// Models decorate despite "Return ONLY the title": quotes, bold
|
||||||
|
// markers, sometimes a "Title:" label.
|
||||||
|
use crate::ai::insight_generator::strip_title_markdown;
|
||||||
|
let cleaned = strip_title_markdown(title.trim());
|
||||||
|
let cleaned = cleaned
|
||||||
|
.strip_prefix("Title:")
|
||||||
|
.or_else(|| cleaned.strip_prefix("title:"))
|
||||||
|
.map(strip_title_markdown)
|
||||||
|
.unwrap_or(cleaned);
|
||||||
|
Ok(cleaned.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a summary for a single photo based on its context
|
/// Generate a summary for a single photo based on its context
|
||||||
@@ -1046,13 +1055,14 @@ Analyze the image and use specific details from both the visual content and the
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Validate embedding dimensions (should be 768 for nomic-embed-text:v1.5)
|
// Validate embedding dimensions (EMBEDDING_DIM; 768 for nomic-embed-text:v1.5)
|
||||||
for (i, embedding) in embeddings.iter().enumerate() {
|
for (i, embedding) in embeddings.iter().enumerate() {
|
||||||
if embedding.len() != 768 {
|
if embedding.len() != crate::ai::embedding_dim() {
|
||||||
log::warn!(
|
log::warn!(
|
||||||
"Unexpected embedding dimensions for item {}: {} (expected 768)",
|
"Unexpected embedding dimensions for item {}: {} (expected {})",
|
||||||
i,
|
i,
|
||||||
embedding.len()
|
embedding.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,282 @@
|
|||||||
|
// User-configurable pronunciation overrides for TTS. Chatterbox mispronounces
|
||||||
|
// place names ("Worcester"), initialisms ("WSL"), and clipped abbreviations
|
||||||
|
// ("blvd"), so we rewrite them to phonetic spellings before synthesis.
|
||||||
|
//
|
||||||
|
// The map lives in a JSON file on the server — a flat object of
|
||||||
|
// `"written form": "spoken form"` pairs, e.g.:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "Worcester": "Wuster",
|
||||||
|
// "WSL": "W S L",
|
||||||
|
// "blvd": "boulevard",
|
||||||
|
// "Dr.": "Doctor"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Path comes from `TTS_PRONUNCIATIONS_PATH` (default `tts_pronunciations.json`
|
||||||
|
// in the working directory). A missing file simply disables the feature. The
|
||||||
|
// file is re-read whenever its mtime changes, so edits apply to the next
|
||||||
|
// synthesis without a restart; a malformed edit keeps the last good map and
|
||||||
|
// logs the parse error instead of silently dropping all overrides.
|
||||||
|
//
|
||||||
|
// Matching rules:
|
||||||
|
// - Whole words only — `cat` never rewrites `category`. (Boundaries are only
|
||||||
|
// asserted next to word characters, so keys like `Dr.` still work.)
|
||||||
|
// - Smartcase: an all-lowercase key matches case-insensitively; a key with
|
||||||
|
// any uppercase matches exactly. That lets `worcester` catch every casing
|
||||||
|
// while `US` (the country) leaves the pronoun `us` alone.
|
||||||
|
// - Longer keys win over shorter ones (`New York Times` before `New York`).
|
||||||
|
|
||||||
|
use regex::Regex;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::{Arc, LazyLock, Mutex as StdMutex};
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
/// A compiled pronunciation map: one alternation regex over every key plus
|
||||||
|
/// the lookup tables the replacement closure resolves matches against.
|
||||||
|
#[derive(Default)]
|
||||||
|
struct CompiledMap {
|
||||||
|
/// `None` when the map is empty — apply() is then a no-op.
|
||||||
|
regex: Option<Regex>,
|
||||||
|
/// Case-sensitive entries, keyed verbatim.
|
||||||
|
exact: HashMap<String, String>,
|
||||||
|
/// Case-insensitive entries, keyed lowercased.
|
||||||
|
folded: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompiledMap {
|
||||||
|
fn from_entries(entries: &HashMap<String, String>) -> Self {
|
||||||
|
let mut keys: Vec<&str> = entries
|
||||||
|
.keys()
|
||||||
|
.map(|k| k.as_str())
|
||||||
|
.filter(|k| !k.trim().is_empty())
|
||||||
|
.collect();
|
||||||
|
if keys.is_empty() {
|
||||||
|
return Self::default();
|
||||||
|
}
|
||||||
|
// Longest key first so overlapping entries prefer the more specific
|
||||||
|
// one (regex alternation is first-match-wins, not longest-match).
|
||||||
|
keys.sort_by(|a, b| b.len().cmp(&a.len()).then(a.cmp(b)));
|
||||||
|
|
||||||
|
let mut exact = HashMap::new();
|
||||||
|
let mut folded = HashMap::new();
|
||||||
|
let alternatives: Vec<String> = keys
|
||||||
|
.iter()
|
||||||
|
.map(|key| {
|
||||||
|
let escaped = regex::escape(key);
|
||||||
|
// Only assert a word boundary where the key edge is a word
|
||||||
|
// character — `\b` adjacent to punctuation (e.g. the dot in
|
||||||
|
// `Dr.`) would otherwise never match.
|
||||||
|
let lead = if key
|
||||||
|
.chars()
|
||||||
|
.next()
|
||||||
|
.is_some_and(|c| c.is_alphanumeric() || c == '_')
|
||||||
|
{
|
||||||
|
r"\b"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
let trail = if key
|
||||||
|
.chars()
|
||||||
|
.last()
|
||||||
|
.is_some_and(|c| c.is_alphanumeric() || c == '_')
|
||||||
|
{
|
||||||
|
r"\b"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
let case_sensitive = key.chars().any(|c| c.is_uppercase());
|
||||||
|
if case_sensitive {
|
||||||
|
exact.insert(key.to_string(), entries[*key].clone());
|
||||||
|
format!("{lead}{escaped}{trail}")
|
||||||
|
} else {
|
||||||
|
folded.insert(key.to_lowercase(), entries[*key].clone());
|
||||||
|
format!("{lead}(?i:{escaped}){trail}")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Escaped fixed strings can't produce an invalid pattern; if one ever
|
||||||
|
// does, treat the whole map as empty rather than panicking a handler.
|
||||||
|
let pattern = alternatives.join("|");
|
||||||
|
let regex = match Regex::new(&pattern) {
|
||||||
|
Ok(r) => Some(r),
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("pronunciation map failed to compile: {e}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Self {
|
||||||
|
regex,
|
||||||
|
exact,
|
||||||
|
folded,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(&self, text: &str) -> String {
|
||||||
|
let Some(re) = &self.regex else {
|
||||||
|
return text.to_string();
|
||||||
|
};
|
||||||
|
re.replace_all(text, |caps: ®ex::Captures| {
|
||||||
|
let m = &caps[0];
|
||||||
|
self.exact
|
||||||
|
.get(m)
|
||||||
|
.or_else(|| self.folded.get(&m.to_lowercase()))
|
||||||
|
.cloned()
|
||||||
|
// Unreachable in practice — every alternative came from one
|
||||||
|
// of the two maps — but never drop the user's text.
|
||||||
|
.unwrap_or_else(|| m.to_string())
|
||||||
|
})
|
||||||
|
.into_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CacheEntry {
|
||||||
|
mtime: Option<SystemTime>,
|
||||||
|
compiled: Arc<CompiledMap>,
|
||||||
|
}
|
||||||
|
|
||||||
|
static CACHE: LazyLock<StdMutex<Option<CacheEntry>>> = LazyLock::new(|| StdMutex::new(None));
|
||||||
|
|
||||||
|
fn config_path() -> String {
|
||||||
|
std::env::var("TTS_PRONUNCIATIONS_PATH")
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.trim().to_string())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(|| "tts_pronunciations.json".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load the compiled map, re-reading the file only when its mtime changed
|
||||||
|
/// since the last call (or it appeared/disappeared). Synthesis is serialized
|
||||||
|
/// on a single GPU permit, so a stat per call is noise.
|
||||||
|
fn current_map() -> Arc<CompiledMap> {
|
||||||
|
let path_s = config_path();
|
||||||
|
let path = Path::new(&path_s);
|
||||||
|
let mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
|
||||||
|
|
||||||
|
let mut cache = CACHE.lock().unwrap();
|
||||||
|
if let Some(entry) = cache.as_ref()
|
||||||
|
&& entry.mtime == mtime
|
||||||
|
{
|
||||||
|
return entry.compiled.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
let compiled = match mtime {
|
||||||
|
None => Arc::new(CompiledMap::default()), // no file → no overrides
|
||||||
|
Some(_) => match std::fs::read_to_string(path)
|
||||||
|
.map_err(anyhow::Error::from)
|
||||||
|
.and_then(|s| Ok(serde_json::from_str::<HashMap<String, String>>(&s)?))
|
||||||
|
{
|
||||||
|
Ok(entries) => {
|
||||||
|
log::info!(
|
||||||
|
"loaded {} pronunciation override(s) from {path_s}",
|
||||||
|
entries.len()
|
||||||
|
);
|
||||||
|
Arc::new(CompiledMap::from_entries(&entries))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("failed to load pronunciation map {path_s}: {e}");
|
||||||
|
// Keep serving the previous map rather than regressing to
|
||||||
|
// none mid-edit; still record the new mtime so the error
|
||||||
|
// logs once per bad save, not once per synthesis.
|
||||||
|
cache
|
||||||
|
.as_ref()
|
||||||
|
.map(|c| c.compiled.clone())
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
*cache = Some(CacheEntry {
|
||||||
|
mtime,
|
||||||
|
compiled: compiled.clone(),
|
||||||
|
});
|
||||||
|
compiled
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rewrite configured words/abbreviations to their phonetic spellings.
|
||||||
|
/// Call on cleaned (post-markdown-strip) text, right before synthesis.
|
||||||
|
pub fn apply_pronunciations(text: &str) -> String {
|
||||||
|
current_map().apply(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn compile(pairs: &[(&str, &str)]) -> CompiledMap {
|
||||||
|
let entries = pairs
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||||
|
.collect();
|
||||||
|
CompiledMap::from_entries(&entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_map_is_a_noop() {
|
||||||
|
let m = compile(&[]);
|
||||||
|
assert_eq!(m.apply("nothing changes"), "nothing changes");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn replaces_whole_words_only() {
|
||||||
|
let m = compile(&[("cat", "kitty")]);
|
||||||
|
assert_eq!(m.apply("the cat sat"), "the kitty sat");
|
||||||
|
// No substring rewrites.
|
||||||
|
assert_eq!(m.apply("the category"), "the category");
|
||||||
|
assert_eq!(m.apply("concatenate"), "concatenate");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lowercase_keys_match_any_casing() {
|
||||||
|
let m = compile(&[("worcester", "Wuster")]);
|
||||||
|
assert_eq!(m.apply("Worcester is nice"), "Wuster is nice");
|
||||||
|
assert_eq!(m.apply("in WORCESTER today"), "in Wuster today");
|
||||||
|
assert_eq!(m.apply("worcester sauce"), "Wuster sauce");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn uppercase_keys_match_case_sensitively() {
|
||||||
|
let m = compile(&[("US", "U S")]);
|
||||||
|
assert_eq!(m.apply("the US economy"), "the U S economy");
|
||||||
|
// The pronoun survives.
|
||||||
|
assert_eq!(m.apply("join us today"), "join us today");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn keys_with_punctuation_work() {
|
||||||
|
// `\b` is only asserted next to word characters, so the trailing dot
|
||||||
|
// doesn't break matching.
|
||||||
|
let m = compile(&[("Dr.", "Doctor"), ("blvd", "boulevard")]);
|
||||||
|
assert_eq!(
|
||||||
|
m.apply("Dr. Smith on Sunset blvd"),
|
||||||
|
"Doctor Smith on Sunset boulevard"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn longer_keys_win_over_shorter() {
|
||||||
|
let m = compile(&[("new york", "Noo York"), ("new york times", "the Times")]);
|
||||||
|
assert_eq!(m.apply("read the new york times"), "read the the Times");
|
||||||
|
assert_eq!(m.apply("visit new york soon"), "visit Noo York soon");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multiple_occurrences_all_rewrite() {
|
||||||
|
let m = compile(&[("wsl", "W S L")]);
|
||||||
|
assert_eq!(m.apply("WSL and wsl and Wsl"), "W S L and W S L and W S L");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn replacement_text_is_verbatim() {
|
||||||
|
// Replacements aren't re-scanned — a value containing another key
|
||||||
|
// doesn't cascade.
|
||||||
|
let m = compile(&[("a1", "b2"), ("b2", "c3")]);
|
||||||
|
assert_eq!(m.apply("a1"), "b2");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn blank_keys_are_ignored() {
|
||||||
|
let m = compile(&[("", "x"), (" ", "y"), ("ok", "fine")]);
|
||||||
|
assert_eq!(m.apply("ok then"), "fine then");
|
||||||
|
}
|
||||||
|
}
|
||||||
+737
-39
@@ -6,7 +6,7 @@
|
|||||||
// (audio read directly; video has its audio track extracted via ffmpeg).
|
// (audio read directly; video has its audio track extracted via ffmpeg).
|
||||||
|
|
||||||
use actix_multipart::Multipart;
|
use actix_multipart::Multipart;
|
||||||
use actix_web::{HttpRequest, HttpResponse, Responder, get, post, web};
|
use actix_web::{HttpRequest, HttpResponse, Responder, delete, get, post, web};
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use bytes::{BufMut, BytesMut};
|
use bytes::{BufMut, BytesMut};
|
||||||
@@ -15,10 +15,13 @@ use opentelemetry::KeyValue;
|
|||||||
use opentelemetry::trace::{Span, Status, Tracer};
|
use opentelemetry::trace::{Span, Status, Tracer};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::{Value, json};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::LazyLock;
|
use std::sync::{LazyLock, Mutex as StdMutex};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::data::Claims;
|
use crate::data::Claims;
|
||||||
use crate::file_types::{is_audio_file, is_video_file};
|
use crate::file_types::{is_audio_file, is_video_file};
|
||||||
@@ -40,6 +43,105 @@ const MAX_VOICE_UPLOAD_BYTES: usize = 25 * 1024 * 1024; // 25 MB
|
|||||||
/// finishes — that's a wrapper limitation; the chunked-queue plan fixes it.)
|
/// finishes — that's a wrapper limitation; the chunked-queue plan fixes it.)
|
||||||
static TTS_PERMIT: LazyLock<Semaphore> = LazyLock::new(|| Semaphore::new(1));
|
static TTS_PERMIT: LazyLock<Semaphore> = LazyLock::new(|| Semaphore::new(1));
|
||||||
|
|
||||||
|
// --- Voice-list cache --------------------------------------------------------
|
||||||
|
|
||||||
|
/// Cached raw voice-library JSON. llama-swap's `/upstream/<model>/voices`
|
||||||
|
/// passthrough spins the TTS model up just to answer a listing — which can
|
||||||
|
/// evict the resident LLM — so we serve a cached copy and only hit upstream on
|
||||||
|
/// a cold cache, an explicit `?refresh=1`, or after a voice create/delete
|
||||||
|
/// invalidates it (the TTS model is already loaded right then anyway).
|
||||||
|
static VOICES_CACHE: LazyLock<StdMutex<Option<Value>>> = LazyLock::new(|| StdMutex::new(None));
|
||||||
|
|
||||||
|
fn cached_voices() -> Option<Value> {
|
||||||
|
VOICES_CACHE.lock().unwrap().clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store_voices_cache(v: &Value) {
|
||||||
|
*VOICES_CACHE.lock().unwrap() = Some(v.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn invalidate_voices_cache() {
|
||||||
|
*VOICES_CACHE.lock().unwrap() = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Async speech jobs -------------------------------------------------------
|
||||||
|
//
|
||||||
|
// Synthesizing a long insight can take minutes — too long to hang one HTTP
|
||||||
|
// request from a phone that may background the app or drop the connection.
|
||||||
|
// Durable variant: POST /tts/speech/jobs returns a job id immediately, the
|
||||||
|
// synth runs in a spawned task (queuing on TTS_PERMIT instead of fast-failing
|
||||||
|
// 429), and the client polls GET /tts/speech/jobs/{id} until it collects the
|
||||||
|
// audio. State is in-memory only (deliberately lighter than the chat
|
||||||
|
// TurnRegistry): a restart loses jobs, the client surfaces that and retries.
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum TtsJobStatus {
|
||||||
|
Queued,
|
||||||
|
Running,
|
||||||
|
Done,
|
||||||
|
Error,
|
||||||
|
Cancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TtsJobStatus {
|
||||||
|
fn is_terminal(self) -> bool {
|
||||||
|
matches!(self, Self::Done | Self::Error | Self::Cancelled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TtsJob {
|
||||||
|
status: TtsJobStatus,
|
||||||
|
format: String,
|
||||||
|
audio_base64: Option<String>,
|
||||||
|
error: Option<String>,
|
||||||
|
created_at: Instant,
|
||||||
|
finished_at: Option<Instant>,
|
||||||
|
abort: Option<tokio::task::AbortHandle>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Finished jobs linger so a client that lost connectivity can still collect
|
||||||
|
/// the result on a later poll; anything older than MAX_AGE is dropped outright
|
||||||
|
/// (aborted first if somehow still running). Swept lazily on each dispatch.
|
||||||
|
const TTS_JOB_RESULT_TTL: Duration = Duration::from_secs(10 * 60);
|
||||||
|
const TTS_JOB_MAX_AGE: Duration = Duration::from_secs(30 * 60);
|
||||||
|
|
||||||
|
static TTS_JOBS: LazyLock<StdMutex<HashMap<Uuid, TtsJob>>> =
|
||||||
|
LazyLock::new(|| StdMutex::new(HashMap::new()));
|
||||||
|
|
||||||
|
fn sweep_stale_jobs(jobs: &mut HashMap<Uuid, TtsJob>, now: Instant) {
|
||||||
|
jobs.retain(|_, job| {
|
||||||
|
let result_expired = job
|
||||||
|
.finished_at
|
||||||
|
.is_some_and(|t| now.duration_since(t) >= TTS_JOB_RESULT_TTL);
|
||||||
|
let too_old = now.duration_since(job.created_at) >= TTS_JOB_MAX_AGE;
|
||||||
|
if too_old && let Some(h) = job.abort.take() {
|
||||||
|
h.abort();
|
||||||
|
}
|
||||||
|
!(result_expired || too_old)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run `f` against a job, if it still exists.
|
||||||
|
fn with_job<R>(id: Uuid, f: impl FnOnce(&mut TtsJob) -> R) -> Option<R> {
|
||||||
|
TTS_JOBS.lock().unwrap().get_mut(&id).map(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Move a job to a terminal state (first terminal write wins — a cancel that
|
||||||
|
/// raced a completion keeps the cancel).
|
||||||
|
fn finish_job(id: Uuid, status: TtsJobStatus, audio_base64: Option<String>, error: Option<String>) {
|
||||||
|
with_job(id, |job| {
|
||||||
|
if job.status.is_terminal() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
job.status = status;
|
||||||
|
job.audio_base64 = audio_base64;
|
||||||
|
job.error = error;
|
||||||
|
job.finished_at = Some(Instant::now());
|
||||||
|
job.abort = None;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/// Sanitize a user-supplied voice name. The name is forwarded to Chatterbox
|
/// Sanitize a user-supplied voice name. The name is forwarded to Chatterbox
|
||||||
/// where it becomes a filename in the voice-library directory, so we restrict
|
/// where it becomes a filename in the voice-library directory, so we restrict
|
||||||
/// it to a safe charset (alphanumerics, dash, underscore) — no path
|
/// it to a safe charset (alphanumerics, dash, underscore) — no path
|
||||||
@@ -64,6 +166,66 @@ fn sanitize_voice_name(raw: &str) -> Option<String> {
|
|||||||
Some(cleaned.chars().take(64).collect())
|
Some(cleaned.chars().take(64).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reference-clip cap in seconds for voice cloning. Chatterbox is zero-shot —
|
||||||
|
/// a clean ~10–20s sample is the sweet spot and more rarely helps. Tune via
|
||||||
|
/// `LLAMA_SWAP_TTS_REF_SECONDS` (default 30).
|
||||||
|
fn tts_ref_seconds() -> u32 {
|
||||||
|
std::env::var("LLAMA_SWAP_TTS_REF_SECONDS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.trim().parse::<u32>().ok())
|
||||||
|
.filter(|n| *n > 0)
|
||||||
|
.unwrap_or(30)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tag a (sanitized) voice name with the reference window used to create it:
|
||||||
|
/// `grandma` → `grandma-30s` (from the start), or `grandma-at1m32s-30s` (30s
|
||||||
|
/// window starting at 1:32). The tag makes the window visible in the voice
|
||||||
|
/// list so clones of the same source from different sections can be compared.
|
||||||
|
/// Skips the append when the name already ends in the same tag; keeps the
|
||||||
|
/// 64-char bound by truncating the base name, never the tag.
|
||||||
|
fn append_ref_window(name: &str, start: f64, secs: u32) -> String {
|
||||||
|
let start_whole = start.round().max(0.0) as u64;
|
||||||
|
let suffix = if start_whole > 0 {
|
||||||
|
// ':' isn't in the safe voice-name charset, so 1:32 becomes 1m32s.
|
||||||
|
let at = if start_whole >= 60 {
|
||||||
|
format!("at{}m{:02}s", start_whole / 60, start_whole % 60)
|
||||||
|
} else {
|
||||||
|
format!("at{start_whole}s")
|
||||||
|
};
|
||||||
|
format!("-{at}-{secs}s")
|
||||||
|
} else {
|
||||||
|
format!("-{secs}s")
|
||||||
|
};
|
||||||
|
if name.ends_with(&suffix) {
|
||||||
|
return name.to_string();
|
||||||
|
}
|
||||||
|
let max_base = 64usize.saturating_sub(suffix.len());
|
||||||
|
let base: String = name.chars().take(max_base).collect();
|
||||||
|
let base = base.trim_end_matches('-');
|
||||||
|
format!("{base}{suffix}")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve a caller-supplied reference window into concrete `(start, duration)`
|
||||||
|
/// seconds for ffmpeg. Start defaults to 0; duration defaults to the
|
||||||
|
/// `tts_ref_seconds` cap and is clamped to it (the cap is the most audio the
|
||||||
|
/// TTS backend benefits from, so longer requests are quietly bounded rather
|
||||||
|
/// than rejected). Non-finite or negative values are the caller's bug → Err.
|
||||||
|
fn resolve_ref_window(
|
||||||
|
start_seconds: Option<f64>,
|
||||||
|
duration_seconds: Option<f64>,
|
||||||
|
) -> Result<(f64, f64), String> {
|
||||||
|
let cap = f64::from(tts_ref_seconds());
|
||||||
|
let start = start_seconds.unwrap_or(0.0);
|
||||||
|
if !start.is_finite() || start < 0.0 {
|
||||||
|
return Err("start_seconds must be a non-negative number".to_string());
|
||||||
|
}
|
||||||
|
let duration = duration_seconds.unwrap_or(cap);
|
||||||
|
if !duration.is_finite() || duration <= 0.0 {
|
||||||
|
return Err("duration_seconds must be a positive number".to_string());
|
||||||
|
}
|
||||||
|
Ok((start, duration.min(cap)))
|
||||||
|
}
|
||||||
|
|
||||||
/// Optional default voice for synthesis when the request doesn't name one.
|
/// Optional default voice for synthesis when the request doesn't name one.
|
||||||
/// Set `LLAMA_SWAP_TTS_VOICE=m` to read insights in a cloned voice by default.
|
/// Set `LLAMA_SWAP_TTS_VOICE=m` to read insights in a cloned voice by default.
|
||||||
fn default_voice() -> Option<String> {
|
fn default_voice() -> Option<String> {
|
||||||
@@ -125,33 +287,42 @@ fn clean_for_tts(input: &str) -> String {
|
|||||||
s.trim().to_string()
|
s.trim().to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Full text-preparation pipeline for synthesis: markdown/emoji cleanup, then
|
||||||
|
/// the user's pronunciation overrides (see [`crate::ai::pronunciation`]) on
|
||||||
|
/// the resulting plain text — after cleanup so word boundaries aren't
|
||||||
|
/// obscured by `**WSL**`-style markup.
|
||||||
|
fn prepare_for_tts(input: &str) -> String {
|
||||||
|
crate::ai::pronunciation::apply_pronunciations(&clean_for_tts(input))
|
||||||
|
}
|
||||||
|
|
||||||
/// Decode an audio/video file to mono 24 kHz WAV via ffmpeg, returning the WAV
|
/// Decode an audio/video file to mono 24 kHz WAV via ffmpeg, returning the WAV
|
||||||
/// bytes. Chatterbox validates the reference clip by file *extension* and
|
/// bytes. Chatterbox validates the reference clip by file *extension* and
|
||||||
/// rejects several formats (e.g. `.aac`, `.opus`), so we always normalize to
|
/// rejects several formats (e.g. `.aac`, `.opus`), so we always normalize to
|
||||||
/// WAV regardless of the source container. Capped at 30s — references only need
|
/// WAV regardless of the source container. Extracts `duration` seconds starting
|
||||||
/// a few seconds of clean speech.
|
/// at `start` (see resolve_ref_window) — references only need a few seconds of
|
||||||
async fn run_ffmpeg_to_wav(input_path: &str) -> anyhow::Result<Vec<u8>> {
|
/// clean speech, which may sit anywhere in a long recording.
|
||||||
|
async fn run_ffmpeg_to_wav(input_path: &str, start: f64, duration: f64) -> anyhow::Result<Vec<u8>> {
|
||||||
let out = tempfile::Builder::new()
|
let out = tempfile::Builder::new()
|
||||||
.suffix(".wav")
|
.suffix(".wav")
|
||||||
.tempfile()
|
.tempfile()
|
||||||
.context("creating temp wav")?;
|
.context("creating temp wav")?;
|
||||||
let out_s = out.path().to_string_lossy().to_string();
|
let out_s = out.path().to_string_lossy().to_string();
|
||||||
|
|
||||||
// Cap the reference clip length. Chatterbox is zero-shot — a clean ~10–20s
|
let start_s = format!("{start}");
|
||||||
// sample is the sweet spot and more rarely helps — so we use the first N
|
let secs = format!("{duration}");
|
||||||
// seconds. Tune via LLAMA_SWAP_TTS_REF_SECONDS (default 30).
|
|
||||||
let secs = std::env::var("LLAMA_SWAP_TTS_REF_SECONDS")
|
// -ss before -i is input seeking: fast, and frame accuracy doesn't matter
|
||||||
.ok()
|
// for picking a speech window.
|
||||||
.and_then(|s| s.trim().parse::<u32>().ok())
|
let mut args: Vec<&str> = vec!["-y"];
|
||||||
.filter(|n| *n > 0)
|
if start > 0.0 {
|
||||||
.unwrap_or(30)
|
args.extend(["-ss", &start_s]);
|
||||||
.to_string();
|
}
|
||||||
|
args.extend([
|
||||||
|
"-i", input_path, "-vn", "-ac", "1", "-ar", "24000", "-t", &secs, "-f", "wav", &out_s,
|
||||||
|
]);
|
||||||
|
|
||||||
let output = tokio::process::Command::new("ffmpeg")
|
let output = tokio::process::Command::new("ffmpeg")
|
||||||
.args([
|
.args(&args)
|
||||||
"-y", "-i", input_path, "-vn", "-ac", "1", "-ar", "24000", "-t", &secs, "-f", "wav",
|
|
||||||
&out_s,
|
|
||||||
])
|
|
||||||
.output()
|
.output()
|
||||||
.await
|
.await
|
||||||
.context("spawning ffmpeg")?;
|
.context("spawning ffmpeg")?;
|
||||||
@@ -164,7 +335,12 @@ async fn run_ffmpeg_to_wav(input_path: &str) -> anyhow::Result<Vec<u8>> {
|
|||||||
|
|
||||||
/// Normalize in-memory upload bytes to WAV: write to a temp file (keeping the
|
/// Normalize in-memory upload bytes to WAV: write to a temp file (keeping the
|
||||||
/// source extension as an ffmpeg probe hint) then transcode.
|
/// source extension as an ffmpeg probe hint) then transcode.
|
||||||
async fn transcode_bytes_to_wav(input: &[u8], src_ext: Option<&str>) -> anyhow::Result<Vec<u8>> {
|
async fn transcode_bytes_to_wav(
|
||||||
|
input: &[u8],
|
||||||
|
src_ext: Option<&str>,
|
||||||
|
start: f64,
|
||||||
|
duration: f64,
|
||||||
|
) -> anyhow::Result<Vec<u8>> {
|
||||||
let suffix = src_ext
|
let suffix = src_ext
|
||||||
.filter(|e| !e.is_empty())
|
.filter(|e| !e.is_empty())
|
||||||
.map(|e| format!(".{e}"))
|
.map(|e| format!(".{e}"))
|
||||||
@@ -174,7 +350,7 @@ async fn transcode_bytes_to_wav(input: &[u8], src_ext: Option<&str>) -> anyhow::
|
|||||||
.tempfile()
|
.tempfile()
|
||||||
.context("creating temp input")?;
|
.context("creating temp input")?;
|
||||||
std::fs::write(in_tmp.path(), input).context("writing temp input")?;
|
std::fs::write(in_tmp.path(), input).context("writing temp input")?;
|
||||||
run_ffmpeg_to_wav(&in_tmp.path().to_string_lossy()).await
|
run_ffmpeg_to_wav(&in_tmp.path().to_string_lossy(), start, duration).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -214,7 +390,7 @@ pub async fn tts_speech_handler(
|
|||||||
let parent_context = extract_context_from_request(&http_request);
|
let parent_context = extract_context_from_request(&http_request);
|
||||||
let mut span = global_tracer().start_with_context("http.tts.speech", &parent_context);
|
let mut span = global_tracer().start_with_context("http.tts.speech", &parent_context);
|
||||||
|
|
||||||
let text = clean_for_tts(&req.text);
|
let text = prepare_for_tts(&req.text);
|
||||||
if text.is_empty() {
|
if text.is_empty() {
|
||||||
span.set_status(Status::error("text is required"));
|
span.set_status(Status::error("text is required"));
|
||||||
return HttpResponse::BadRequest().json(json!({ "error": "text is required" }));
|
return HttpResponse::BadRequest().json(json!({ "error": "text is required" }));
|
||||||
@@ -255,6 +431,10 @@ pub async fn tts_speech_handler(
|
|||||||
}));
|
}));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Wait for the LLM side to release the GPU before sending — the synthesis
|
||||||
|
// timeout starts at send, not here (see ai::gpu).
|
||||||
|
let _gpu = crate::ai::gpu::tts_lease().await;
|
||||||
|
|
||||||
match client
|
match client
|
||||||
.text_to_speech(&text, voice, format, exaggeration, cfg_weight, temperature)
|
.text_to_speech(&text, voice, format, exaggeration, cfg_weight, temperature)
|
||||||
.await
|
.await
|
||||||
@@ -276,16 +456,283 @@ pub async fn tts_speech_handler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /tts/voices — list the Chatterbox voice library (raw passthrough).
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct TtsJobCreatedResponse {
|
||||||
|
pub job_id: String,
|
||||||
|
pub status: TtsJobStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct TtsJobStatusResponse {
|
||||||
|
pub job_id: String,
|
||||||
|
pub status: TtsJobStatus,
|
||||||
|
pub format: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub audio_base64: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /tts/speech/jobs — durable variant of /tts/speech for long syntheses.
|
||||||
|
/// Returns 202 + a job id immediately; the synth queues on the single GPU
|
||||||
|
/// permit (instead of fast-failing 429) and the client polls the job until
|
||||||
|
/// the audio is ready.
|
||||||
|
#[post("/tts/speech/jobs")]
|
||||||
|
pub async fn create_speech_job_handler(
|
||||||
|
http_request: HttpRequest,
|
||||||
|
_claims: Claims,
|
||||||
|
req: web::Json<TtsSpeechRequest>,
|
||||||
|
app_state: web::Data<AppState>,
|
||||||
|
) -> impl Responder {
|
||||||
|
let parent_context = extract_context_from_request(&http_request);
|
||||||
|
let mut span =
|
||||||
|
global_tracer().start_with_context("http.tts.speech_job.create", &parent_context);
|
||||||
|
|
||||||
|
let text = prepare_for_tts(&req.text);
|
||||||
|
if text.is_empty() {
|
||||||
|
span.set_status(Status::error("text is required"));
|
||||||
|
return HttpResponse::BadRequest().json(json!({ "error": "text is required" }));
|
||||||
|
}
|
||||||
|
if app_state.llamacpp.is_none() {
|
||||||
|
span.set_status(Status::error("tts backend not configured"));
|
||||||
|
return HttpResponse::ServiceUnavailable()
|
||||||
|
.json(json!({ "error": "TTS backend not configured (set LLAMA_SWAP_URL)" }));
|
||||||
|
}
|
||||||
|
|
||||||
|
let format = req
|
||||||
|
.format
|
||||||
|
.as_deref()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or("mp3")
|
||||||
|
.to_string();
|
||||||
|
let voice = req
|
||||||
|
.voice
|
||||||
|
.clone()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.or_else(default_voice);
|
||||||
|
// Clamp generation knobs to Chatterbox's documented ranges before forwarding.
|
||||||
|
let exaggeration = req.exaggeration.map(|x| x.clamp(0.25, 2.0));
|
||||||
|
let cfg_weight = req.cfg_weight.map(|x| x.clamp(0.0, 1.0));
|
||||||
|
let temperature = req.temperature.map(|x| x.clamp(0.05, 5.0));
|
||||||
|
|
||||||
|
span.set_attribute(KeyValue::new("tts.format", format.clone()));
|
||||||
|
span.set_attribute(KeyValue::new("tts.has_voice", voice.is_some()));
|
||||||
|
span.set_attribute(KeyValue::new("tts.text_len", text.len() as i64));
|
||||||
|
|
||||||
|
let job_id = Uuid::new_v4();
|
||||||
|
{
|
||||||
|
let mut jobs = TTS_JOBS.lock().unwrap();
|
||||||
|
sweep_stale_jobs(&mut jobs, Instant::now());
|
||||||
|
jobs.insert(
|
||||||
|
job_id,
|
||||||
|
TtsJob {
|
||||||
|
status: TtsJobStatus::Queued,
|
||||||
|
format: format.clone(),
|
||||||
|
audio_base64: None,
|
||||||
|
error: None,
|
||||||
|
created_at: Instant::now(),
|
||||||
|
finished_at: None,
|
||||||
|
abort: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let state = app_state.clone();
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
// Queue rather than fast-fail: jobs wait their turn for the GPU.
|
||||||
|
let _permit = match TTS_PERMIT.acquire().await {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => {
|
||||||
|
finish_job(
|
||||||
|
job_id,
|
||||||
|
TtsJobStatus::Error,
|
||||||
|
None,
|
||||||
|
Some("TTS queue closed".to_string()),
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Wait for the LLM side to release the GPU too (see ai::gpu) — only
|
||||||
|
// then does the job count as running. The synthesis timeout starts at
|
||||||
|
// the HTTP send below, so neither wait burns it, and the client can
|
||||||
|
// anchor its own deadline to the queued→running transition.
|
||||||
|
let _gpu = crate::ai::gpu::tts_lease().await;
|
||||||
|
|
||||||
|
// Cancelled while queued — release the permits without synthesizing.
|
||||||
|
let cancelled = with_job(job_id, |job| {
|
||||||
|
if job.status == TtsJobStatus::Queued {
|
||||||
|
job.status = TtsJobStatus::Running;
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or(true);
|
||||||
|
if cancelled {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(client) = state.llamacpp.as_ref() else {
|
||||||
|
finish_job(
|
||||||
|
job_id,
|
||||||
|
TtsJobStatus::Error,
|
||||||
|
None,
|
||||||
|
Some("TTS backend not configured".to_string()),
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
match client
|
||||||
|
.text_to_speech(
|
||||||
|
&text,
|
||||||
|
voice.as_deref(),
|
||||||
|
&format,
|
||||||
|
exaggeration,
|
||||||
|
cfg_weight,
|
||||||
|
temperature,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(bytes) => {
|
||||||
|
let audio = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||||
|
finish_job(job_id, TtsJobStatus::Done, Some(audio), None);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("TTS job {job_id} failed: {:?}", e);
|
||||||
|
finish_job(
|
||||||
|
job_id,
|
||||||
|
TtsJobStatus::Error,
|
||||||
|
None,
|
||||||
|
Some(format!("TTS failed: {e}")),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// Aborting an already-finished task is a no-op, so this late install is
|
||||||
|
// safe even if the job raced to completion.
|
||||||
|
with_job(job_id, |job| {
|
||||||
|
if !job.status.is_terminal() {
|
||||||
|
job.abort = Some(handle.abort_handle());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
span.set_status(Status::Ok);
|
||||||
|
HttpResponse::Accepted().json(TtsJobCreatedResponse {
|
||||||
|
job_id: job_id.to_string(),
|
||||||
|
status: TtsJobStatus::Queued,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /tts/speech/jobs/{id} — poll a speech job; returns the audio once done.
|
||||||
|
/// 404s after the job expires (results are kept ~10 min past completion).
|
||||||
|
#[get("/tts/speech/jobs/{id}")]
|
||||||
|
pub async fn speech_job_status_handler(
|
||||||
|
http_request: HttpRequest,
|
||||||
|
_claims: Claims,
|
||||||
|
path: web::Path<String>,
|
||||||
|
) -> impl Responder {
|
||||||
|
let parent_context = extract_context_from_request(&http_request);
|
||||||
|
let mut span =
|
||||||
|
global_tracer().start_with_context("http.tts.speech_job.status", &parent_context);
|
||||||
|
|
||||||
|
let Ok(id) = Uuid::parse_str(&path.into_inner()) else {
|
||||||
|
span.set_status(Status::error("invalid job id"));
|
||||||
|
return HttpResponse::BadRequest().json(json!({ "error": "invalid job id" }));
|
||||||
|
};
|
||||||
|
let resp = {
|
||||||
|
let jobs = TTS_JOBS.lock().unwrap();
|
||||||
|
jobs.get(&id).map(|job| TtsJobStatusResponse {
|
||||||
|
job_id: id.to_string(),
|
||||||
|
status: job.status,
|
||||||
|
format: job.format.clone(),
|
||||||
|
audio_base64: job.audio_base64.clone(),
|
||||||
|
error: job.error.clone(),
|
||||||
|
})
|
||||||
|
};
|
||||||
|
match resp {
|
||||||
|
Some(r) => {
|
||||||
|
span.set_status(Status::Ok);
|
||||||
|
HttpResponse::Ok().json(r)
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
span.set_status(Status::error("job not found"));
|
||||||
|
HttpResponse::NotFound()
|
||||||
|
.json(json!({ "error": "TTS job not found (it may have expired)" }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DELETE /tts/speech/jobs/{id} — cancel a queued/running speech job. Once the
|
||||||
|
/// upstream GPU job has started it can't be interrupted (same wrapper
|
||||||
|
/// limitation as the sync path); cancelling stops the wait and discards the
|
||||||
|
/// result. Cancelling an already-finished job leaves it terminal.
|
||||||
|
#[delete("/tts/speech/jobs/{id}")]
|
||||||
|
pub async fn cancel_speech_job_handler(
|
||||||
|
http_request: HttpRequest,
|
||||||
|
_claims: Claims,
|
||||||
|
path: web::Path<String>,
|
||||||
|
) -> impl Responder {
|
||||||
|
let parent_context = extract_context_from_request(&http_request);
|
||||||
|
let mut span =
|
||||||
|
global_tracer().start_with_context("http.tts.speech_job.cancel", &parent_context);
|
||||||
|
|
||||||
|
let Ok(id) = Uuid::parse_str(&path.into_inner()) else {
|
||||||
|
span.set_status(Status::error("invalid job id"));
|
||||||
|
return HttpResponse::BadRequest().json(json!({ "error": "invalid job id" }));
|
||||||
|
};
|
||||||
|
let status = with_job(id, |job| {
|
||||||
|
if !job.status.is_terminal() {
|
||||||
|
if let Some(h) = job.abort.take() {
|
||||||
|
h.abort();
|
||||||
|
}
|
||||||
|
job.status = TtsJobStatus::Cancelled;
|
||||||
|
job.finished_at = Some(Instant::now());
|
||||||
|
}
|
||||||
|
job.status
|
||||||
|
});
|
||||||
|
match status {
|
||||||
|
Some(s) => {
|
||||||
|
span.set_status(Status::Ok);
|
||||||
|
HttpResponse::Ok().json(json!({ "job_id": id.to_string(), "status": s }))
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
span.set_status(Status::error("job not found"));
|
||||||
|
HttpResponse::NotFound()
|
||||||
|
.json(json!({ "error": "TTS job not found (it may have expired)" }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ListVoicesQuery {
|
||||||
|
/// `?refresh=1` bypasses the voice-list cache and re-queries upstream
|
||||||
|
/// (which may spin up the TTS model).
|
||||||
|
#[serde(default)]
|
||||||
|
pub refresh: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /tts/voices — list the Chatterbox voice library. Served from an
|
||||||
|
/// in-memory cache when possible so browsing settings doesn't make llama-swap
|
||||||
|
/// load the TTS model (and evict the resident LLM); see VOICES_CACHE.
|
||||||
#[get("/tts/voices")]
|
#[get("/tts/voices")]
|
||||||
pub async fn list_voices_handler(
|
pub async fn list_voices_handler(
|
||||||
http_request: HttpRequest,
|
http_request: HttpRequest,
|
||||||
_claims: Claims,
|
_claims: Claims,
|
||||||
|
query: web::Query<ListVoicesQuery>,
|
||||||
app_state: web::Data<AppState>,
|
app_state: web::Data<AppState>,
|
||||||
) -> impl Responder {
|
) -> impl Responder {
|
||||||
let parent_context = extract_context_from_request(&http_request);
|
let parent_context = extract_context_from_request(&http_request);
|
||||||
let mut span = global_tracer().start_with_context("http.tts.voices.list", &parent_context);
|
let mut span = global_tracer().start_with_context("http.tts.voices.list", &parent_context);
|
||||||
|
|
||||||
|
let force = query
|
||||||
|
.refresh
|
||||||
|
.as_deref()
|
||||||
|
.is_some_and(|v| matches!(v, "1" | "true" | "yes"));
|
||||||
|
if !force && let Some(v) = cached_voices() {
|
||||||
|
span.set_attribute(KeyValue::new("tts.voices_cache_hit", true));
|
||||||
|
span.set_status(Status::Ok);
|
||||||
|
return HttpResponse::Ok().json(v);
|
||||||
|
}
|
||||||
|
|
||||||
let Some(client) = app_state.llamacpp.as_ref() else {
|
let Some(client) = app_state.llamacpp.as_ref() else {
|
||||||
span.set_status(Status::error("tts backend not configured"));
|
span.set_status(Status::error("tts backend not configured"));
|
||||||
return HttpResponse::ServiceUnavailable()
|
return HttpResponse::ServiceUnavailable()
|
||||||
@@ -293,6 +740,8 @@ pub async fn list_voices_handler(
|
|||||||
};
|
};
|
||||||
match client.list_voices().await {
|
match client.list_voices().await {
|
||||||
Ok(v) => {
|
Ok(v) => {
|
||||||
|
store_voices_cache(&v);
|
||||||
|
span.set_attribute(KeyValue::new("tts.voices_cache_hit", false));
|
||||||
span.set_status(Status::Ok);
|
span.set_status(Status::Ok);
|
||||||
HttpResponse::Ok().json(v)
|
HttpResponse::Ok().json(v)
|
||||||
}
|
}
|
||||||
@@ -304,8 +753,52 @@ pub async fn list_voices_handler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// DELETE /tts/voices/{name} — remove a cloned voice from the library.
|
||||||
|
#[delete("/tts/voices/{name}")]
|
||||||
|
pub async fn delete_voice_handler(
|
||||||
|
http_request: HttpRequest,
|
||||||
|
_claims: Claims,
|
||||||
|
path: web::Path<String>,
|
||||||
|
app_state: web::Data<AppState>,
|
||||||
|
) -> impl Responder {
|
||||||
|
let parent_context = extract_context_from_request(&http_request);
|
||||||
|
let mut span = global_tracer().start_with_context("http.tts.voices.delete", &parent_context);
|
||||||
|
|
||||||
|
let Some(client) = app_state.llamacpp.as_ref() else {
|
||||||
|
span.set_status(Status::error("tts backend not configured"));
|
||||||
|
return HttpResponse::ServiceUnavailable()
|
||||||
|
.json(json!({ "error": "TTS backend not configured" }));
|
||||||
|
};
|
||||||
|
// Same charset rule as creation — a name that sanitizes differently was
|
||||||
|
// never a voice we created, and must not reach the upstream URL.
|
||||||
|
let raw = path.into_inner();
|
||||||
|
let name = match sanitize_voice_name(&raw) {
|
||||||
|
Some(n) if n == raw => n,
|
||||||
|
_ => {
|
||||||
|
span.set_status(Status::error("invalid voice name"));
|
||||||
|
return HttpResponse::BadRequest().json(json!({ "error": "invalid voice name" }));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
span.set_attribute(KeyValue::new("tts.voice_name", name.clone()));
|
||||||
|
|
||||||
|
match client.delete_voice(&name).await {
|
||||||
|
Ok(v) => {
|
||||||
|
invalidate_voices_cache();
|
||||||
|
span.set_status(Status::Ok);
|
||||||
|
HttpResponse::Ok().json(v)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
span.set_status(Status::error("delete_voice failed"));
|
||||||
|
log::error!("delete_voice failed: {:?}", e);
|
||||||
|
HttpResponse::BadGateway().json(json!({ "error": format!("{e}") }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// POST /tts/voices/upload — register a cloned voice from an uploaded audio
|
/// POST /tts/voices/upload — register a cloned voice from an uploaded audio
|
||||||
/// clip. Multipart fields: `voice_name` (text) + a file part (`voice_file`).
|
/// clip. Multipart fields: `voice_name` (text) + a file part (`voice_file`),
|
||||||
|
/// plus optional `start_seconds` / `duration_seconds` (text) selecting which
|
||||||
|
/// window of a longer recording becomes the reference clip.
|
||||||
#[post("/tts/voices/upload")]
|
#[post("/tts/voices/upload")]
|
||||||
pub async fn create_voice_upload_handler(
|
pub async fn create_voice_upload_handler(
|
||||||
http_request: HttpRequest,
|
http_request: HttpRequest,
|
||||||
@@ -323,6 +816,8 @@ pub async fn create_voice_upload_handler(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut voice_name: Option<String> = None;
|
let mut voice_name: Option<String> = None;
|
||||||
|
let mut start_field: Option<String> = None;
|
||||||
|
let mut duration_field: Option<String> = None;
|
||||||
let mut file_bytes = BytesMut::new();
|
let mut file_bytes = BytesMut::new();
|
||||||
let mut filename = "voice.wav".to_string();
|
let mut filename = "voice.wav".to_string();
|
||||||
|
|
||||||
@@ -347,22 +842,57 @@ pub async fn create_voice_upload_handler(
|
|||||||
}
|
}
|
||||||
file_bytes.put(data);
|
file_bytes.put(data);
|
||||||
}
|
}
|
||||||
} else if name_opt.as_deref() == Some("voice_name") {
|
} else if matches!(
|
||||||
|
name_opt.as_deref(),
|
||||||
|
Some("voice_name" | "start_seconds" | "duration_seconds")
|
||||||
|
) {
|
||||||
|
let field = name_opt.as_deref().unwrap().to_string();
|
||||||
let mut buf = BytesMut::new();
|
let mut buf = BytesMut::new();
|
||||||
while let Some(Ok(data)) = part.next().await {
|
while let Some(Ok(data)) = part.next().await {
|
||||||
buf.put(data);
|
buf.put(data);
|
||||||
}
|
}
|
||||||
voice_name = Some(String::from_utf8_lossy(&buf).trim().to_string());
|
let text = String::from_utf8_lossy(&buf).trim().to_string();
|
||||||
|
match field.as_str() {
|
||||||
|
"voice_name" => voice_name = Some(text),
|
||||||
|
"start_seconds" => start_field = Some(text),
|
||||||
|
_ => duration_field = Some(text),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
while let Some(Ok(_)) = part.next().await {}
|
while let Some(Ok(_)) = part.next().await {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Empty text parts are treated as absent; anything else must parse, so a
|
||||||
|
// client bug ("abc") fails loudly instead of silently cloning from 0s.
|
||||||
|
let parse_secs = |field: Option<&String>, name: &str| -> Result<Option<f64>, String> {
|
||||||
|
match field.map(|s| s.as_str()).filter(|s| !s.is_empty()) {
|
||||||
|
None => Ok(None),
|
||||||
|
Some(s) => s
|
||||||
|
.parse::<f64>()
|
||||||
|
.map(Some)
|
||||||
|
.map_err(|_| format!("{name} must be a number of seconds")),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let window = parse_secs(start_field.as_ref(), "start_seconds").and_then(|start| {
|
||||||
|
parse_secs(duration_field.as_ref(), "duration_seconds")
|
||||||
|
.and_then(|dur| resolve_ref_window(start, dur))
|
||||||
|
});
|
||||||
|
let (ref_start, ref_duration) = match window {
|
||||||
|
Ok(w) => w,
|
||||||
|
Err(msg) => {
|
||||||
|
span.set_status(Status::error("invalid reference window"));
|
||||||
|
return HttpResponse::BadRequest().json(json!({ "error": msg }));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let Some(name) = voice_name.as_deref().and_then(sanitize_voice_name) else {
|
let Some(name) = voice_name.as_deref().and_then(sanitize_voice_name) else {
|
||||||
span.set_status(Status::error("voice_name is required"));
|
span.set_status(Status::error("voice_name is required"));
|
||||||
return HttpResponse::BadRequest()
|
return HttpResponse::BadRequest()
|
||||||
.json(json!({ "error": "voice_name is required (alphanumerics, - and _ only)" }));
|
.json(json!({ "error": "voice_name is required (alphanumerics, - and _ only)" }));
|
||||||
};
|
};
|
||||||
|
// Tag the name with the ref-clip length (e.g. `grandma-30s`) so the
|
||||||
|
// library shows which reference length produced each clone.
|
||||||
|
let name = append_ref_window(&name, ref_start, ref_duration.round().max(1.0) as u32);
|
||||||
if file_bytes.is_empty() {
|
if file_bytes.is_empty() {
|
||||||
span.set_status(Status::error("voice_file is required"));
|
span.set_status(Status::error("voice_file is required"));
|
||||||
return HttpResponse::BadRequest().json(json!({ "error": "voice_file is required" }));
|
return HttpResponse::BadRequest().json(json!({ "error": "voice_file is required" }));
|
||||||
@@ -373,21 +903,23 @@ pub async fn create_voice_upload_handler(
|
|||||||
// Normalize to WAV so any device format (e.g. .aac / .opus, which Chatterbox
|
// Normalize to WAV so any device format (e.g. .aac / .opus, which Chatterbox
|
||||||
// rejects by extension) is accepted.
|
// rejects by extension) is accepted.
|
||||||
let src_ext = Path::new(&filename).extension().and_then(|e| e.to_str());
|
let src_ext = Path::new(&filename).extension().and_then(|e| e.to_str());
|
||||||
let wav = match transcode_bytes_to_wav(file_bytes.as_ref(), src_ext).await {
|
let wav =
|
||||||
Ok(w) => w,
|
match transcode_bytes_to_wav(file_bytes.as_ref(), src_ext, ref_start, ref_duration).await {
|
||||||
Err(e) => {
|
Ok(w) => w,
|
||||||
span.set_status(Status::error("audio decode failed"));
|
Err(e) => {
|
||||||
log::error!("voice upload transcode failed: {:?}", e);
|
span.set_status(Status::error("audio decode failed"));
|
||||||
return HttpResponse::BadRequest()
|
log::error!("voice upload transcode failed: {:?}", e);
|
||||||
.json(json!({ "error": "couldn't decode that audio file" }));
|
return HttpResponse::BadRequest()
|
||||||
}
|
.json(json!({ "error": "couldn't decode that audio file" }));
|
||||||
};
|
}
|
||||||
|
};
|
||||||
|
|
||||||
match client
|
match client
|
||||||
.create_voice(&name, wav, "reference.wav", "audio/wav")
|
.create_voice(&name, wav, "reference.wav", "audio/wav")
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(v) => {
|
Ok(v) => {
|
||||||
|
invalidate_voices_cache();
|
||||||
span.set_status(Status::Ok);
|
span.set_status(Status::Ok);
|
||||||
HttpResponse::Ok().json(v)
|
HttpResponse::Ok().json(v)
|
||||||
}
|
}
|
||||||
@@ -406,11 +938,19 @@ pub struct CreateVoiceFromLibraryRequest {
|
|||||||
pub path: String,
|
pub path: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub library: Option<String>,
|
pub library: Option<String>,
|
||||||
|
/// Offset into the source where the reference window begins (default 0) —
|
||||||
|
/// lets the client pick the clean-speech section of a long recording.
|
||||||
|
#[serde(default)]
|
||||||
|
pub start_seconds: Option<f64>,
|
||||||
|
/// Reference window length; clamped to LLAMA_SWAP_TTS_REF_SECONDS.
|
||||||
|
#[serde(default)]
|
||||||
|
pub duration_seconds: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /tts/voices/from-library — register a cloned voice from a file already
|
/// POST /tts/voices/from-library — register a cloned voice from a file already
|
||||||
/// in a library. Audio and video alike are ffmpeg-normalized to a mono 24 kHz
|
/// in a library. Audio and video alike are ffmpeg-normalized to a mono 24 kHz
|
||||||
/// WAV reference clip (length capped by LLAMA_SWAP_TTS_REF_SECONDS).
|
/// WAV reference clip (window selected by start/duration_seconds, length
|
||||||
|
/// capped by LLAMA_SWAP_TTS_REF_SECONDS).
|
||||||
#[post("/tts/voices/from-library")]
|
#[post("/tts/voices/from-library")]
|
||||||
pub async fn create_voice_from_library_handler(
|
pub async fn create_voice_from_library_handler(
|
||||||
http_request: HttpRequest,
|
http_request: HttpRequest,
|
||||||
@@ -432,6 +972,18 @@ pub async fn create_voice_from_library_handler(
|
|||||||
return HttpResponse::BadRequest()
|
return HttpResponse::BadRequest()
|
||||||
.json(json!({ "error": "voice_name is required (alphanumerics, - and _ only)" }));
|
.json(json!({ "error": "voice_name is required (alphanumerics, - and _ only)" }));
|
||||||
};
|
};
|
||||||
|
let (ref_start, ref_duration) =
|
||||||
|
match resolve_ref_window(req.start_seconds, req.duration_seconds) {
|
||||||
|
Ok(w) => w,
|
||||||
|
Err(msg) => {
|
||||||
|
span.set_status(Status::error("invalid reference window"));
|
||||||
|
return HttpResponse::BadRequest().json(json!({ "error": msg }));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Tag the name with the ref-clip length (e.g. `grandma-30s`) so the
|
||||||
|
// library shows which reference length produced each clone.
|
||||||
|
let voice_name =
|
||||||
|
append_ref_window(&voice_name, ref_start, ref_duration.round().max(1.0) as u32);
|
||||||
|
|
||||||
let library = match libraries::resolve_library_param(&app_state, req.library.as_deref()) {
|
let library = match libraries::resolve_library_param(&app_state, req.library.as_deref()) {
|
||||||
Ok(Some(l)) => l,
|
Ok(Some(l)) => l,
|
||||||
@@ -460,7 +1012,7 @@ pub async fn create_voice_from_library_handler(
|
|||||||
}
|
}
|
||||||
span.set_attribute(KeyValue::new("tts.voice_name", voice_name.clone()));
|
span.set_attribute(KeyValue::new("tts.voice_name", voice_name.clone()));
|
||||||
|
|
||||||
let wav = match prepare_reference_audio(&abs).await {
|
let wav = match prepare_reference_audio(&abs, ref_start, ref_duration).await {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
span.set_status(Status::error("audio decode failed"));
|
span.set_status(Status::error("audio decode failed"));
|
||||||
@@ -475,6 +1027,7 @@ pub async fn create_voice_from_library_handler(
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(v) => {
|
Ok(v) => {
|
||||||
|
invalidate_voices_cache();
|
||||||
span.set_status(Status::Ok);
|
span.set_status(Status::Ok);
|
||||||
HttpResponse::Ok().json(v)
|
HttpResponse::Ok().json(v)
|
||||||
}
|
}
|
||||||
@@ -489,8 +1042,8 @@ pub async fn create_voice_from_library_handler(
|
|||||||
/// Read a library file (audio or video) as a Chatterbox-ready reference: ffmpeg
|
/// Read a library file (audio or video) as a Chatterbox-ready reference: ffmpeg
|
||||||
/// decodes/extracts its audio to mono 24 kHz WAV. Reading straight from the
|
/// decodes/extracts its audio to mono 24 kHz WAV. Reading straight from the
|
||||||
/// library path avoids slurping a (possibly large) video into memory.
|
/// library path avoids slurping a (possibly large) video into memory.
|
||||||
async fn prepare_reference_audio(abs: &Path) -> anyhow::Result<Vec<u8>> {
|
async fn prepare_reference_audio(abs: &Path, start: f64, duration: f64) -> anyhow::Result<Vec<u8>> {
|
||||||
run_ffmpeg_to_wav(&abs.to_string_lossy()).await
|
run_ffmpeg_to_wav(&abs.to_string_lossy(), start, duration).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -534,6 +1087,151 @@ mod tests {
|
|||||||
assert_eq!(sanitize_voice_name(&long).unwrap().len(), 64);
|
assert_eq!(sanitize_voice_name(&long).unwrap().len(), 64);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn append_ref_window_tags_name() {
|
||||||
|
assert_eq!(append_ref_window("grandma", 0.0, 30), "grandma-30s");
|
||||||
|
assert_eq!(append_ref_window("voice_01", 0.0, 15), "voice_01-15s");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn append_ref_window_includes_nonzero_start() {
|
||||||
|
// Sub-minute starts stay in seconds; longer ones read as XmYYs since
|
||||||
|
// ':' isn't allowed in voice names.
|
||||||
|
assert_eq!(append_ref_window("grandma", 45.0, 30), "grandma-at45s-30s");
|
||||||
|
assert_eq!(
|
||||||
|
append_ref_window("grandma", 92.4, 30),
|
||||||
|
"grandma-at1m32s-30s"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
append_ref_window("grandma", 600.0, 12),
|
||||||
|
"grandma-at10m00s-12s"
|
||||||
|
);
|
||||||
|
// A start that rounds to zero is "from the start".
|
||||||
|
assert_eq!(append_ref_window("grandma", 0.3, 30), "grandma-30s");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn append_ref_window_is_idempotent_for_same_window() {
|
||||||
|
assert_eq!(append_ref_window("grandma-30s", 0.0, 30), "grandma-30s");
|
||||||
|
assert_eq!(
|
||||||
|
append_ref_window("grandma-at45s-30s", 45.0, 30),
|
||||||
|
"grandma-at45s-30s"
|
||||||
|
);
|
||||||
|
// A different window still appends — that's the comparison use-case.
|
||||||
|
assert_eq!(append_ref_window("grandma-15s", 0.0, 30), "grandma-15s-30s");
|
||||||
|
assert_eq!(
|
||||||
|
append_ref_window("grandma-30s", 45.0, 30),
|
||||||
|
"grandma-30s-at45s-30s"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn append_ref_window_keeps_64_char_bound() {
|
||||||
|
let long = "a".repeat(64);
|
||||||
|
let tagged = append_ref_window(&long, 0.0, 30);
|
||||||
|
assert_eq!(tagged.len(), 64);
|
||||||
|
assert!(tagged.ends_with("-30s"));
|
||||||
|
|
||||||
|
let tagged = append_ref_window(&long, 92.0, 30);
|
||||||
|
assert_eq!(tagged.len(), 64);
|
||||||
|
assert!(tagged.ends_with("-at1m32s-30s"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_ref_window_defaults_to_start_of_clip_at_cap_length() {
|
||||||
|
// Reads the live cap rather than mutating LLAMA_SWAP_TTS_REF_SECONDS:
|
||||||
|
// env mutation flakes under the parallel suite (see env_dispatch).
|
||||||
|
let cap = f64::from(tts_ref_seconds());
|
||||||
|
assert_eq!(resolve_ref_window(None, None), Ok((0.0, cap)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_ref_window_accepts_offset_and_clamps_duration() {
|
||||||
|
let cap = f64::from(tts_ref_seconds());
|
||||||
|
assert_eq!(resolve_ref_window(Some(92.5), None), Ok((92.5, cap)));
|
||||||
|
assert_eq!(resolve_ref_window(Some(10.0), Some(12.0)), Ok((10.0, 12.0)));
|
||||||
|
// Longer-than-cap windows are bounded, not rejected.
|
||||||
|
assert_eq!(resolve_ref_window(None, Some(cap + 100.0)), Ok((0.0, cap)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_ref_window_rejects_garbage() {
|
||||||
|
assert!(resolve_ref_window(Some(-1.0), None).is_err());
|
||||||
|
assert!(resolve_ref_window(Some(f64::NAN), None).is_err());
|
||||||
|
assert!(resolve_ref_window(Some(f64::INFINITY), None).is_err());
|
||||||
|
assert!(resolve_ref_window(None, Some(0.0)).is_err());
|
||||||
|
assert!(resolve_ref_window(None, Some(-5.0)).is_err());
|
||||||
|
assert!(resolve_ref_window(None, Some(f64::NAN)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sweep_drops_expired_results_and_keeps_live_jobs() {
|
||||||
|
let now = Instant::now();
|
||||||
|
let mk = |status: TtsJobStatus, created: Instant, finished: Option<Instant>| TtsJob {
|
||||||
|
status,
|
||||||
|
format: "mp3".into(),
|
||||||
|
audio_base64: None,
|
||||||
|
error: None,
|
||||||
|
created_at: created,
|
||||||
|
finished_at: finished,
|
||||||
|
abort: None,
|
||||||
|
};
|
||||||
|
let mut jobs = HashMap::new();
|
||||||
|
let live = Uuid::new_v4();
|
||||||
|
let fresh_done = Uuid::new_v4();
|
||||||
|
let stale_done = Uuid::new_v4();
|
||||||
|
jobs.insert(live, mk(TtsJobStatus::Running, now, None));
|
||||||
|
jobs.insert(
|
||||||
|
fresh_done,
|
||||||
|
mk(TtsJobStatus::Done, now, Some(now - Duration::from_secs(60))),
|
||||||
|
);
|
||||||
|
jobs.insert(
|
||||||
|
stale_done,
|
||||||
|
mk(
|
||||||
|
TtsJobStatus::Done,
|
||||||
|
now - TTS_JOB_MAX_AGE / 2,
|
||||||
|
Some(now - TTS_JOB_RESULT_TTL),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
sweep_stale_jobs(&mut jobs, now);
|
||||||
|
assert!(jobs.contains_key(&live));
|
||||||
|
assert!(jobs.contains_key(&fresh_done));
|
||||||
|
assert!(!jobs.contains_key(&stale_done));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sweep_drops_jobs_past_max_age_even_if_unfinished() {
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut jobs = HashMap::new();
|
||||||
|
let ancient = Uuid::new_v4();
|
||||||
|
jobs.insert(
|
||||||
|
ancient,
|
||||||
|
TtsJob {
|
||||||
|
status: TtsJobStatus::Running,
|
||||||
|
format: "mp3".into(),
|
||||||
|
audio_base64: None,
|
||||||
|
error: None,
|
||||||
|
created_at: now - TTS_JOB_MAX_AGE,
|
||||||
|
finished_at: None,
|
||||||
|
abort: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
sweep_stale_jobs(&mut jobs, now);
|
||||||
|
assert!(jobs.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn voices_cache_roundtrip_and_invalidation() {
|
||||||
|
invalidate_voices_cache();
|
||||||
|
assert!(cached_voices().is_none());
|
||||||
|
let v = json!({ "voices": [{ "name": "m-30s" }], "count": 1 });
|
||||||
|
store_voices_cache(&v);
|
||||||
|
assert_eq!(cached_voices(), Some(v));
|
||||||
|
invalidate_voices_cache();
|
||||||
|
assert!(cached_voices().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn clean_for_tts_strips_markdown() {
|
fn clean_for_tts_strips_markdown() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use image_api::ai::ollama::OllamaClient;
|
use image_api::ai::LocalLlm;
|
||||||
use image_api::bin_progress;
|
use image_api::bin_progress;
|
||||||
use image_api::database::calendar_dao::{InsertCalendarEvent, SqliteCalendarEventDao};
|
use image_api::database::calendar_dao::{InsertCalendarEvent, SqliteCalendarEventDao};
|
||||||
use image_api::parsers::ical_parser::parse_ics_file;
|
use image_api::parsers::ical_parser::parse_ics_file;
|
||||||
@@ -44,22 +44,10 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let context = opentelemetry::Context::current();
|
let context = opentelemetry::Context::current();
|
||||||
|
|
||||||
let ollama = if args.generate_embeddings {
|
// LocalLlm dispatches per LLM_BACKEND, so embeddings written here land
|
||||||
let primary_url = dotenv::var("OLLAMA_PRIMARY_URL")
|
// in the same vector space the query side searches.
|
||||||
.or_else(|_| dotenv::var("OLLAMA_URL"))
|
let llm = if args.generate_embeddings {
|
||||||
.unwrap_or_else(|_| "http://localhost:11434".to_string());
|
Some(LocalLlm::from_env())
|
||||||
let fallback_url = dotenv::var("OLLAMA_FALLBACK_URL").ok();
|
|
||||||
let primary_model = dotenv::var("OLLAMA_PRIMARY_MODEL")
|
|
||||||
.or_else(|_| dotenv::var("OLLAMA_MODEL"))
|
|
||||||
.unwrap_or_else(|_| "nomic-embed-text:v1.5".to_string());
|
|
||||||
let fallback_model = dotenv::var("OLLAMA_FALLBACK_MODEL").ok();
|
|
||||||
|
|
||||||
Some(OllamaClient::new(
|
|
||||||
primary_url,
|
|
||||||
fallback_url,
|
|
||||||
primary_model,
|
|
||||||
fallback_model,
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@@ -90,7 +78,7 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate embedding if requested (blocking call)
|
// Generate embedding if requested (blocking call)
|
||||||
let embedding = if let Some(ref ollama_client) = ollama {
|
let embedding = if let Some(ref llm) = llm {
|
||||||
let text = format!(
|
let text = format!(
|
||||||
"{} {} {}",
|
"{} {} {}",
|
||||||
event.summary,
|
event.summary,
|
||||||
@@ -100,7 +88,7 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
match tokio::task::block_in_place(|| {
|
match tokio::task::block_in_place(|| {
|
||||||
tokio::runtime::Handle::current()
|
tokio::runtime::Handle::current()
|
||||||
.block_on(async { ollama_client.generate_embedding(&text).await })
|
.block_on(async { llm.embed_document(&text).await })
|
||||||
}) {
|
}) {
|
||||||
Ok(emb) => Some(emb),
|
Ok(emb) => Some(emb),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use image_api::ai::ollama::OllamaClient;
|
use image_api::ai::LocalLlm;
|
||||||
use image_api::bin_progress;
|
use image_api::bin_progress;
|
||||||
use image_api::database::search_dao::{InsertSearchRecord, SqliteSearchHistoryDao};
|
use image_api::database::search_dao::{InsertSearchRecord, SqliteSearchHistoryDao};
|
||||||
use image_api::parsers::search_html_parser::parse_search_html;
|
use image_api::parsers::search_html_parser::parse_search_html;
|
||||||
@@ -38,16 +38,9 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
info!("Found {} search records", searches.len());
|
info!("Found {} search records", searches.len());
|
||||||
|
|
||||||
let primary_url = dotenv::var("OLLAMA_PRIMARY_URL")
|
// LocalLlm dispatches per LLM_BACKEND, so embeddings written here land
|
||||||
.or_else(|_| dotenv::var("OLLAMA_URL"))
|
// in the same vector space the query side searches.
|
||||||
.unwrap_or_else(|_| "http://localhost:11434".to_string());
|
let llm = LocalLlm::from_env();
|
||||||
let fallback_url = dotenv::var("OLLAMA_FALLBACK_URL").ok();
|
|
||||||
let primary_model = dotenv::var("OLLAMA_PRIMARY_MODEL")
|
|
||||||
.or_else(|_| dotenv::var("OLLAMA_MODEL"))
|
|
||||||
.unwrap_or_else(|_| "nomic-embed-text:v1.5".to_string());
|
|
||||||
let fallback_model = dotenv::var("OLLAMA_FALLBACK_MODEL").ok();
|
|
||||||
|
|
||||||
let ollama = OllamaClient::new(primary_url, fallback_url, primary_model, fallback_model);
|
|
||||||
let context = opentelemetry::Context::current();
|
let context = opentelemetry::Context::current();
|
||||||
|
|
||||||
let mut inserted_count = 0usize;
|
let mut inserted_count = 0usize;
|
||||||
@@ -67,12 +60,11 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let pb_for_warn = pb.clone();
|
let pb_for_warn = pb.clone();
|
||||||
let embeddings_result = tokio::task::spawn({
|
let embeddings_result = tokio::task::spawn({
|
||||||
let ollama_client = ollama.clone();
|
let llm = llm.clone();
|
||||||
async move {
|
async move {
|
||||||
// Generate embeddings in parallel for the batch
|
|
||||||
let mut embeddings = Vec::new();
|
let mut embeddings = Vec::new();
|
||||||
for query in &queries {
|
for query in &queries {
|
||||||
match ollama_client.generate_embedding(query).await {
|
match llm.embed_document(query).await {
|
||||||
Ok(emb) => embeddings.push(Some(emb)),
|
Ok(emb) => embeddings.push(Some(emb)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
pb_for_warn.println(format!("embedding failed for '{}': {}", query, e));
|
pb_for_warn.println(format!("embedding failed for '{}': {}", query, e));
|
||||||
|
|||||||
@@ -0,0 +1,465 @@
|
|||||||
|
//! Re-embed stored corpora through `LocalLlm`, i.e. the same
|
||||||
|
//! `LLM_BACKEND` dispatch the query side uses. The original import /
|
||||||
|
//! backfill tools always embedded via Ollama, so a deploy running
|
||||||
|
//! `LLM_BACKEND=llamacpp` queries vector spaces the corpora may not live
|
||||||
|
//! in. Three tables share the problem and are all covered here:
|
||||||
|
//!
|
||||||
|
//! - `daily_conversation_summaries` — re-embeds
|
||||||
|
//! `strip_summary_boilerplate(summary)` (what the original job fed the
|
||||||
|
//! embedder); also rewrites `model_version`.
|
||||||
|
//! - `calendar_events` — re-embeds "summary description location" exactly
|
||||||
|
//! as `import_calendar` does; rows without an embedding are skipped (the
|
||||||
|
//! import only embeds under `--generate-embeddings`).
|
||||||
|
//! - `search_history` — re-embeds the raw query text.
|
||||||
|
//! - `entities` (knowledge graph) — re-embeds "name description" exactly as
|
||||||
|
//! `tool_store_entity` does; embedding-less rows are skipped (embedding
|
||||||
|
//! is best-effort at store time).
|
||||||
|
//!
|
||||||
|
//! Source text is untouched — only vectors are rewritten. The old↔new
|
||||||
|
//! cosine report doubles as a diagnostic: ~1.0 means both backends already
|
||||||
|
//! shared a space (re-embedding was a no-op); low values confirm the
|
||||||
|
//! mismatch this tool exists to fix.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
use diesel::prelude::*;
|
||||||
|
use diesel::sql_query;
|
||||||
|
use diesel::sqlite::SqliteConnection;
|
||||||
|
use image_api::ai::{LocalLlm, strip_summary_boilerplate};
|
||||||
|
use image_api::bin_progress;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about = "Re-embed stored corpora via the configured LLM_BACKEND", long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Comma-separated tables to process: summaries, calendar, search, entities
|
||||||
|
#[arg(long, default_value = "summaries,calendar,search,entities")]
|
||||||
|
tables: String,
|
||||||
|
|
||||||
|
/// Only process the first N rows per table (smoke test)
|
||||||
|
#[arg(long)]
|
||||||
|
limit: Option<usize>,
|
||||||
|
|
||||||
|
/// Compute embeddings and report old↔new similarity without writing
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
dry_run: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(QueryableByName)]
|
||||||
|
struct SummaryRow {
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
||||||
|
id: i32,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||||
|
summary: String,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Binary)]
|
||||||
|
embedding: Vec<u8>,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||||
|
model_version: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(QueryableByName)]
|
||||||
|
struct CalendarRow {
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
||||||
|
id: i32,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||||
|
summary: String,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
||||||
|
description: Option<String>,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
||||||
|
location: Option<String>,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Binary)]
|
||||||
|
embedding: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(QueryableByName)]
|
||||||
|
struct SearchRow {
|
||||||
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||||
|
id: i64,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||||
|
query: String,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Binary)]
|
||||||
|
embedding: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(QueryableByName)]
|
||||||
|
struct EntityRow {
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
||||||
|
id: i32,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||||
|
name: String,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||||
|
description: String,
|
||||||
|
#[diesel(sql_type = diesel::sql_types::Binary)]
|
||||||
|
embedding: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One unit of re-embed work, normalized across tables.
|
||||||
|
struct WorkItem {
|
||||||
|
/// Row key, as i64 so both i32 ids and rowids fit.
|
||||||
|
id: i64,
|
||||||
|
/// Text fed to the embedder — must match what the original writer used.
|
||||||
|
text: String,
|
||||||
|
/// Existing vector bytes, for the old↔new similarity report.
|
||||||
|
old_embedding: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize_vector(bytes: &[u8]) -> Option<Vec<f32>> {
|
||||||
|
if !bytes.len().is_multiple_of(4) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(
|
||||||
|
bytes
|
||||||
|
.chunks_exact(4)
|
||||||
|
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn serialize_vector(vec: &[f32]) -> Vec<u8> {
|
||||||
|
vec.iter().flat_map(|f| f.to_le_bytes()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||||
|
if a.len() != b.len() {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||||
|
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
if mag_a == 0.0 || mag_b == 0.0 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
dot / (mag_a * mag_b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embed `text`, halving it on "input too large" errors until it fits the
|
||||||
|
/// server's physical batch (`--ubatch-size`). Mirrors the silent truncation
|
||||||
|
/// Ollama applied when these corpora were first embedded — llama-server
|
||||||
|
/// returns a 500 instead — except here it's surfaced via the returned flag.
|
||||||
|
/// Returns `(embedding, truncated)`.
|
||||||
|
async fn embed_with_truncation(llm: &LocalLlm, text: &str) -> Result<(Vec<f32>, bool)> {
|
||||||
|
let mut text = text.to_string();
|
||||||
|
let mut truncated = false;
|
||||||
|
loop {
|
||||||
|
match llm.embed_document(&text).await {
|
||||||
|
Ok(emb) => return Ok((emb, truncated)),
|
||||||
|
Err(e)
|
||||||
|
if e.to_string().contains("too large to process") && text.chars().count() > 64 =>
|
||||||
|
{
|
||||||
|
let keep = text.chars().count() / 2;
|
||||||
|
text = text.chars().take(keep).collect();
|
||||||
|
truncated = true;
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Re-embed `items`, writing each new vector via `update`. Returns the
|
||||||
|
/// old↔new cosines for the similarity report.
|
||||||
|
async fn reembed_table(
|
||||||
|
conn: &mut SqliteConnection,
|
||||||
|
llm: &LocalLlm,
|
||||||
|
label: &str,
|
||||||
|
items: Vec<WorkItem>,
|
||||||
|
dry_run: bool,
|
||||||
|
update: impl Fn(&mut SqliteConnection, i64, Vec<u8>) -> Result<()>,
|
||||||
|
) -> Result<Vec<f32>> {
|
||||||
|
println!("\n[{}] re-embedding {} rows...", label, items.len());
|
||||||
|
let pb = bin_progress::determinate(items.len() as u64, format!("re-embedding {}", label));
|
||||||
|
|
||||||
|
let mut sims: Vec<f32> = Vec::with_capacity(items.len());
|
||||||
|
let mut updated = 0usize;
|
||||||
|
let mut failed = 0usize;
|
||||||
|
let mut truncated_count = 0usize;
|
||||||
|
|
||||||
|
for item in &items {
|
||||||
|
let new_emb = match embed_with_truncation(llm, &item.text).await {
|
||||||
|
Ok((e, truncated)) => {
|
||||||
|
if truncated {
|
||||||
|
truncated_count += 1;
|
||||||
|
pb.println(format!(
|
||||||
|
"⚠ {} id={}: input exceeded the embed server's batch size, \
|
||||||
|
truncated before embedding",
|
||||||
|
label, item.id
|
||||||
|
));
|
||||||
|
}
|
||||||
|
e
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
pb.inc(1);
|
||||||
|
failed += 1;
|
||||||
|
eprintln!("✗ {} id={}: {}", label, item.id, e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// The whole pipeline (DAO checks, stored corpora) assumes
|
||||||
|
// EMBEDDING_DIM dims. A mismatch means the active embed slot is not
|
||||||
|
// serving the configured model — stop rather than corrupt the table.
|
||||||
|
anyhow::ensure!(
|
||||||
|
new_emb.len() == image_api::ai::embedding_dim(),
|
||||||
|
"backend returned {}-dim embedding (expected {}) — '{}' does not \
|
||||||
|
match the configured EMBEDDING_DIM",
|
||||||
|
new_emb.len(),
|
||||||
|
image_api::ai::embedding_dim(),
|
||||||
|
llm.embedding_model_version()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(old_emb) = deserialize_vector(&item.old_embedding) {
|
||||||
|
sims.push(cosine_similarity(&old_emb, &new_emb));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dry_run {
|
||||||
|
update(conn, item.id, serialize_vector(&new_emb))
|
||||||
|
.with_context(|| format!("updating {} id={}", label, item.id))?;
|
||||||
|
}
|
||||||
|
updated += 1;
|
||||||
|
pb.inc(1);
|
||||||
|
}
|
||||||
|
pb.finish_and_clear();
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"[{}] {} re-embedded ({} truncated), {} failed",
|
||||||
|
label, updated, truncated_count, failed
|
||||||
|
);
|
||||||
|
Ok(sims)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn report_similarity(label: &str, mut sims: Vec<f32>) {
|
||||||
|
if sims.is_empty() {
|
||||||
|
println!("[{}] no old↔new pairs to compare", label);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sims.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
let mean: f32 = sims.iter().sum::<f32>() / sims.len() as f32;
|
||||||
|
let median = sims[sims.len() / 2];
|
||||||
|
println!(
|
||||||
|
"[{}] old↔new cosine over identical text: min={:.3} median={:.3} mean={:.3} max={:.3}",
|
||||||
|
label,
|
||||||
|
sims.first().unwrap(),
|
||||||
|
median,
|
||||||
|
mean,
|
||||||
|
sims.last().unwrap()
|
||||||
|
);
|
||||||
|
if median > 0.98 {
|
||||||
|
println!(
|
||||||
|
"[{}] → old and new backends agree (~same vector space); poor search \
|
||||||
|
results are coming from something else (prefixes, thresholds, corpus).",
|
||||||
|
label
|
||||||
|
);
|
||||||
|
} else if median > 0.9 {
|
||||||
|
println!(
|
||||||
|
"[{}] → same model family but measurably different vectors \
|
||||||
|
(quantization / runtime drift); re-embedding was worthwhile.",
|
||||||
|
label
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
println!(
|
||||||
|
"[{}] → vector-space mismatch confirmed — queries were searching a \
|
||||||
|
different space than the corpus. This re-embed should fix it.",
|
||||||
|
label
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
dotenv::dotenv().ok();
|
||||||
|
env_logger::init();
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let tables: Vec<&str> = args.tables.split(',').map(|t| t.trim()).collect();
|
||||||
|
for t in &tables {
|
||||||
|
anyhow::ensure!(
|
||||||
|
matches!(*t, "summaries" | "calendar" | "search" | "entities"),
|
||||||
|
"unknown table '{}' — expected summaries, calendar, search, entities",
|
||||||
|
t
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| "auth.db".to_string());
|
||||||
|
println!("Database: {}", database_url);
|
||||||
|
|
||||||
|
let mut conn = SqliteConnection::establish(&database_url)
|
||||||
|
.with_context(|| format!("connecting to {}", database_url))?;
|
||||||
|
|
||||||
|
let llm = LocalLlm::from_env();
|
||||||
|
let model_version = llm.embedding_model_version();
|
||||||
|
println!("Embedding via '{}'", model_version);
|
||||||
|
if args.dry_run {
|
||||||
|
println!("DRY RUN — no rows will be written");
|
||||||
|
}
|
||||||
|
|
||||||
|
if tables.contains(&"summaries") {
|
||||||
|
let mut rows: Vec<SummaryRow> = sql_query(
|
||||||
|
"SELECT id, summary, embedding, model_version
|
||||||
|
FROM daily_conversation_summaries ORDER BY date",
|
||||||
|
)
|
||||||
|
.load(&mut conn)
|
||||||
|
.context("loading daily summaries")?;
|
||||||
|
if let Some(limit) = args.limit {
|
||||||
|
rows.truncate(limit);
|
||||||
|
}
|
||||||
|
if let Some(first) = rows.first() {
|
||||||
|
println!(
|
||||||
|
"\n[summaries] previous model_version '{}' → '{}'",
|
||||||
|
first.model_version, model_version
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let items = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| WorkItem {
|
||||||
|
id: r.id as i64,
|
||||||
|
text: strip_summary_boilerplate(&r.summary),
|
||||||
|
old_embedding: r.embedding,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let mv = model_version.clone();
|
||||||
|
let sims = reembed_table(
|
||||||
|
&mut conn,
|
||||||
|
&llm,
|
||||||
|
"summaries",
|
||||||
|
items,
|
||||||
|
args.dry_run,
|
||||||
|
move |conn, id, emb| {
|
||||||
|
sql_query(
|
||||||
|
"UPDATE daily_conversation_summaries
|
||||||
|
SET embedding = ?1, model_version = ?2 WHERE id = ?3",
|
||||||
|
)
|
||||||
|
.bind::<diesel::sql_types::Binary, _>(emb)
|
||||||
|
.bind::<diesel::sql_types::Text, _>(&mv)
|
||||||
|
.bind::<diesel::sql_types::Integer, _>(id as i32)
|
||||||
|
.execute(conn)?;
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
report_similarity("summaries", sims);
|
||||||
|
}
|
||||||
|
|
||||||
|
if tables.contains(&"calendar") {
|
||||||
|
let mut rows: Vec<CalendarRow> = sql_query(
|
||||||
|
"SELECT id, summary, description, location, embedding
|
||||||
|
FROM calendar_events WHERE embedding IS NOT NULL ORDER BY id",
|
||||||
|
)
|
||||||
|
.load(&mut conn)
|
||||||
|
.context("loading calendar events")?;
|
||||||
|
if let Some(limit) = args.limit {
|
||||||
|
rows.truncate(limit);
|
||||||
|
}
|
||||||
|
let items = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| WorkItem {
|
||||||
|
id: r.id as i64,
|
||||||
|
// Same text construction as import_calendar.
|
||||||
|
text: format!(
|
||||||
|
"{} {} {}",
|
||||||
|
r.summary,
|
||||||
|
r.description.as_deref().unwrap_or(""),
|
||||||
|
r.location.as_deref().unwrap_or("")
|
||||||
|
),
|
||||||
|
old_embedding: r.embedding,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let sims = reembed_table(
|
||||||
|
&mut conn,
|
||||||
|
&llm,
|
||||||
|
"calendar",
|
||||||
|
items,
|
||||||
|
args.dry_run,
|
||||||
|
|conn, id, emb| {
|
||||||
|
sql_query("UPDATE calendar_events SET embedding = ?1 WHERE id = ?2")
|
||||||
|
.bind::<diesel::sql_types::Binary, _>(emb)
|
||||||
|
.bind::<diesel::sql_types::Integer, _>(id as i32)
|
||||||
|
.execute(conn)?;
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
report_similarity("calendar", sims);
|
||||||
|
}
|
||||||
|
|
||||||
|
if tables.contains(&"search") {
|
||||||
|
let mut rows: Vec<SearchRow> = sql_query(
|
||||||
|
"SELECT rowid AS id, query, embedding
|
||||||
|
FROM search_history ORDER BY rowid",
|
||||||
|
)
|
||||||
|
.load(&mut conn)
|
||||||
|
.context("loading search history")?;
|
||||||
|
if let Some(limit) = args.limit {
|
||||||
|
rows.truncate(limit);
|
||||||
|
}
|
||||||
|
let items = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| WorkItem {
|
||||||
|
id: r.id,
|
||||||
|
text: r.query,
|
||||||
|
old_embedding: r.embedding,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let sims = reembed_table(
|
||||||
|
&mut conn,
|
||||||
|
&llm,
|
||||||
|
"search",
|
||||||
|
items,
|
||||||
|
args.dry_run,
|
||||||
|
|conn, id, emb| {
|
||||||
|
sql_query("UPDATE search_history SET embedding = ?1 WHERE rowid = ?2")
|
||||||
|
.bind::<diesel::sql_types::Binary, _>(emb)
|
||||||
|
.bind::<diesel::sql_types::BigInt, _>(id)
|
||||||
|
.execute(conn)?;
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
report_similarity("search", sims);
|
||||||
|
}
|
||||||
|
|
||||||
|
if tables.contains(&"entities") {
|
||||||
|
let mut rows: Vec<EntityRow> = sql_query(
|
||||||
|
"SELECT id, name, description, embedding
|
||||||
|
FROM entities WHERE embedding IS NOT NULL ORDER BY id",
|
||||||
|
)
|
||||||
|
.load(&mut conn)
|
||||||
|
.context("loading knowledge entities")?;
|
||||||
|
if let Some(limit) = args.limit {
|
||||||
|
rows.truncate(limit);
|
||||||
|
}
|
||||||
|
let items = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| WorkItem {
|
||||||
|
id: r.id as i64,
|
||||||
|
// Same text construction as tool_store_entity.
|
||||||
|
text: format!("{} {}", r.name, r.description),
|
||||||
|
old_embedding: r.embedding,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let sims = reembed_table(
|
||||||
|
&mut conn,
|
||||||
|
&llm,
|
||||||
|
"entities",
|
||||||
|
items,
|
||||||
|
args.dry_run,
|
||||||
|
|conn, id, emb| {
|
||||||
|
sql_query("UPDATE entities SET embedding = ?1 WHERE id = ?2")
|
||||||
|
.bind::<diesel::sql_types::Binary, _>(emb)
|
||||||
|
.bind::<diesel::sql_types::Integer, _>(id as i32)
|
||||||
|
.execute(conn)?;
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
report_similarity("entities", sims);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"\n{}",
|
||||||
|
if args.dry_run {
|
||||||
|
"Dry run complete"
|
||||||
|
} else {
|
||||||
|
"Done"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -222,11 +222,12 @@ impl CalendarEventDao for SqliteCalendarEventDao {
|
|||||||
|
|
||||||
// Validate embedding dimensions if provided
|
// Validate embedding dimensions if provided
|
||||||
if let Some(ref emb) = event.embedding
|
if let Some(ref emb) = event.embedding
|
||||||
&& emb.len() != 768
|
&& emb.len() != crate::ai::embedding_dim()
|
||||||
{
|
{
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid embedding dimensions: {} (expected 768)",
|
"Invalid embedding dimensions: {} (expected {})",
|
||||||
emb.len()
|
emb.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,7 +294,7 @@ impl CalendarEventDao for SqliteCalendarEventDao {
|
|||||||
for event in events {
|
for event in events {
|
||||||
// Validate embedding if provided
|
// Validate embedding if provided
|
||||||
if let Some(ref emb) = event.embedding
|
if let Some(ref emb) = event.embedding
|
||||||
&& emb.len() != 768
|
&& emb.len() != crate::ai::embedding_dim()
|
||||||
{
|
{
|
||||||
log::warn!(
|
log::warn!(
|
||||||
"Skipping event with invalid embedding dimensions: {}",
|
"Skipping event with invalid embedding dimensions: {}",
|
||||||
@@ -385,10 +386,11 @@ impl CalendarEventDao for SqliteCalendarEventDao {
|
|||||||
trace_db_call(context, "query", "find_similar_events", |_span| {
|
trace_db_call(context, "query", "find_similar_events", |_span| {
|
||||||
let mut conn = self.connection.lock().expect("Unable to get CalendarEventDao");
|
let mut conn = self.connection.lock().expect("Unable to get CalendarEventDao");
|
||||||
|
|
||||||
if query_embedding.len() != 768 {
|
if query_embedding.len() != crate::ai::embedding_dim() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid query embedding dimensions: {} (expected 768)",
|
"Invalid query embedding dimensions: {} (expected {})",
|
||||||
query_embedding.len()
|
query_embedding.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -461,10 +463,11 @@ impl CalendarEventDao for SqliteCalendarEventDao {
|
|||||||
|
|
||||||
// Step 2: If query embedding provided, rank by semantic similarity
|
// Step 2: If query embedding provided, rank by semantic similarity
|
||||||
if let Some(query_emb) = query_embedding {
|
if let Some(query_emb) = query_embedding {
|
||||||
if query_emb.len() != 768 {
|
if query_emb.len() != crate::ai::embedding_dim() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid query embedding dimensions: {} (expected 768)",
|
"Invalid query embedding dimensions: {} (expected {})",
|
||||||
query_emb.len()
|
query_emb.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -150,10 +150,11 @@ impl DailySummaryDao for SqliteDailySummaryDao {
|
|||||||
.expect("Unable to get DailySummaryDao");
|
.expect("Unable to get DailySummaryDao");
|
||||||
|
|
||||||
// Validate embedding dimensions
|
// Validate embedding dimensions
|
||||||
if summary.embedding.len() != 768 {
|
if summary.embedding.len() != crate::ai::embedding_dim() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid embedding dimensions: {} (expected 768)",
|
"Invalid embedding dimensions: {} (expected {})",
|
||||||
summary.embedding.len()
|
summary.embedding.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,10 +203,11 @@ impl DailySummaryDao for SqliteDailySummaryDao {
|
|||||||
trace_db_call(context, "query", "find_similar_summaries", |_span| {
|
trace_db_call(context, "query", "find_similar_summaries", |_span| {
|
||||||
let mut conn = self.connection.lock().expect("Unable to get DailySummaryDao");
|
let mut conn = self.connection.lock().expect("Unable to get DailySummaryDao");
|
||||||
|
|
||||||
if query_embedding.len() != 768 {
|
if query_embedding.len() != crate::ai::embedding_dim() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid query embedding dimensions: {} (expected 768)",
|
"Invalid query embedding dimensions: {} (expected {})",
|
||||||
query_embedding.len()
|
query_embedding.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -299,10 +301,11 @@ impl DailySummaryDao for SqliteDailySummaryDao {
|
|||||||
trace_db_call(context, "query", "find_similar_summaries_with_time_weight", |_span| {
|
trace_db_call(context, "query", "find_similar_summaries_with_time_weight", |_span| {
|
||||||
let mut conn = self.connection.lock().expect("Unable to get DailySummaryDao");
|
let mut conn = self.connection.lock().expect("Unable to get DailySummaryDao");
|
||||||
|
|
||||||
if query_embedding.len() != 768 {
|
if query_embedding.len() != crate::ai::embedding_dim() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid query embedding dimensions: {} (expected 768)",
|
"Invalid query embedding dimensions: {} (expected {})",
|
||||||
query_embedding.len()
|
query_embedding.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -216,11 +216,12 @@ impl LocationHistoryDao for SqliteLocationHistoryDao {
|
|||||||
|
|
||||||
// Validate embedding dimensions if provided (rare for location data)
|
// Validate embedding dimensions if provided (rare for location data)
|
||||||
if let Some(ref emb) = location.embedding
|
if let Some(ref emb) = location.embedding
|
||||||
&& emb.len() != 768
|
&& emb.len() != crate::ai::embedding_dim()
|
||||||
{
|
{
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid embedding dimensions: {} (expected 768)",
|
"Invalid embedding dimensions: {} (expected {})",
|
||||||
emb.len()
|
emb.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,7 +293,7 @@ impl LocationHistoryDao for SqliteLocationHistoryDao {
|
|||||||
for location in locations {
|
for location in locations {
|
||||||
// Validate embedding if provided (rare)
|
// Validate embedding if provided (rare)
|
||||||
if let Some(ref emb) = location.embedding
|
if let Some(ref emb) = location.embedding
|
||||||
&& emb.len() != 768
|
&& emb.len() != crate::ai::embedding_dim()
|
||||||
{
|
{
|
||||||
log::warn!(
|
log::warn!(
|
||||||
"Skipping location with invalid embedding dimensions: {}",
|
"Skipping location with invalid embedding dimensions: {}",
|
||||||
|
|||||||
+13
-10
@@ -189,10 +189,11 @@ impl SearchHistoryDao for SqliteSearchHistoryDao {
|
|||||||
.expect("Unable to get SearchHistoryDao");
|
.expect("Unable to get SearchHistoryDao");
|
||||||
|
|
||||||
// Validate embedding dimensions (REQUIRED for searches)
|
// Validate embedding dimensions (REQUIRED for searches)
|
||||||
if search.embedding.len() != 768 {
|
if search.embedding.len() != crate::ai::embedding_dim() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid embedding dimensions: {} (expected 768)",
|
"Invalid embedding dimensions: {} (expected {})",
|
||||||
search.embedding.len()
|
search.embedding.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,7 +246,7 @@ impl SearchHistoryDao for SqliteSearchHistoryDao {
|
|||||||
conn.transaction::<_, anyhow::Error, _>(|conn| {
|
conn.transaction::<_, anyhow::Error, _>(|conn| {
|
||||||
for search in searches {
|
for search in searches {
|
||||||
// Validate embedding (REQUIRED)
|
// Validate embedding (REQUIRED)
|
||||||
if search.embedding.len() != 768 {
|
if search.embedding.len() != crate::ai::embedding_dim() {
|
||||||
log::warn!(
|
log::warn!(
|
||||||
"Skipping search with invalid embedding dimensions: {}",
|
"Skipping search with invalid embedding dimensions: {}",
|
||||||
search.embedding.len()
|
search.embedding.len()
|
||||||
@@ -325,10 +326,11 @@ impl SearchHistoryDao for SqliteSearchHistoryDao {
|
|||||||
.lock()
|
.lock()
|
||||||
.expect("Unable to get SearchHistoryDao");
|
.expect("Unable to get SearchHistoryDao");
|
||||||
|
|
||||||
if query_embedding.len() != 768 {
|
if query_embedding.len() != crate::ai::embedding_dim() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid query embedding dimensions: {} (expected 768)",
|
"Invalid query embedding dimensions: {} (expected {})",
|
||||||
query_embedding.len()
|
query_embedding.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -406,10 +408,11 @@ impl SearchHistoryDao for SqliteSearchHistoryDao {
|
|||||||
|
|
||||||
// Step 2: If query embedding provided, rank by semantic similarity
|
// Step 2: If query embedding provided, rank by semantic similarity
|
||||||
if let Some(query_emb) = query_embedding {
|
if let Some(query_emb) = query_embedding {
|
||||||
if query_emb.len() != 768 {
|
if query_emb.len() != crate::ai::embedding_dim() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Invalid query embedding dimensions: {} (expected 768)",
|
"Invalid query embedding dimensions: {} (expected {})",
|
||||||
query_emb.len()
|
query_emb.len(),
|
||||||
|
crate::ai::embedding_dim()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -364,9 +364,13 @@ fn main() -> std::io::Result<()> {
|
|||||||
.service(ai::rate_insight_handler)
|
.service(ai::rate_insight_handler)
|
||||||
.service(ai::export_training_data_handler)
|
.service(ai::export_training_data_handler)
|
||||||
.service(ai::tts_speech_handler)
|
.service(ai::tts_speech_handler)
|
||||||
|
.service(ai::create_speech_job_handler)
|
||||||
|
.service(ai::speech_job_status_handler)
|
||||||
|
.service(ai::cancel_speech_job_handler)
|
||||||
.service(ai::list_voices_handler)
|
.service(ai::list_voices_handler)
|
||||||
.service(ai::create_voice_upload_handler)
|
.service(ai::create_voice_upload_handler)
|
||||||
.service(ai::create_voice_from_library_handler)
|
.service(ai::create_voice_from_library_handler)
|
||||||
|
.service(ai::delete_voice_handler)
|
||||||
.service(libraries::list_libraries)
|
.service(libraries::list_libraries)
|
||||||
.service(libraries::patch_library)
|
.service(libraries::patch_library)
|
||||||
.add_feature(add_tag_services::<_, SqliteTagDao>)
|
.add_feature(add_tag_services::<_, SqliteTagDao>)
|
||||||
|
|||||||
+18
-16
@@ -186,21 +186,7 @@ impl AppState {
|
|||||||
impl Default for AppState {
|
impl Default for AppState {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
// Initialize AI clients
|
// Initialize AI clients
|
||||||
let ollama_primary_url = env::var("OLLAMA_PRIMARY_URL").unwrap_or_else(|_| {
|
let ollama = build_ollama_from_env();
|
||||||
env::var("OLLAMA_URL").unwrap_or_else(|_| "http://localhost:11434".to_string())
|
|
||||||
});
|
|
||||||
let ollama_fallback_url = env::var("OLLAMA_FALLBACK_URL").ok();
|
|
||||||
let ollama_primary_model = env::var("OLLAMA_PRIMARY_MODEL")
|
|
||||||
.or_else(|_| env::var("OLLAMA_MODEL"))
|
|
||||||
.unwrap_or_else(|_| "nemotron-3-nano:30b".to_string());
|
|
||||||
let ollama_fallback_model = env::var("OLLAMA_FALLBACK_MODEL").ok();
|
|
||||||
|
|
||||||
let ollama = OllamaClient::new(
|
|
||||||
ollama_primary_url,
|
|
||||||
ollama_fallback_url,
|
|
||||||
ollama_primary_model,
|
|
||||||
ollama_fallback_model,
|
|
||||||
);
|
|
||||||
|
|
||||||
let openrouter = build_openrouter_from_env();
|
let openrouter = build_openrouter_from_env();
|
||||||
let openrouter_allowed_models = parse_openrouter_allowed_models();
|
let openrouter_allowed_models = parse_openrouter_allowed_models();
|
||||||
@@ -375,13 +361,29 @@ fn parse_openrouter_allowed_models() -> Vec<String> {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build the `OllamaClient` from environment variables — the canonical
|
||||||
|
/// `OLLAMA_*` wiring shared by the server (`AppState::default`) and the
|
||||||
|
/// standalone binaries (which predate this helper and used to copy it).
|
||||||
|
pub fn build_ollama_from_env() -> OllamaClient {
|
||||||
|
let primary_url = env::var("OLLAMA_PRIMARY_URL").unwrap_or_else(|_| {
|
||||||
|
env::var("OLLAMA_URL").unwrap_or_else(|_| "http://localhost:11434".to_string())
|
||||||
|
});
|
||||||
|
let fallback_url = env::var("OLLAMA_FALLBACK_URL").ok();
|
||||||
|
let primary_model = env::var("OLLAMA_PRIMARY_MODEL")
|
||||||
|
.or_else(|_| env::var("OLLAMA_MODEL"))
|
||||||
|
.unwrap_or_else(|_| "nemotron-3-nano:30b".to_string());
|
||||||
|
let fallback_model = env::var("OLLAMA_FALLBACK_MODEL").ok();
|
||||||
|
|
||||||
|
OllamaClient::new(primary_url, fallback_url, primary_model, fallback_model)
|
||||||
|
}
|
||||||
|
|
||||||
/// Build a `LlamaCppClient` from environment variables. Returns `None` when
|
/// Build a `LlamaCppClient` from environment variables. Returns `None` when
|
||||||
/// `LLAMA_SWAP_URL` is unset. The client is constructed unconditionally
|
/// `LLAMA_SWAP_URL` is unset. The client is constructed unconditionally
|
||||||
/// when the URL is set (so it's available even under `LLM_BACKEND=ollama`
|
/// when the URL is set (so it's available even under `LLM_BACKEND=ollama`
|
||||||
/// for ad-hoc tooling), but the agentic / chat paths only route through it
|
/// for ad-hoc tooling), but the agentic / chat paths only route through it
|
||||||
/// when `LLM_BACKEND=llamacpp`. Slot ids default to the names the bundled
|
/// when `LLM_BACKEND=llamacpp`. Slot ids default to the names the bundled
|
||||||
/// `llama-swap/config.yaml` uses — `chat` / `vision` / `embed`.
|
/// `llama-swap/config.yaml` uses — `chat` / `vision` / `embed`.
|
||||||
fn build_llamacpp_from_env() -> Option<Arc<LlamaCppClient>> {
|
pub fn build_llamacpp_from_env() -> Option<Arc<LlamaCppClient>> {
|
||||||
let base_url = env::var("LLAMA_SWAP_URL").ok()?;
|
let base_url = env::var("LLAMA_SWAP_URL").ok()?;
|
||||||
let primary_model = env::var("LLAMA_SWAP_PRIMARY_MODEL").ok();
|
let primary_model = env::var("LLAMA_SWAP_PRIMARY_MODEL").ok();
|
||||||
let mut client = LlamaCppClient::new(Some(base_url), primary_model);
|
let mut client = LlamaCppClient::new(Some(base_url), primary_model);
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"Worcester": "Wuster",
|
||||||
|
"Spokane": "Spo can",
|
||||||
|
"wsl": "W S L",
|
||||||
|
"sql": "sequel",
|
||||||
|
"api": "A P I",
|
||||||
|
"US": "U S",
|
||||||
|
"Dr.": "Doctor",
|
||||||
|
"St.": "Saint",
|
||||||
|
"blvd": "boulevard",
|
||||||
|
"vs.": "versus",
|
||||||
|
"etc.": "et cetera"
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user