feat(ai): search_messages tool + RAG reranker

Adds a search_messages tool that hits the Django FTS5/semantic/hybrid
endpoint for keyword-quality text search over message bodies, and an
LLM-based reranker inside tool_search_rag (gated by SEARCH_RAG_RERANK,
default on). Reranker pulls ~3x candidates from the vector index, asks
the chat model to rank by relevance, and falls back to vector order on
parse failure.

The reranker shares the active chat turn's OllamaClient so num_ctx and
sampling match — otherwise Ollama unloads/reloads the model on every
rerank call. (Unverified end-to-end; caught by inspection, awaiting
e2e confirmation.)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Cameron
2026-04-22 10:56:03 -04:00
parent e51cd564a3
commit e4a3536f87
2 changed files with 286 additions and 6 deletions

View File

@@ -1421,7 +1421,8 @@ Return ONLY the summary, nothing else."#,
cx: &opentelemetry::Context,
) -> String {
let result = match tool_name {
"search_rag" => self.tool_search_rag(arguments, cx).await,
"search_rag" => self.tool_search_rag(arguments, ollama, cx).await,
"search_messages" => self.tool_search_messages(arguments).await,
"get_sms_messages" => self.tool_get_sms_messages(arguments, cx).await,
"get_calendar_events" => self.tool_get_calendar_events(arguments, cx).await,
"get_location_history" => self.tool_get_location_history(arguments, cx).await,
@@ -1447,6 +1448,7 @@ Return ONLY the summary, nothing else."#,
async fn tool_search_rag(
&self,
args: &serde_json::Value,
ollama: &OllamaClient,
_cx: &opentelemetry::Context,
) -> String {
let query = match args.get("query").and_then(|v| v.as_str()) {
@@ -1479,13 +1481,204 @@ Return ONLY the summary, nothing else."#,
limit
);
match self
.find_relevant_messages_rag(date, None, contact.as_deref(), None, limit, Some(&query))
// Pull a wider candidate pool than the final limit so the LLM
// reranker has room to promote less-obvious hits. Candidates_factor
// is capped so a big `limit` doesn't blow past what the reranker
// can sensibly judge in one prompt.
let rerank_enabled = std::env::var("SEARCH_RAG_RERANK")
.ok()
.map(|v| v.to_lowercase() != "off" && v != "0")
.unwrap_or(true);
let candidate_limit = if rerank_enabled {
(limit * 3).min(40)
} else {
limit
};
let results = match self
.find_relevant_messages_rag(
date,
None,
contact.as_deref(),
None,
candidate_limit,
Some(&query),
)
.await
{
Ok(results) if !results.is_empty() => results.join("\n\n"),
Ok(_) => "No relevant messages found.".to_string(),
Err(e) => format!("Error searching RAG: {}", e),
Ok(results) if !results.is_empty() => results,
Ok(_) => return "No relevant messages found.".to_string(),
Err(e) => return format!("Error searching RAG: {}", e),
};
let final_results = if rerank_enabled && results.len() > limit {
match self.rerank_with_llm(&query, &results, limit, ollama).await {
Ok(reordered) => reordered,
Err(e) => {
log::warn!("rerank failed, using vector order: {}", e);
results.into_iter().take(limit).collect()
}
}
} else {
results.into_iter().take(limit).collect::<Vec<_>>()
};
final_results.join("\n\n")
}
/// LLM-based reranker: ask the local model to pick the top-`limit`
/// passages from `candidates` that are most relevant to `query`.
/// Returns the reordered subset.
///
/// Cheap-ish because the reranker prompt and output live outside the
/// agent's visible context — only the final selection lands in the
/// tool_result. On parse failure we fall back to the input order.
async fn rerank_with_llm(
&self,
query: &str,
candidates: &[String],
limit: usize,
ollama: &OllamaClient,
) -> Result<Vec<String>> {
// Build numbered list (1-based for readability). Cap each passage
// at ~1000 chars so very long summaries don't eat the prompt.
let numbered: String = candidates
.iter()
.enumerate()
.map(|(i, c)| {
let trimmed = if c.len() > 1000 {
format!("{}", &c[..1000])
} else {
c.clone()
};
format!("[{}] {}", i + 1, trimmed)
})
.collect::<Vec<_>>()
.join("\n\n");
let prompt = format!(
"You are ranking search results. From the numbered passages below, \
select the {} most relevant to the query. Respond with ONLY a \
comma-separated list of passage numbers in order from most to \
least relevant. No explanation, no other text.\n\n\
Query: {}\n\n\
Passages:\n{}\n\n\
Top {} passage numbers:",
limit, query, numbered, limit
);
let response = ollama
.generate(
&prompt,
Some(
"You are a terse relevance ranker. You output only numbers separated by commas.",
),
)
.await?;
// Extract indices from the response. Accept "3, 1, 7" and also
// tolerate "[3, 1, 7]" or "3,1,7,..." with trailing junk.
let picks: Vec<usize> = response
.split(|c: char| !c.is_ascii_digit())
.filter_map(|s| s.parse::<usize>().ok())
.filter(|&n| n >= 1 && n <= candidates.len())
.collect();
if picks.is_empty() {
return Err(anyhow::anyhow!(
"reranker returned no usable indices (raw: {})",
response.chars().take(120).collect::<String>()
));
}
let mut seen = std::collections::HashSet::new();
let mut reordered: Vec<String> = Vec::with_capacity(limit);
for n in picks {
if seen.insert(n) {
reordered.push(candidates[n - 1].clone());
if reordered.len() >= limit {
break;
}
}
}
// Top-up from original order if the reranker returned fewer than
// `limit` distinct entries.
if reordered.len() < limit {
for (i, c) in candidates.iter().enumerate() {
if !seen.contains(&(i + 1)) {
reordered.push(c.clone());
if reordered.len() >= limit {
break;
}
}
}
}
Ok(reordered)
}
/// Tool: search_messages — keyword / semantic / hybrid search over all
/// SMS message bodies via the Django FTS5 + embeddings pipeline. Unlike
/// `search_rag` (daily summaries, date-weighted) this hits raw message
/// text across time and is the right choice for exact phrases, proper
/// nouns, URLs, or anything where specific wording matters.
async fn tool_search_messages(&self, args: &serde_json::Value) -> String {
let query = match args.get("query").and_then(|v| v.as_str()) {
Some(q) if !q.trim().is_empty() => q.trim(),
_ => return "Error: missing required parameter 'query'".to_string(),
};
if query.len() < 3 {
return "Error: query must be at least 3 characters".to_string();
}
let mode = args
.get("mode")
.and_then(|v| v.as_str())
.map(|s| s.to_lowercase())
.unwrap_or_else(|| "hybrid".to_string());
if !matches!(mode.as_str(), "fts5" | "semantic" | "hybrid") {
return format!(
"Error: unknown mode '{}'; expected one of: fts5, semantic, hybrid",
mode
);
}
let limit = args
.get("limit")
.and_then(|v| v.as_i64())
.unwrap_or(20)
.clamp(1, 50) as usize;
log::info!(
"tool_search_messages: query='{}', mode={}, limit={}",
query,
mode,
limit
);
match self.sms_client.search_messages(query, &mode, limit).await {
Ok(hits) if hits.is_empty() => "No messages matched.".to_string(),
Ok(hits) => {
let mut out = String::new();
out.push_str(&format!(
"Found {} messages (mode: {}):\n\n",
hits.len(),
mode
));
for h in hits {
let date = chrono::DateTime::from_timestamp(h.date, 0)
.map(|dt| dt.format("%Y-%m-%d").to_string())
.unwrap_or_else(|| h.date.to_string());
let direction = if h.type_ == 2 { "Me" } else { &h.contact_name };
let score = h
.similarity_score
.map(|s| format!(" [score {:.2}]", s))
.unwrap_or_default();
out.push_str(&format!(
"[{}]{} {}{}\n\n",
date, score, direction, h.body
));
}
out
}
Err(e) => format!("Error searching messages: {}", e),
}
}
@@ -2164,6 +2357,29 @@ Return ONLY the summary, nothing else."#,
}
}),
),
Tool::function(
"search_messages",
"Keyword/semantic/hybrid search over ALL SMS message bodies (not just summaries) across all time. Prefer this for specific phrases, proper nouns, URLs, or when you don't know the date. Modes: 'fts5' (keyword, supports \"phrase\" / prefix* / AND / NEAR(w1 w2, 5)), 'semantic' (embedding similarity), 'hybrid' (recommended — merges both via reciprocal rank fusion).",
serde_json::json!({
"type": "object",
"required": ["query"],
"properties": {
"query": {
"type": "string",
"description": "Search query. Min 3 chars. For fts5 mode, supports phrase (\"\"), prefix (*), AND/OR/NOT, and NEAR proximity."
},
"mode": {
"type": "string",
"enum": ["fts5", "semantic", "hybrid"],
"description": "Search strategy. Default: hybrid."
},
"limit": {
"type": "integer",
"description": "Maximum number of results (default: 20, max: 50)"
}
}
}),
),
Tool::function(
"get_sms_messages",
"Fetch SMS/text messages near a specific date. Returns the actual message conversation. Omit contact to search across all conversations.",

View File

@@ -250,6 +250,45 @@ impl SmsApiClient {
.collect())
}
/// Search message bodies via the Django side's FTS5 / semantic / hybrid
/// endpoint. `mode` selects the ranking strategy:
/// - "fts5" keyword-only, supports phrase / prefix / boolean / NEAR
/// - "semantic" embedding similarity
/// - "hybrid" both merged via reciprocal rank fusion (recommended)
pub async fn search_messages(
&self,
query: &str,
mode: &str,
limit: usize,
) -> Result<Vec<SmsSearchHit>> {
let url = format!(
"{}/api/messages/search/?q={}&mode={}&limit={}",
self.base_url,
urlencoding::encode(query),
urlencoding::encode(mode),
limit
);
let mut request = self.client.get(&url);
if let Some(token) = &self.token {
request = request.header("Authorization", format!("Bearer {}", token));
}
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"SMS search request failed: {} - {}",
status,
body
));
}
let data: SmsSearchResponse = response.json().await?;
Ok(data.results)
}
pub async fn summarize_context(
&self,
messages: &[SmsMessage],
@@ -314,3 +353,28 @@ struct SmsApiMessage {
#[serde(rename = "type")]
type_: i32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SmsSearchHit {
#[allow(dead_code)]
pub message_id: i64,
pub contact_name: String,
#[allow(dead_code)]
pub contact_address: String,
pub body: String,
pub date: i64,
/// Message direction code: 1 = received, 2 = sent.
#[serde(rename = "type")]
pub type_: i32,
/// Present for semantic / hybrid modes; absent for fts5.
#[serde(default)]
pub similarity_score: Option<f32>,
}
#[derive(Deserialize)]
struct SmsSearchResponse {
results: Vec<SmsSearchHit>,
#[allow(dead_code)]
#[serde(default)]
search_method: String,
}