diff --git a/src/ai/insight_generator.rs b/src/ai/insight_generator.rs index 09fd585..a1edd0c 100644 --- a/src/ai/insight_generator.rs +++ b/src/ai/insight_generator.rs @@ -1421,7 +1421,8 @@ Return ONLY the summary, nothing else."#, cx: &opentelemetry::Context, ) -> String { let result = match tool_name { - "search_rag" => self.tool_search_rag(arguments, cx).await, + "search_rag" => self.tool_search_rag(arguments, ollama, cx).await, + "search_messages" => self.tool_search_messages(arguments).await, "get_sms_messages" => self.tool_get_sms_messages(arguments, cx).await, "get_calendar_events" => self.tool_get_calendar_events(arguments, cx).await, "get_location_history" => self.tool_get_location_history(arguments, cx).await, @@ -1447,6 +1448,7 @@ Return ONLY the summary, nothing else."#, async fn tool_search_rag( &self, args: &serde_json::Value, + ollama: &OllamaClient, _cx: &opentelemetry::Context, ) -> String { let query = match args.get("query").and_then(|v| v.as_str()) { @@ -1479,13 +1481,204 @@ Return ONLY the summary, nothing else."#, limit ); - match self - .find_relevant_messages_rag(date, None, contact.as_deref(), None, limit, Some(&query)) + // Pull a wider candidate pool than the final limit so the LLM + // reranker has room to promote less-obvious hits. Candidates_factor + // is capped so a big `limit` doesn't blow past what the reranker + // can sensibly judge in one prompt. + let rerank_enabled = std::env::var("SEARCH_RAG_RERANK") + .ok() + .map(|v| v.to_lowercase() != "off" && v != "0") + .unwrap_or(true); + let candidate_limit = if rerank_enabled { + (limit * 3).min(40) + } else { + limit + }; + + let results = match self + .find_relevant_messages_rag( + date, + None, + contact.as_deref(), + None, + candidate_limit, + Some(&query), + ) .await { - Ok(results) if !results.is_empty() => results.join("\n\n"), - Ok(_) => "No relevant messages found.".to_string(), - Err(e) => format!("Error searching RAG: {}", e), + Ok(results) if !results.is_empty() => results, + Ok(_) => return "No relevant messages found.".to_string(), + Err(e) => return format!("Error searching RAG: {}", e), + }; + + let final_results = if rerank_enabled && results.len() > limit { + match self.rerank_with_llm(&query, &results, limit, ollama).await { + Ok(reordered) => reordered, + Err(e) => { + log::warn!("rerank failed, using vector order: {}", e); + results.into_iter().take(limit).collect() + } + } + } else { + results.into_iter().take(limit).collect::>() + }; + + final_results.join("\n\n") + } + + /// LLM-based reranker: ask the local model to pick the top-`limit` + /// passages from `candidates` that are most relevant to `query`. + /// Returns the reordered subset. + /// + /// Cheap-ish because the reranker prompt and output live outside the + /// agent's visible context — only the final selection lands in the + /// tool_result. On parse failure we fall back to the input order. + async fn rerank_with_llm( + &self, + query: &str, + candidates: &[String], + limit: usize, + ollama: &OllamaClient, + ) -> Result> { + // Build numbered list (1-based for readability). Cap each passage + // at ~1000 chars so very long summaries don't eat the prompt. + let numbered: String = candidates + .iter() + .enumerate() + .map(|(i, c)| { + let trimmed = if c.len() > 1000 { + format!("{}…", &c[..1000]) + } else { + c.clone() + }; + format!("[{}] {}", i + 1, trimmed) + }) + .collect::>() + .join("\n\n"); + + let prompt = format!( + "You are ranking search results. From the numbered passages below, \ + select the {} most relevant to the query. Respond with ONLY a \ + comma-separated list of passage numbers in order from most to \ + least relevant. No explanation, no other text.\n\n\ + Query: {}\n\n\ + Passages:\n{}\n\n\ + Top {} passage numbers:", + limit, query, numbered, limit + ); + + let response = ollama + .generate( + &prompt, + Some( + "You are a terse relevance ranker. You output only numbers separated by commas.", + ), + ) + .await?; + + // Extract indices from the response. Accept "3, 1, 7" and also + // tolerate "[3, 1, 7]" or "3,1,7,..." with trailing junk. + let picks: Vec = response + .split(|c: char| !c.is_ascii_digit()) + .filter_map(|s| s.parse::().ok()) + .filter(|&n| n >= 1 && n <= candidates.len()) + .collect(); + + if picks.is_empty() { + return Err(anyhow::anyhow!( + "reranker returned no usable indices (raw: {})", + response.chars().take(120).collect::() + )); + } + + let mut seen = std::collections::HashSet::new(); + let mut reordered: Vec = Vec::with_capacity(limit); + for n in picks { + if seen.insert(n) { + reordered.push(candidates[n - 1].clone()); + if reordered.len() >= limit { + break; + } + } + } + // Top-up from original order if the reranker returned fewer than + // `limit` distinct entries. + if reordered.len() < limit { + for (i, c) in candidates.iter().enumerate() { + if !seen.contains(&(i + 1)) { + reordered.push(c.clone()); + if reordered.len() >= limit { + break; + } + } + } + } + Ok(reordered) + } + + /// Tool: search_messages — keyword / semantic / hybrid search over all + /// SMS message bodies via the Django FTS5 + embeddings pipeline. Unlike + /// `search_rag` (daily summaries, date-weighted) this hits raw message + /// text across time and is the right choice for exact phrases, proper + /// nouns, URLs, or anything where specific wording matters. + async fn tool_search_messages(&self, args: &serde_json::Value) -> String { + let query = match args.get("query").and_then(|v| v.as_str()) { + Some(q) if !q.trim().is_empty() => q.trim(), + _ => return "Error: missing required parameter 'query'".to_string(), + }; + if query.len() < 3 { + return "Error: query must be at least 3 characters".to_string(); + } + let mode = args + .get("mode") + .and_then(|v| v.as_str()) + .map(|s| s.to_lowercase()) + .unwrap_or_else(|| "hybrid".to_string()); + if !matches!(mode.as_str(), "fts5" | "semantic" | "hybrid") { + return format!( + "Error: unknown mode '{}'; expected one of: fts5, semantic, hybrid", + mode + ); + } + let limit = args + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(20) + .clamp(1, 50) as usize; + + log::info!( + "tool_search_messages: query='{}', mode={}, limit={}", + query, + mode, + limit + ); + + match self.sms_client.search_messages(query, &mode, limit).await { + Ok(hits) if hits.is_empty() => "No messages matched.".to_string(), + Ok(hits) => { + let mut out = String::new(); + out.push_str(&format!( + "Found {} messages (mode: {}):\n\n", + hits.len(), + mode + )); + for h in hits { + let date = chrono::DateTime::from_timestamp(h.date, 0) + .map(|dt| dt.format("%Y-%m-%d").to_string()) + .unwrap_or_else(|| h.date.to_string()); + let direction = if h.type_ == 2 { "Me" } else { &h.contact_name }; + let score = h + .similarity_score + .map(|s| format!(" [score {:.2}]", s)) + .unwrap_or_default(); + out.push_str(&format!( + "[{}]{} {} — {}\n\n", + date, score, direction, h.body + )); + } + out + } + Err(e) => format!("Error searching messages: {}", e), } } @@ -2164,6 +2357,29 @@ Return ONLY the summary, nothing else."#, } }), ), + Tool::function( + "search_messages", + "Keyword/semantic/hybrid search over ALL SMS message bodies (not just summaries) across all time. Prefer this for specific phrases, proper nouns, URLs, or when you don't know the date. Modes: 'fts5' (keyword, supports \"phrase\" / prefix* / AND / NEAR(w1 w2, 5)), 'semantic' (embedding similarity), 'hybrid' (recommended — merges both via reciprocal rank fusion).", + serde_json::json!({ + "type": "object", + "required": ["query"], + "properties": { + "query": { + "type": "string", + "description": "Search query. Min 3 chars. For fts5 mode, supports phrase (\"\"), prefix (*), AND/OR/NOT, and NEAR proximity." + }, + "mode": { + "type": "string", + "enum": ["fts5", "semantic", "hybrid"], + "description": "Search strategy. Default: hybrid." + }, + "limit": { + "type": "integer", + "description": "Maximum number of results (default: 20, max: 50)" + } + } + }), + ), Tool::function( "get_sms_messages", "Fetch SMS/text messages near a specific date. Returns the actual message conversation. Omit contact to search across all conversations.", diff --git a/src/ai/sms_client.rs b/src/ai/sms_client.rs index 1b6b605..57d28a1 100644 --- a/src/ai/sms_client.rs +++ b/src/ai/sms_client.rs @@ -250,6 +250,45 @@ impl SmsApiClient { .collect()) } + /// Search message bodies via the Django side's FTS5 / semantic / hybrid + /// endpoint. `mode` selects the ranking strategy: + /// - "fts5" keyword-only, supports phrase / prefix / boolean / NEAR + /// - "semantic" embedding similarity + /// - "hybrid" both merged via reciprocal rank fusion (recommended) + pub async fn search_messages( + &self, + query: &str, + mode: &str, + limit: usize, + ) -> Result> { + let url = format!( + "{}/api/messages/search/?q={}&mode={}&limit={}", + self.base_url, + urlencoding::encode(query), + urlencoding::encode(mode), + limit + ); + + let mut request = self.client.get(&url); + if let Some(token) = &self.token { + request = request.header("Authorization", format!("Bearer {}", token)); + } + + let response = request.send().await?; + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(anyhow::anyhow!( + "SMS search request failed: {} - {}", + status, + body + )); + } + + let data: SmsSearchResponse = response.json().await?; + Ok(data.results) + } + pub async fn summarize_context( &self, messages: &[SmsMessage], @@ -314,3 +353,28 @@ struct SmsApiMessage { #[serde(rename = "type")] type_: i32, } + +#[derive(Debug, Clone, Deserialize)] +pub struct SmsSearchHit { + #[allow(dead_code)] + pub message_id: i64, + pub contact_name: String, + #[allow(dead_code)] + pub contact_address: String, + pub body: String, + pub date: i64, + /// Message direction code: 1 = received, 2 = sent. + #[serde(rename = "type")] + pub type_: i32, + /// Present for semantic / hybrid modes; absent for fts5. + #[serde(default)] + pub similarity_score: Option, +} + +#[derive(Deserialize)] +struct SmsSearchResponse { + results: Vec, + #[allow(dead_code)] + #[serde(default)] + search_method: String, +}