diff --git a/src/ai/llamacpp.rs b/src/ai/llamacpp.rs index e2ba00d..afd7f1b 100644 --- a/src/ai/llamacpp.rs +++ b/src/ai/llamacpp.rs @@ -36,6 +36,7 @@ const DEFAULT_BASE_URL: &str = "http://localhost:9292/v1"; const DEFAULT_PRIMARY_MODEL: &str = "chat"; const DEFAULT_VISION_MODEL: &str = "vision"; const DEFAULT_EMBEDDING_MODEL: &str = "embed"; +const DEFAULT_TTS_MODEL: &str = "chatterbox"; const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 180; /// OpenAI-compatible client targeting a llama-swap proxy in front of one or @@ -54,6 +55,10 @@ pub struct LlamaCppClient { /// to `primary_model` so describe_image works out of the box; override /// via `LLAMA_SWAP_VISION_MODEL` for a dedicated vision slot. pub vision_model: String, + /// TTS model slot id (e.g. `"chatterbox"`). Routes `text_to_speech` and + /// is the `/upstream//voices` path segment for the voice library. + /// Override via `LLAMA_SWAP_TTS_MODEL`. + pub tts_model: String, num_ctx: Option, temperature: Option, top_p: Option, @@ -78,6 +83,7 @@ impl LlamaCppClient { primary_model: pm.clone(), embedding_model: DEFAULT_EMBEDDING_MODEL.to_string(), vision_model: pm, + tts_model: DEFAULT_TTS_MODEL.to_string(), num_ctx: None, temperature: None, top_p: None, @@ -111,6 +117,116 @@ impl LlamaCppClient { self.min_p = min_p; } + pub fn set_tts_model(&mut self, model: String) { + self.tts_model = model; + } + + // --- TTS (Chatterbox behind llama-swap) --------------------------------- + // + // Speech synthesis uses the OpenAI-compatible `{base_url}/audio/speech` + // endpoint (llama-swap routes by the `model` field). The voice *library* + // (list / create cloned voices) is NOT an OpenAI endpoint — it lives on the + // upstream server directly, reached via llama-swap's passthrough at + // `{swap_root}/upstream//voices`. + + /// Root of the llama-swap proxy: `base_url` with a trailing `/v1` removed. + /// The `/upstream/...` passthrough lives here, not under `/v1`. + fn swap_root(&self) -> &str { + let b = self.base_url.trim_end_matches('/'); + b.strip_suffix("/v1").unwrap_or(b) + } + + /// Synthesize speech for `input` in an optional named `voice`, returning + /// the raw audio bytes (format per `response_format`, e.g. `"mp3"`/`"wav"`). + pub async fn text_to_speech( + &self, + input: &str, + voice: Option<&str>, + response_format: &str, + ) -> Result> { + let url = format!("{}/audio/speech", self.base_url); + let mut body = json!({ + "model": self.tts_model, + "input": input, + "response_format": response_format, + }); + if let Some(v) = voice { + body["voice"] = Value::String(v.to_string()); + } + + let resp = self + .client + .post(&url) + .json(&body) + .send() + .await + .with_context(|| format!("POST {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + bail!("llama-swap TTS request failed: {} — {}", status, text); + } + + Ok(resp + .bytes() + .await + .context("reading TTS audio bytes")? + .to_vec()) + } + + /// List voices in the Chatterbox voice library (raw JSON passthrough). + pub async fn list_voices(&self) -> Result { + let url = format!("{}/upstream/{}/voices", self.swap_root(), self.tts_model); + let resp = self + .client + .get(&url) + .send() + .await + .with_context(|| format!("GET {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + bail!("llama-swap list_voices failed: {} — {}", status, text); + } + resp.json().await.context("parsing voices response") + } + + /// Register a cloned voice from raw audio bytes (multipart `voice_name` + + /// `voice_file`). Returns the upstream JSON response. + pub async fn create_voice( + &self, + voice_name: &str, + audio_bytes: Vec, + filename: &str, + mime: &str, + ) -> Result { + let url = format!("{}/upstream/{}/voices", self.swap_root(), self.tts_model); + let part = reqwest::multipart::Part::bytes(audio_bytes) + .file_name(filename.to_string()) + .mime_str(mime) + .context("invalid audio mime type")?; + let form = reqwest::multipart::Form::new() + .text("voice_name", voice_name.to_string()) + .part("voice_file", part); + + let resp = self + .client + .post(&url) + .multipart(form) + .send() + .await + .with_context(|| format!("POST {} (multipart) failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + bail!("llama-swap create_voice failed: {} — {}", status, text); + } + resp.json().await.context("parsing create_voice response") + } + /// Translate canonical messages to the OpenAI-compatible wire shape. /// Behaviorally identical to `OpenRouterClient::messages_to_openai` — /// stringify tool-call arguments, rewrite images into content-parts, attach @@ -1140,4 +1256,24 @@ mod tests { let wire = LlamaCppClient::messages_to_openai(&[msg]); assert_eq!(wire[0]["content"], ""); } + + #[test] + fn swap_root_strips_v1_suffix() { + let c = LlamaCppClient::new(Some("http://localhost:9292/v1".to_string()), None); + assert_eq!(c.swap_root(), "http://localhost:9292"); + + // Tolerates a trailing slash on the base URL. + let c2 = LlamaCppClient::new(Some("http://localhost:9292/v1/".to_string()), None); + assert_eq!(c2.swap_root(), "http://localhost:9292"); + + // No /v1 suffix → returned unchanged. + let c3 = LlamaCppClient::new(Some("http://host:1234".to_string()), None); + assert_eq!(c3.swap_root(), "http://host:1234"); + } + + #[test] + fn tts_model_defaults_to_chatterbox() { + let c = LlamaCppClient::new(None, None); + assert_eq!(c.tts_model, "chatterbox"); + } } diff --git a/src/ai/mod.rs b/src/ai/mod.rs index e9bec09..e61eace 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -11,6 +11,7 @@ pub mod llm_client; pub mod ollama; pub mod openrouter; pub mod sms_client; +pub mod tts; pub mod turn_registry; // strip_summary_boilerplate is used by binaries (test_daily_summary), not the library @@ -34,6 +35,10 @@ pub use llm_client::{ }; pub use ollama::{EMBEDDING_MODEL, OllamaClient}; pub use sms_client::{SmsApiClient, SmsMessage}; +pub use tts::{ + create_voice_from_library_handler, create_voice_upload_handler, list_voices_handler, + tts_speech_handler, +}; /// Display name used for the user in message transcripts and first-person /// prompt text. Reads the `USER_NAME` env var; defaults to `"Me"`. Models diff --git a/src/ai/tts.rs b/src/ai/tts.rs new file mode 100644 index 0000000..b2bd675 --- /dev/null +++ b/src/ai/tts.rs @@ -0,0 +1,393 @@ +// TTS endpoints: proxy text-to-speech + voice-library management to the +// Chatterbox server that sits behind llama-swap (via LlamaCppClient). Speech +// synthesis returns audio as base64-in-JSON so the mobile app can play it as a +// `data:` URI without a binary-fetch path. Voice cloning registers a named +// voice from either an uploaded clip (device) or an existing library file +// (audio read directly; video has its audio track extracted via ffmpeg). + +use actix_multipart::Multipart; +use actix_web::{HttpResponse, Responder, get, post, web}; +use anyhow::Context; +use base64::Engine; +use bytes::{BufMut, BytesMut}; +use futures::StreamExt; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::path::Path; + +use crate::data::Claims; +use crate::file_types::{is_audio_file, is_video_file}; +use crate::files::is_valid_full_path; +use crate::libraries; +use crate::state::AppState; + +/// Hard cap on an uploaded voice-reference clip. Chatterbox itself caps the +/// payload (~60s clip); this is a defensive ceiling so a hostile/oversized +/// upload can't balloon ImageApi memory before we ever forward it. +const MAX_VOICE_UPLOAD_BYTES: usize = 25 * 1024 * 1024; // 25 MB + +/// Sanitize a user-supplied voice name. The name is forwarded to Chatterbox +/// where it becomes a filename in the voice-library directory, so we restrict +/// it to a safe charset (alphanumerics, dash, underscore) — no path +/// separators, dots, or whitespace — and bound its length. Returns `None` +/// when nothing usable remains. +fn sanitize_voice_name(raw: &str) -> Option { + let cleaned: String = raw + .trim() + .chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '-' || c == '_' { + c + } else { + '-' + } + }) + .collect(); + let cleaned = cleaned.trim_matches('-').to_string(); + if cleaned.is_empty() { + return None; + } + Some(cleaned.chars().take(64).collect()) +} + +/// Optional default voice for synthesis when the request doesn't name one. +/// Set `LLAMA_SWAP_TTS_VOICE=m` to read insights in a cloned voice by default. +fn default_voice() -> Option { + std::env::var("LLAMA_SWAP_TTS_VOICE") + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) +} + +fn guess_audio_mime(path: &Path) -> String { + match path + .extension() + .and_then(|e| e.to_str()) + .map(|e| e.to_lowercase()) + .as_deref() + { + Some("wav") => "audio/wav", + Some("mp3") => "audio/mpeg", + Some("m4a") | Some("mp4") | Some("aac") => "audio/mp4", + Some("flac") => "audio/flac", + Some("ogg") | Some("oga") => "audio/ogg", + _ => "application/octet-stream", + } + .to_string() +} + +#[derive(Debug, Deserialize)] +pub struct TtsSpeechRequest { + pub text: String, + #[serde(default)] + pub voice: Option, + /// Audio container, e.g. `"mp3"` (default) or `"wav"`. + #[serde(default)] + pub format: Option, +} + +#[derive(Debug, Serialize)] +pub struct TtsSpeechResponse { + pub audio_base64: String, + pub format: String, +} + +/// POST /tts/speech — synthesize `text` (optionally in a named `voice`) and +/// return base64-encoded audio for `data:` URI playback on the client. +#[post("/tts/speech")] +pub async fn tts_speech_handler( + _claims: Claims, + req: web::Json, + app_state: web::Data, +) -> impl Responder { + let text = req.text.trim(); + if text.is_empty() { + return HttpResponse::BadRequest().json(json!({ "error": "text is required" })); + } + let Some(client) = app_state.llamacpp.as_ref() else { + return HttpResponse::ServiceUnavailable() + .json(json!({ "error": "TTS backend not configured (set LLAMA_SWAP_URL)" })); + }; + + let format = req + .format + .as_deref() + .filter(|s| !s.is_empty()) + .unwrap_or("mp3"); + let dv = default_voice(); + let voice = req + .voice + .as_deref() + .filter(|s| !s.is_empty()) + .or(dv.as_deref()); + + match client.text_to_speech(text, voice, format).await { + Ok(bytes) => { + let audio_base64 = base64::engine::general_purpose::STANDARD.encode(&bytes); + HttpResponse::Ok().json(TtsSpeechResponse { + audio_base64, + format: format.to_string(), + }) + } + Err(e) => { + log::error!("TTS synth failed: {:?}", e); + HttpResponse::BadGateway().json(json!({ "error": format!("TTS failed: {e}") })) + } + } +} + +/// GET /tts/voices — list the Chatterbox voice library (raw passthrough). +#[get("/tts/voices")] +pub async fn list_voices_handler( + _claims: Claims, + app_state: web::Data, +) -> impl Responder { + let Some(client) = app_state.llamacpp.as_ref() else { + return HttpResponse::ServiceUnavailable() + .json(json!({ "error": "TTS backend not configured" })); + }; + match client.list_voices().await { + Ok(v) => HttpResponse::Ok().json(v), + Err(e) => { + log::error!("list_voices failed: {:?}", e); + HttpResponse::BadGateway().json(json!({ "error": format!("{e}") })) + } + } +} + +/// POST /tts/voices/upload — register a cloned voice from an uploaded audio +/// clip. Multipart fields: `voice_name` (text) + a file part (`voice_file`). +#[post("/tts/voices/upload")] +pub async fn create_voice_upload_handler( + _claims: Claims, + mut payload: Multipart, + app_state: web::Data, +) -> impl Responder { + let Some(client) = app_state.llamacpp.as_ref() else { + return HttpResponse::ServiceUnavailable() + .json(json!({ "error": "TTS backend not configured" })); + }; + + let mut voice_name: Option = None; + let mut file_bytes = BytesMut::new(); + let mut filename = "voice.wav".to_string(); + let mut mime = "application/octet-stream".to_string(); + + while let Some(Ok(mut part)) = payload.next().await { + // Capture disposition fields up front so the immutable borrow ends + // before we mutably stream the part body (mirrors handlers/image.rs). + let (fname_opt, name_opt) = { + let cd = part.content_disposition(); + ( + cd.and_then(|c| c.get_filename()).map(|s| s.to_string()), + cd.and_then(|c| c.get_name()).map(|s| s.to_string()), + ) + }; + + if let Some(fname) = fname_opt { + filename = fname; + if let Some(ct) = part.content_type() { + mime = ct.to_string(); + } + while let Some(Ok(data)) = part.next().await { + if file_bytes.len() + data.len() > MAX_VOICE_UPLOAD_BYTES { + return HttpResponse::PayloadTooLarge() + .json(json!({ "error": "voice clip exceeds 25 MB" })); + } + file_bytes.put(data); + } + } else if name_opt.as_deref() == Some("voice_name") { + let mut buf = BytesMut::new(); + while let Some(Ok(data)) = part.next().await { + buf.put(data); + } + voice_name = Some(String::from_utf8_lossy(&buf).trim().to_string()); + } else { + while let Some(Ok(_)) = part.next().await {} + } + } + + let Some(name) = voice_name.as_deref().and_then(sanitize_voice_name) else { + return HttpResponse::BadRequest() + .json(json!({ "error": "voice_name is required (alphanumerics, - and _ only)" })); + }; + if file_bytes.is_empty() { + return HttpResponse::BadRequest().json(json!({ "error": "voice_file is required" })); + } + if !mime.starts_with("audio") { + mime = guess_audio_mime(Path::new(&filename)); + } + + match client + .create_voice(&name, file_bytes.to_vec(), &filename, &mime) + .await + { + Ok(v) => HttpResponse::Ok().json(v), + Err(e) => { + log::error!("create_voice (upload) failed: {:?}", e); + HttpResponse::BadGateway().json(json!({ "error": format!("{e}") })) + } + } +} + +#[derive(Debug, Deserialize)] +pub struct CreateVoiceFromLibraryRequest { + pub voice_name: String, + /// Library-relative path to an audio or video file. + pub path: String, + #[serde(default)] + pub library: Option, +} + +/// POST /tts/voices/from-library — register a cloned voice from a file already +/// in a library. Audio files are forwarded as-is; video files have up to 30s +/// of their audio track extracted (mono, 24 kHz) via ffmpeg. +#[post("/tts/voices/from-library")] +pub async fn create_voice_from_library_handler( + _claims: Claims, + req: web::Json, + app_state: web::Data, +) -> impl Responder { + let Some(client) = app_state.llamacpp.as_ref() else { + return HttpResponse::ServiceUnavailable() + .json(json!({ "error": "TTS backend not configured" })); + }; + let Some(voice_name) = sanitize_voice_name(&req.voice_name) else { + return HttpResponse::BadRequest() + .json(json!({ "error": "voice_name is required (alphanumerics, - and _ only)" })); + }; + + let library = match libraries::resolve_library_param(&app_state, req.library.as_deref()) { + Ok(Some(l)) => l, + Ok(None) => app_state.primary_library(), + Err(msg) => return HttpResponse::BadRequest().json(json!({ "error": msg })), + }; + + // is_valid_full_path confines the path to the library root (no traversal). + let abs = match is_valid_full_path(&library.root_path, &req.path, false) { + Some(p) if p.exists() => p, + _ => { + return HttpResponse::NotFound().json(json!({ "error": "file not found in library" })); + } + }; + + // Only real audio/video sources are valid voice references — refuse to + // slurp arbitrary library files into memory / ffmpeg. + if !is_audio_file(&abs) && !is_video_file(&abs) { + return HttpResponse::BadRequest() + .json(json!({ "error": "file is not an audio or video file" })); + } + + let (bytes, filename, mime) = match prepare_reference_audio(&abs).await { + Ok(t) => t, + Err(e) => { + log::error!("voice reference prep failed for {:?}: {:?}", abs, e); + return HttpResponse::BadRequest().json(json!({ "error": format!("{e}") })); + } + }; + + match client + .create_voice(&voice_name, bytes, &filename, &mime) + .await + { + Ok(v) => HttpResponse::Ok().json(v), + Err(e) => { + log::error!("create_voice (from-library) failed: {:?}", e); + HttpResponse::BadGateway().json(json!({ "error": format!("{e}") })) + } + } +} + +/// Read a library file as reference audio. Audio is returned verbatim; video +/// has up to 30s of audio extracted to mono 24 kHz WAV via ffmpeg. +async fn prepare_reference_audio(abs: &Path) -> anyhow::Result<(Vec, String, String)> { + if is_video_file(abs) { + let tmp = tempfile::Builder::new() + .suffix(".wav") + .tempfile() + .context("creating temp wav")?; + let out = tmp.path().to_path_buf(); + let abs_s = abs.to_string_lossy().to_string(); + let out_s = out.to_string_lossy().to_string(); + + let output = tokio::process::Command::new("ffmpeg") + .args([ + "-y", "-i", &abs_s, "-vn", "-ac", "1", "-ar", "24000", "-t", "30", "-f", "wav", + &out_s, + ]) + .output() + .await + .context("spawning ffmpeg")?; + + if !output.status.success() { + anyhow::bail!( + "ffmpeg audio extraction failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + } + let bytes = std::fs::read(&out).context("reading extracted audio")?; + Ok((bytes, "reference.wav".to_string(), "audio/wav".to_string())) + } else { + let bytes = std::fs::read(abs).context("reading audio file")?; + let filename = abs + .file_name() + .and_then(|f| f.to_str()) + .unwrap_or("reference") + .to_string(); + Ok((bytes, filename, guess_audio_mime(abs))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sanitize_voice_name_keeps_safe_chars() { + assert_eq!(sanitize_voice_name("m").as_deref(), Some("m")); + assert_eq!( + sanitize_voice_name(" Cameron ").as_deref(), + Some("Cameron") + ); + assert_eq!( + sanitize_voice_name("voice_01-a").as_deref(), + Some("voice_01-a") + ); + } + + #[test] + fn sanitize_voice_name_strips_unsafe_chars() { + // Path separators / dots / spaces become '-' and are trimmed at edges. + assert_eq!(sanitize_voice_name("a b.c").as_deref(), Some("a-b-c")); + assert_eq!( + sanitize_voice_name("../etc/passwd").as_deref(), + Some("etc-passwd") + ); + } + + #[test] + fn sanitize_voice_name_rejects_empty_or_all_unsafe() { + assert_eq!(sanitize_voice_name(""), None); + assert_eq!(sanitize_voice_name(" "), None); + assert_eq!(sanitize_voice_name("../../"), None); + assert_eq!(sanitize_voice_name("...."), None); + } + + #[test] + fn sanitize_voice_name_bounds_length() { + let long = "a".repeat(200); + assert_eq!(sanitize_voice_name(&long).unwrap().len(), 64); + } + + #[test] + fn guess_audio_mime_maps_known_extensions() { + assert_eq!(guess_audio_mime(Path::new("clip.wav")), "audio/wav"); + assert_eq!(guess_audio_mime(Path::new("clip.MP3")), "audio/mpeg"); + assert_eq!(guess_audio_mime(Path::new("clip.m4a")), "audio/mp4"); + assert_eq!(guess_audio_mime(Path::new("clip.flac")), "audio/flac"); + assert_eq!( + guess_audio_mime(Path::new("clip.xyz")), + "application/octet-stream" + ); + } +} diff --git a/src/file_types.rs b/src/file_types.rs index 33f71dd..b834cba 100644 --- a/src/file_types.rs +++ b/src/file_types.rs @@ -22,6 +22,10 @@ pub fn needs_ffmpeg_thumbnail(path: &Path) -> bool { /// Supported video file extensions pub const VIDEO_EXTENSIONS: &[&str] = &["mp4", "mov", "avi", "mkv"]; +/// Audio file extensions accepted as voice-clone references (TTS). Mirrors +/// the formats Chatterbox can decode (wav/mp3/flac/m4a/aac/ogg). +pub const AUDIO_EXTENSIONS: &[&str] = &["wav", "mp3", "flac", "m4a", "aac", "ogg", "oga", "opus"]; + /// Filenames that are filesystem metadata, not real media — exact /// basename match. Extend if a new platform sidecar appears (Windows /// Thumbs.db / desktop.ini live here too if those libraries land). @@ -75,6 +79,19 @@ pub fn is_video_file(path: &Path) -> bool { } } +/// Check if a path has an audio extension (voice-clone references) +pub fn is_audio_file(path: &Path) -> bool { + if is_filesystem_metadata(path) { + return false; + } + if let Some(ext) = path.extension().and_then(|e| e.to_str()) { + let ext_lower = ext.to_lowercase(); + AUDIO_EXTENSIONS.contains(&ext_lower.as_str()) + } else { + false + } +} + /// Check if a path has a supported media extension (image or video) pub fn is_media_file(path: &Path) -> bool { is_image_file(path) || is_video_file(path) diff --git a/src/main.rs b/src/main.rs index 4099a5d..8b06228 100644 --- a/src/main.rs +++ b/src/main.rs @@ -362,6 +362,10 @@ fn main() -> std::io::Result<()> { .service(ai::cancel_turn_handler) .service(ai::rate_insight_handler) .service(ai::export_training_data_handler) + .service(ai::tts_speech_handler) + .service(ai::list_voices_handler) + .service(ai::create_voice_upload_handler) + .service(ai::create_voice_from_library_handler) .service(libraries::list_libraries) .service(libraries::patch_library) .add_feature(add_tag_services::<_, SqliteTagDao>) diff --git a/src/state.rs b/src/state.rs index f9adda7..ef071a8 100644 --- a/src/state.rs +++ b/src/state.rs @@ -391,6 +391,9 @@ fn build_llamacpp_from_env() -> Option> { if let Ok(model) = env::var("LLAMA_SWAP_VISION_MODEL") { client.set_vision_model(model); } + if let Ok(model) = env::var("LLAMA_SWAP_TTS_MODEL") { + client.set_tts_model(model); + } Some(Arc::new(client)) }