diff --git a/src/ai/mod.rs b/src/ai/mod.rs index f60e553..5414e69 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -3,6 +3,7 @@ pub mod handlers; pub mod insight_generator; pub mod llm_client; pub mod ollama; +pub mod openrouter; pub mod sms_client; // strip_summary_boilerplate is used by binaries (test_daily_summary), not the library diff --git a/src/ai/openrouter.rs b/src/ai/openrouter.rs new file mode 100644 index 0000000..2c46852 --- /dev/null +++ b/src/ai/openrouter.rs @@ -0,0 +1,727 @@ +// First consumer lands in a later PR (hybrid backend routing). Tests exercise +// the translation helpers directly. +#![allow(dead_code)] + +use anyhow::{Context, Result, anyhow, bail}; +use async_trait::async_trait; +use reqwest::Client; +use serde::Deserialize; +use serde_json::{Value, json}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use crate::ai::llm_client::{ + ChatMessage, LlmClient, ModelCapabilities, Tool, ToolCall, ToolCallFunction, +}; + +const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1"; +const DEFAULT_EMBEDDING_MODEL: &str = "openai/text-embedding-3-small"; +const CACHE_DURATION_SECS: u64 = 15 * 60; + +#[derive(Clone)] +struct CachedEntry { + data: T, + cached_at: Instant, +} + +impl CachedEntry { + fn new(data: T) -> Self { + Self { + data, + cached_at: Instant::now(), + } + } + + fn is_expired(&self) -> bool { + self.cached_at.elapsed().as_secs() > CACHE_DURATION_SECS + } +} + +lazy_static::lazy_static! { + static ref MODEL_CAPABILITIES_CACHE: Arc>>>> = + Arc::new(Mutex::new(HashMap::new())); +} + +/// OpenAI-compatible client for OpenRouter (https://openrouter.ai). +/// +/// Translates canonical `ChatMessage` / `Tool` shapes to OpenAI wire format: +/// - Tool-call `arguments` serialized as JSON-encoded strings (vs Ollama's +/// native JSON). +/// - Image content rewritten into content-parts array with `image_url` entries. +/// - `role=tool` messages attach a `tool_call_id` inferred from the preceding +/// assistant turn's tool call. +#[derive(Clone)] +pub struct OpenRouterClient { + client: Client, + pub api_key: String, + pub base_url: String, + pub primary_model: String, + pub embedding_model: String, + num_ctx: Option, + temperature: Option, + top_p: Option, + top_k: Option, + min_p: Option, + /// Optional `HTTP-Referer` header OpenRouter uses for attribution. + pub referer: Option, + /// Optional `X-Title` header OpenRouter uses for attribution. + pub app_title: Option, +} + +impl OpenRouterClient { + pub fn new(api_key: String, base_url: Option, primary_model: String) -> Self { + Self { + client: Client::builder() + .connect_timeout(Duration::from_secs(10)) + .timeout(Duration::from_secs(180)) + .build() + .unwrap_or_else(|_| Client::new()), + api_key, + base_url: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string()), + primary_model, + embedding_model: DEFAULT_EMBEDDING_MODEL.to_string(), + num_ctx: None, + temperature: None, + top_p: None, + top_k: None, + min_p: None, + referer: None, + app_title: None, + } + } + + pub fn set_embedding_model(&mut self, model: String) { + self.embedding_model = model; + } + + #[allow(dead_code)] + pub fn set_num_ctx(&mut self, num_ctx: Option) { + self.num_ctx = num_ctx; + } + + #[allow(dead_code)] + pub fn set_sampling_params( + &mut self, + temperature: Option, + top_p: Option, + top_k: Option, + min_p: Option, + ) { + self.temperature = temperature; + self.top_p = top_p; + self.top_k = top_k; + self.min_p = min_p; + } + + pub fn set_attribution(&mut self, referer: Option, app_title: Option) { + self.referer = referer; + self.app_title = app_title; + } + + fn authed(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + let mut b = builder.bearer_auth(&self.api_key); + if let Some(r) = &self.referer { + b = b.header("HTTP-Referer", r); + } + if let Some(t) = &self.app_title { + b = b.header("X-Title", t); + } + b + } + + /// Translate canonical messages to the OpenAI-compatible wire shape. + /// + /// Walks in order so it can attach `tool_call_id` to `role=tool` messages + /// based on the most recent assistant turn's tool call. + fn messages_to_openai(messages: &[ChatMessage]) -> Vec { + let mut out = Vec::with_capacity(messages.len()); + let mut last_tool_call_ids: Vec = Vec::new(); + let mut next_tool_result_idx: usize = 0; + + for msg in messages { + let mut obj = serde_json::Map::new(); + obj.insert("role".into(), Value::String(msg.role.clone())); + + // Content: string OR content-parts array (when images present). + match &msg.images { + Some(images) if !images.is_empty() => { + let mut parts: Vec = Vec::new(); + if !msg.content.is_empty() { + parts.push(json!({"type": "text", "text": msg.content})); + } + for img in images { + let url = image_to_data_url(img); + parts.push(json!({ + "type": "image_url", + "image_url": { "url": url } + })); + } + obj.insert("content".into(), Value::Array(parts)); + } + _ => { + obj.insert("content".into(), Value::String(msg.content.clone())); + } + } + + // Assistant message with tool_calls: stringify arguments, remember + // the ids so the subsequent tool messages can reference them. + if let Some(tcs) = &msg.tool_calls + && msg.role == "assistant" + { + let converted: Vec = tcs + .iter() + .enumerate() + .map(|(i, call)| { + let id = call.id.clone().unwrap_or_else(|| format!("call_{}", i)); + let args_str = serde_json::to_string(&call.function.arguments) + .unwrap_or_else(|_| "{}".to_string()); + json!({ + "id": id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": args_str, + } + }) + }) + .collect(); + last_tool_call_ids = converted + .iter() + .filter_map(|v| v.get("id").and_then(|x| x.as_str()).map(String::from)) + .collect(); + next_tool_result_idx = 0; + obj.insert("tool_calls".into(), Value::Array(converted)); + } + + // Tool result messages: attach tool_call_id from the last assistant turn. + if msg.role == "tool" { + let id = last_tool_call_ids + .get(next_tool_result_idx) + .cloned() + .unwrap_or_else(|| "call_0".to_string()); + obj.insert("tool_call_id".into(), Value::String(id)); + next_tool_result_idx += 1; + } + + out.push(Value::Object(obj)); + } + + out + } + + /// Parse an OpenAI-compatible assistant message back into canonical shape. + fn openai_message_to_chat(msg: &Value) -> Result { + let obj = msg + .as_object() + .ok_or_else(|| anyhow!("response message is not an object"))?; + let role = obj + .get("role") + .and_then(|v| v.as_str()) + .unwrap_or("assistant") + .to_string(); + let content = obj + .get("content") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let tool_calls = if let Some(tcs) = obj.get("tool_calls").and_then(|v| v.as_array()) { + let mut parsed = Vec::with_capacity(tcs.len()); + for tc in tcs { + let id = tc.get("id").and_then(|v| v.as_str()).map(String::from); + let function = tc + .get("function") + .ok_or_else(|| anyhow!("tool_call missing function field"))?; + let name = function + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let args_value = match function.get("arguments") { + // OpenAI-compat: stringified JSON. + Some(Value::String(s)) => { + serde_json::from_str::(s).unwrap_or_else(|_| json!({})) + } + // Some providers emit arguments as an object directly — accept both. + Some(v @ Value::Object(_)) => v.clone(), + _ => json!({}), + }; + parsed.push(ToolCall { + id, + function: ToolCallFunction { + name, + arguments: args_value, + }, + }); + } + Some(parsed) + } else { + None + }; + + Ok(ChatMessage { + role, + content, + tool_calls, + images: None, + }) + } + + fn build_options(&self) -> Vec<(&'static str, Value)> { + let mut v = Vec::new(); + if let Some(t) = self.temperature { + v.push(("temperature", json!(t))); + } + if let Some(p) = self.top_p { + v.push(("top_p", json!(p))); + } + if let Some(k) = self.top_k { + v.push(("top_k", json!(k))); + } + if let Some(m) = self.min_p { + v.push(("min_p", json!(m))); + } + if let Some(c) = self.num_ctx { + // OpenAI uses max_tokens for generation bound; num_ctx isn't + // directly transferable. Skip rather than silently mis-map. + let _ = c; + } + v + } +} + +#[async_trait] +impl LlmClient for OpenRouterClient { + async fn generate( + &self, + prompt: &str, + system: Option<&str>, + images: Option>, + ) -> Result { + let mut messages: Vec = Vec::new(); + if let Some(sys) = system { + messages.push(ChatMessage::system(sys)); + } + let mut user = ChatMessage::user(prompt); + user.images = images; + messages.push(user); + + let (reply, _, _) = self.chat_with_tools(messages, Vec::new()).await?; + Ok(reply.content) + } + + async fn chat_with_tools( + &self, + messages: Vec, + tools: Vec, + ) -> Result<(ChatMessage, Option, Option)> { + let url = format!("{}/chat/completions", self.base_url); + let mut body = serde_json::Map::new(); + body.insert("model".into(), Value::String(self.primary_model.clone())); + body.insert( + "messages".into(), + Value::Array(Self::messages_to_openai(&messages)), + ); + body.insert("stream".into(), Value::Bool(false)); + if !tools.is_empty() { + body.insert( + "tools".into(), + serde_json::to_value(&tools).context("serializing tools")?, + ); + } + for (k, v) in self.build_options() { + body.insert(k.into(), v); + } + + log::info!( + "OpenRouter chat_with_tools: model={} messages={} tools={}", + self.primary_model, + messages.len(), + tools.len() + ); + + let resp = self + .authed(self.client.post(&url)) + .json(&Value::Object(body)) + .send() + .await + .with_context(|| format!("POST {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("OpenRouter chat request failed: {} — {}", status, body); + } + + let parsed: Value = resp.json().await.context("parsing chat response")?; + let choice = parsed + .get("choices") + .and_then(|v| v.as_array()) + .and_then(|a| a.first()) + .ok_or_else(|| anyhow!("response missing choices[0]"))?; + let msg = choice + .get("message") + .ok_or_else(|| anyhow!("choices[0] missing message"))?; + let chat_msg = Self::openai_message_to_chat(msg)?; + + let usage = parsed.get("usage"); + let prompt_tokens = usage + .and_then(|u| u.get("prompt_tokens")) + .and_then(|v| v.as_i64()) + .map(|n| n as i32); + let completion_tokens = usage + .and_then(|u| u.get("completion_tokens")) + .and_then(|v| v.as_i64()) + .map(|n| n as i32); + + Ok((chat_msg, prompt_tokens, completion_tokens)) + } + + async fn generate_embeddings(&self, texts: &[&str]) -> Result>> { + let url = format!("{}/embeddings", self.base_url); + let body = json!({ + "model": self.embedding_model, + "input": texts, + }); + + let resp = self + .authed(self.client.post(&url)) + .json(&body) + .send() + .await + .with_context(|| format!("POST {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("OpenRouter embedding request failed: {} — {}", status, body); + } + + #[derive(Deserialize)] + struct EmbedResponse { + data: Vec, + } + #[derive(Deserialize)] + struct EmbedItem { + embedding: Vec, + } + + let parsed: EmbedResponse = resp.json().await.context("parsing embed response")?; + Ok(parsed.data.into_iter().map(|i| i.embedding).collect()) + } + + async fn describe_image(&self, image_base64: &str) -> Result { + let prompt = "Briefly describe what you see in this image in 1-2 sentences. \ + Focus on the people, location, and activity."; + self.generate( + prompt, + Some("You are a scene description assistant. Be concise and factual."), + Some(vec![image_base64.to_string()]), + ) + .await + } + + async fn list_models(&self) -> Result> { + { + let cache = MODEL_CAPABILITIES_CACHE.lock().unwrap(); + if let Some(entry) = cache.get(&self.base_url) + && !entry.is_expired() + { + return Ok(entry.data.clone()); + } + } + + let url = format!("{}/models", self.base_url); + let resp = self + .authed(self.client.get(&url)) + .send() + .await + .with_context(|| format!("GET {} failed", url))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("OpenRouter list_models failed: {} — {}", status, body); + } + + let parsed: Value = resp.json().await.context("parsing models response")?; + let data = parsed + .get("data") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow!("models response missing data[]"))?; + + let caps: Vec = data.iter().map(parse_model_capabilities).collect(); + + { + let mut cache = MODEL_CAPABILITIES_CACHE.lock().unwrap(); + cache.insert(self.base_url.clone(), CachedEntry::new(caps.clone())); + } + + Ok(caps) + } + + async fn model_capabilities(&self, model: &str) -> Result { + let all = self.list_models().await?; + all.into_iter() + .find(|m| m.name == model) + .ok_or_else(|| anyhow!("model '{}' not found on OpenRouter", model)) + } + + fn primary_model(&self) -> &str { + &self.primary_model + } +} + +/// Build a `data:` URL if the provided string is raw base64, otherwise pass it through. +fn image_to_data_url(img: &str) -> String { + if img.starts_with("data:") { + img.to_string() + } else { + format!("data:image/jpeg;base64,{}", img) + } +} + +fn parse_model_capabilities(m: &Value) -> ModelCapabilities { + let name = m + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let has_tool_calling = m + .get("supported_parameters") + .and_then(|v| v.as_array()) + .map(|arr| arr.iter().any(|x| x.as_str() == Some("tools"))) + .unwrap_or(false); + let has_vision = m + .get("architecture") + .and_then(|v| v.get("input_modalities")) + .and_then(|v| v.as_array()) + .map(|arr| arr.iter().any(|x| x.as_str() == Some("image"))) + .unwrap_or(false); + ModelCapabilities { + name, + has_vision, + has_tool_calling, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tool_call_arguments_stringified_on_send() { + let mut msg = ChatMessage { + role: "assistant".into(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: Some("call_abc".into()), + function: ToolCallFunction { + name: "search_sms".into(), + arguments: json!({"query": "hello", "limit": 5}), + }, + }]), + images: None, + }; + msg.tool_calls.as_mut().unwrap()[0].function.arguments = + json!({"query": "hello", "limit": 5}); + + let wire = OpenRouterClient::messages_to_openai(&[msg]); + let tcs = wire[0] + .get("tool_calls") + .and_then(|v| v.as_array()) + .expect("tool_calls present"); + let args = tcs[0] + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|a| a.as_str()) + .expect("arguments stringified"); + let parsed: Value = serde_json::from_str(args).unwrap(); + assert_eq!(parsed["query"], "hello"); + assert_eq!(parsed["limit"], 5); + } + + #[test] + fn tool_call_arguments_parsed_on_receive() { + let response_msg = json!({ + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": "call_xyz", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"city\":\"Boston\",\"units\":\"celsius\"}" + } + }] + }); + + let parsed = OpenRouterClient::openai_message_to_chat(&response_msg).unwrap(); + let tcs = parsed.tool_calls.unwrap(); + assert_eq!(tcs.len(), 1); + assert_eq!(tcs[0].function.name, "get_weather"); + assert_eq!(tcs[0].function.arguments["city"], "Boston"); + assert_eq!(tcs[0].function.arguments["units"], "celsius"); + assert_eq!(tcs[0].id.as_deref(), Some("call_xyz")); + } + + #[test] + fn tool_call_arguments_accept_native_json_on_receive() { + // Some providers return arguments as an object directly; accept both. + let response_msg = json!({ + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": { + "name": "foo", + "arguments": {"nested": {"k": 1}} + } + }] + }); + let parsed = OpenRouterClient::openai_message_to_chat(&response_msg).unwrap(); + let tc = &parsed.tool_calls.unwrap()[0]; + assert_eq!(tc.function.arguments["nested"]["k"], 1); + } + + #[test] + fn images_become_content_parts() { + let mut msg = ChatMessage::user("What is in this photo?"); + msg.images = Some(vec!["BASE64DATA".into()]); + + let wire = OpenRouterClient::messages_to_openai(&[msg]); + let content = wire[0].get("content").and_then(|v| v.as_array()).unwrap(); + assert_eq!(content.len(), 2); + assert_eq!(content[0]["type"], "text"); + assert_eq!(content[0]["text"], "What is in this photo?"); + assert_eq!(content[1]["type"], "image_url"); + assert_eq!( + content[1]["image_url"]["url"], + "data:image/jpeg;base64,BASE64DATA" + ); + } + + #[test] + fn data_url_images_pass_through_unchanged() { + let mut msg = ChatMessage::user(""); + msg.images = Some(vec!["data:image/png;base64,ABCDEF".into()]); + let wire = OpenRouterClient::messages_to_openai(&[msg]); + let content = wire[0].get("content").and_then(|v| v.as_array()).unwrap(); + // No text part when content is empty. + assert_eq!(content.len(), 1); + assert_eq!( + content[0]["image_url"]["url"], + "data:image/png;base64,ABCDEF" + ); + } + + #[test] + fn text_only_message_stays_string() { + let msg = ChatMessage::user("hello"); + let wire = OpenRouterClient::messages_to_openai(&[msg]); + assert_eq!(wire[0]["content"], "hello"); + assert!(wire[0]["content"].as_str().is_some()); + } + + #[test] + fn tool_result_inherits_tool_call_id_from_prior_assistant() { + let assistant = ChatMessage { + role: "assistant".into(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: Some("call_42".into()), + function: ToolCallFunction { + name: "lookup".into(), + arguments: json!({}), + }, + }]), + images: None, + }; + let tool_result = ChatMessage::tool_result("found it"); + + let wire = OpenRouterClient::messages_to_openai(&[assistant, tool_result]); + assert_eq!(wire[1]["role"], "tool"); + assert_eq!(wire[1]["tool_call_id"], "call_42"); + } + + #[test] + fn multiple_tool_results_map_to_sequential_call_ids() { + let assistant = ChatMessage { + role: "assistant".into(), + content: String::new(), + tool_calls: Some(vec![ + ToolCall { + id: Some("call_A".into()), + function: ToolCallFunction { + name: "a".into(), + arguments: json!({}), + }, + }, + ToolCall { + id: Some("call_B".into()), + function: ToolCallFunction { + name: "b".into(), + arguments: json!({}), + }, + }, + ]), + images: None, + }; + let r1 = ChatMessage::tool_result("a result"); + let r2 = ChatMessage::tool_result("b result"); + + let wire = OpenRouterClient::messages_to_openai(&[assistant, r1, r2]); + assert_eq!(wire[1]["tool_call_id"], "call_A"); + assert_eq!(wire[2]["tool_call_id"], "call_B"); + } + + #[test] + fn missing_tool_call_id_gets_synthetic_fallback() { + let assistant = ChatMessage { + role: "assistant".into(), + content: String::new(), + tool_calls: Some(vec![ToolCall { + id: None, + function: ToolCallFunction { + name: "noid".into(), + arguments: json!({}), + }, + }]), + images: None, + }; + let wire = OpenRouterClient::messages_to_openai(&[assistant]); + let tcs = wire[0] + .get("tool_calls") + .and_then(|v| v.as_array()) + .unwrap(); + assert_eq!(tcs[0]["id"], "call_0"); + } + + #[test] + fn parse_model_capabilities_extracts_tools_and_vision() { + let m = json!({ + "id": "anthropic/claude-sonnet-4", + "supported_parameters": ["temperature", "top_p", "tools", "max_tokens"], + "architecture": { + "input_modalities": ["text", "image"] + } + }); + let caps = parse_model_capabilities(&m); + assert_eq!(caps.name, "anthropic/claude-sonnet-4"); + assert!(caps.has_tool_calling); + assert!(caps.has_vision); + } + + #[test] + fn parse_model_capabilities_handles_missing_fields() { + let m = json!({ + "id": "some/text-only-model" + }); + let caps = parse_model_capabilities(&m); + assert_eq!(caps.name, "some/text-only-model"); + assert!(!caps.has_tool_calling); + assert!(!caps.has_vision); + } +}