feat(ai): search_messages tool + RAG reranker
Adds a search_messages tool that hits the Django FTS5/semantic/hybrid endpoint for keyword-quality text search over message bodies, and an LLM-based reranker inside tool_search_rag (gated by SEARCH_RAG_RERANK, default on). Reranker pulls ~3x candidates from the vector index, asks the chat model to rank by relevance, and falls back to vector order on parse failure. The reranker shares the active chat turn's OllamaClient so num_ctx and sampling match — otherwise Ollama unloads/reloads the model on every rerank call. (Unverified end-to-end; caught by inspection, awaiting e2e confirmation.) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1421,7 +1421,8 @@ Return ONLY the summary, nothing else."#,
|
|||||||
cx: &opentelemetry::Context,
|
cx: &opentelemetry::Context,
|
||||||
) -> String {
|
) -> String {
|
||||||
let result = match tool_name {
|
let result = match tool_name {
|
||||||
"search_rag" => self.tool_search_rag(arguments, cx).await,
|
"search_rag" => self.tool_search_rag(arguments, ollama, cx).await,
|
||||||
|
"search_messages" => self.tool_search_messages(arguments).await,
|
||||||
"get_sms_messages" => self.tool_get_sms_messages(arguments, cx).await,
|
"get_sms_messages" => self.tool_get_sms_messages(arguments, cx).await,
|
||||||
"get_calendar_events" => self.tool_get_calendar_events(arguments, cx).await,
|
"get_calendar_events" => self.tool_get_calendar_events(arguments, cx).await,
|
||||||
"get_location_history" => self.tool_get_location_history(arguments, cx).await,
|
"get_location_history" => self.tool_get_location_history(arguments, cx).await,
|
||||||
@@ -1447,6 +1448,7 @@ Return ONLY the summary, nothing else."#,
|
|||||||
async fn tool_search_rag(
|
async fn tool_search_rag(
|
||||||
&self,
|
&self,
|
||||||
args: &serde_json::Value,
|
args: &serde_json::Value,
|
||||||
|
ollama: &OllamaClient,
|
||||||
_cx: &opentelemetry::Context,
|
_cx: &opentelemetry::Context,
|
||||||
) -> String {
|
) -> String {
|
||||||
let query = match args.get("query").and_then(|v| v.as_str()) {
|
let query = match args.get("query").and_then(|v| v.as_str()) {
|
||||||
@@ -1479,13 +1481,204 @@ Return ONLY the summary, nothing else."#,
|
|||||||
limit
|
limit
|
||||||
);
|
);
|
||||||
|
|
||||||
match self
|
// Pull a wider candidate pool than the final limit so the LLM
|
||||||
.find_relevant_messages_rag(date, None, contact.as_deref(), None, limit, Some(&query))
|
// reranker has room to promote less-obvious hits. Candidates_factor
|
||||||
|
// is capped so a big `limit` doesn't blow past what the reranker
|
||||||
|
// can sensibly judge in one prompt.
|
||||||
|
let rerank_enabled = std::env::var("SEARCH_RAG_RERANK")
|
||||||
|
.ok()
|
||||||
|
.map(|v| v.to_lowercase() != "off" && v != "0")
|
||||||
|
.unwrap_or(true);
|
||||||
|
let candidate_limit = if rerank_enabled {
|
||||||
|
(limit * 3).min(40)
|
||||||
|
} else {
|
||||||
|
limit
|
||||||
|
};
|
||||||
|
|
||||||
|
let results = match self
|
||||||
|
.find_relevant_messages_rag(
|
||||||
|
date,
|
||||||
|
None,
|
||||||
|
contact.as_deref(),
|
||||||
|
None,
|
||||||
|
candidate_limit,
|
||||||
|
Some(&query),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(results) if !results.is_empty() => results.join("\n\n"),
|
Ok(results) if !results.is_empty() => results,
|
||||||
Ok(_) => "No relevant messages found.".to_string(),
|
Ok(_) => return "No relevant messages found.".to_string(),
|
||||||
Err(e) => format!("Error searching RAG: {}", e),
|
Err(e) => return format!("Error searching RAG: {}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
let final_results = if rerank_enabled && results.len() > limit {
|
||||||
|
match self.rerank_with_llm(&query, &results, limit, ollama).await {
|
||||||
|
Ok(reordered) => reordered,
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("rerank failed, using vector order: {}", e);
|
||||||
|
results.into_iter().take(limit).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
results.into_iter().take(limit).collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
final_results.join("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLM-based reranker: ask the local model to pick the top-`limit`
|
||||||
|
/// passages from `candidates` that are most relevant to `query`.
|
||||||
|
/// Returns the reordered subset.
|
||||||
|
///
|
||||||
|
/// Cheap-ish because the reranker prompt and output live outside the
|
||||||
|
/// agent's visible context — only the final selection lands in the
|
||||||
|
/// tool_result. On parse failure we fall back to the input order.
|
||||||
|
async fn rerank_with_llm(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
candidates: &[String],
|
||||||
|
limit: usize,
|
||||||
|
ollama: &OllamaClient,
|
||||||
|
) -> Result<Vec<String>> {
|
||||||
|
// Build numbered list (1-based for readability). Cap each passage
|
||||||
|
// at ~1000 chars so very long summaries don't eat the prompt.
|
||||||
|
let numbered: String = candidates
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, c)| {
|
||||||
|
let trimmed = if c.len() > 1000 {
|
||||||
|
format!("{}…", &c[..1000])
|
||||||
|
} else {
|
||||||
|
c.clone()
|
||||||
|
};
|
||||||
|
format!("[{}] {}", i + 1, trimmed)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n\n");
|
||||||
|
|
||||||
|
let prompt = format!(
|
||||||
|
"You are ranking search results. From the numbered passages below, \
|
||||||
|
select the {} most relevant to the query. Respond with ONLY a \
|
||||||
|
comma-separated list of passage numbers in order from most to \
|
||||||
|
least relevant. No explanation, no other text.\n\n\
|
||||||
|
Query: {}\n\n\
|
||||||
|
Passages:\n{}\n\n\
|
||||||
|
Top {} passage numbers:",
|
||||||
|
limit, query, numbered, limit
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = ollama
|
||||||
|
.generate(
|
||||||
|
&prompt,
|
||||||
|
Some(
|
||||||
|
"You are a terse relevance ranker. You output only numbers separated by commas.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Extract indices from the response. Accept "3, 1, 7" and also
|
||||||
|
// tolerate "[3, 1, 7]" or "3,1,7,..." with trailing junk.
|
||||||
|
let picks: Vec<usize> = response
|
||||||
|
.split(|c: char| !c.is_ascii_digit())
|
||||||
|
.filter_map(|s| s.parse::<usize>().ok())
|
||||||
|
.filter(|&n| n >= 1 && n <= candidates.len())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if picks.is_empty() {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"reranker returned no usable indices (raw: {})",
|
||||||
|
response.chars().take(120).collect::<String>()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut seen = std::collections::HashSet::new();
|
||||||
|
let mut reordered: Vec<String> = Vec::with_capacity(limit);
|
||||||
|
for n in picks {
|
||||||
|
if seen.insert(n) {
|
||||||
|
reordered.push(candidates[n - 1].clone());
|
||||||
|
if reordered.len() >= limit {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Top-up from original order if the reranker returned fewer than
|
||||||
|
// `limit` distinct entries.
|
||||||
|
if reordered.len() < limit {
|
||||||
|
for (i, c) in candidates.iter().enumerate() {
|
||||||
|
if !seen.contains(&(i + 1)) {
|
||||||
|
reordered.push(c.clone());
|
||||||
|
if reordered.len() >= limit {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(reordered)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool: search_messages — keyword / semantic / hybrid search over all
|
||||||
|
/// SMS message bodies via the Django FTS5 + embeddings pipeline. Unlike
|
||||||
|
/// `search_rag` (daily summaries, date-weighted) this hits raw message
|
||||||
|
/// text across time and is the right choice for exact phrases, proper
|
||||||
|
/// nouns, URLs, or anything where specific wording matters.
|
||||||
|
async fn tool_search_messages(&self, args: &serde_json::Value) -> String {
|
||||||
|
let query = match args.get("query").and_then(|v| v.as_str()) {
|
||||||
|
Some(q) if !q.trim().is_empty() => q.trim(),
|
||||||
|
_ => return "Error: missing required parameter 'query'".to_string(),
|
||||||
|
};
|
||||||
|
if query.len() < 3 {
|
||||||
|
return "Error: query must be at least 3 characters".to_string();
|
||||||
|
}
|
||||||
|
let mode = args
|
||||||
|
.get("mode")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_lowercase())
|
||||||
|
.unwrap_or_else(|| "hybrid".to_string());
|
||||||
|
if !matches!(mode.as_str(), "fts5" | "semantic" | "hybrid") {
|
||||||
|
return format!(
|
||||||
|
"Error: unknown mode '{}'; expected one of: fts5, semantic, hybrid",
|
||||||
|
mode
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let limit = args
|
||||||
|
.get("limit")
|
||||||
|
.and_then(|v| v.as_i64())
|
||||||
|
.unwrap_or(20)
|
||||||
|
.clamp(1, 50) as usize;
|
||||||
|
|
||||||
|
log::info!(
|
||||||
|
"tool_search_messages: query='{}', mode={}, limit={}",
|
||||||
|
query,
|
||||||
|
mode,
|
||||||
|
limit
|
||||||
|
);
|
||||||
|
|
||||||
|
match self.sms_client.search_messages(query, &mode, limit).await {
|
||||||
|
Ok(hits) if hits.is_empty() => "No messages matched.".to_string(),
|
||||||
|
Ok(hits) => {
|
||||||
|
let mut out = String::new();
|
||||||
|
out.push_str(&format!(
|
||||||
|
"Found {} messages (mode: {}):\n\n",
|
||||||
|
hits.len(),
|
||||||
|
mode
|
||||||
|
));
|
||||||
|
for h in hits {
|
||||||
|
let date = chrono::DateTime::from_timestamp(h.date, 0)
|
||||||
|
.map(|dt| dt.format("%Y-%m-%d").to_string())
|
||||||
|
.unwrap_or_else(|| h.date.to_string());
|
||||||
|
let direction = if h.type_ == 2 { "Me" } else { &h.contact_name };
|
||||||
|
let score = h
|
||||||
|
.similarity_score
|
||||||
|
.map(|s| format!(" [score {:.2}]", s))
|
||||||
|
.unwrap_or_default();
|
||||||
|
out.push_str(&format!(
|
||||||
|
"[{}]{} {} — {}\n\n",
|
||||||
|
date, score, direction, h.body
|
||||||
|
));
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
Err(e) => format!("Error searching messages: {}", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2164,6 +2357,29 @@ Return ONLY the summary, nothing else."#,
|
|||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
Tool::function(
|
||||||
|
"search_messages",
|
||||||
|
"Keyword/semantic/hybrid search over ALL SMS message bodies (not just summaries) across all time. Prefer this for specific phrases, proper nouns, URLs, or when you don't know the date. Modes: 'fts5' (keyword, supports \"phrase\" / prefix* / AND / NEAR(w1 w2, 5)), 'semantic' (embedding similarity), 'hybrid' (recommended — merges both via reciprocal rank fusion).",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"required": ["query"],
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query. Min 3 chars. For fts5 mode, supports phrase (\"\"), prefix (*), AND/OR/NOT, and NEAR proximity."
|
||||||
|
},
|
||||||
|
"mode": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["fts5", "semantic", "hybrid"],
|
||||||
|
"description": "Search strategy. Default: hybrid."
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results (default: 20, max: 50)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
),
|
||||||
Tool::function(
|
Tool::function(
|
||||||
"get_sms_messages",
|
"get_sms_messages",
|
||||||
"Fetch SMS/text messages near a specific date. Returns the actual message conversation. Omit contact to search across all conversations.",
|
"Fetch SMS/text messages near a specific date. Returns the actual message conversation. Omit contact to search across all conversations.",
|
||||||
|
|||||||
@@ -250,6 +250,45 @@ impl SmsApiClient {
|
|||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Search message bodies via the Django side's FTS5 / semantic / hybrid
|
||||||
|
/// endpoint. `mode` selects the ranking strategy:
|
||||||
|
/// - "fts5" keyword-only, supports phrase / prefix / boolean / NEAR
|
||||||
|
/// - "semantic" embedding similarity
|
||||||
|
/// - "hybrid" both merged via reciprocal rank fusion (recommended)
|
||||||
|
pub async fn search_messages(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
mode: &str,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<SmsSearchHit>> {
|
||||||
|
let url = format!(
|
||||||
|
"{}/api/messages/search/?q={}&mode={}&limit={}",
|
||||||
|
self.base_url,
|
||||||
|
urlencoding::encode(query),
|
||||||
|
urlencoding::encode(mode),
|
||||||
|
limit
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut request = self.client.get(&url);
|
||||||
|
if let Some(token) = &self.token {
|
||||||
|
request = request.header("Authorization", format!("Bearer {}", token));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = request.send().await?;
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"SMS search request failed: {} - {}",
|
||||||
|
status,
|
||||||
|
body
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let data: SmsSearchResponse = response.json().await?;
|
||||||
|
Ok(data.results)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn summarize_context(
|
pub async fn summarize_context(
|
||||||
&self,
|
&self,
|
||||||
messages: &[SmsMessage],
|
messages: &[SmsMessage],
|
||||||
@@ -314,3 +353,28 @@ struct SmsApiMessage {
|
|||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
type_: i32,
|
type_: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct SmsSearchHit {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub message_id: i64,
|
||||||
|
pub contact_name: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub contact_address: String,
|
||||||
|
pub body: String,
|
||||||
|
pub date: i64,
|
||||||
|
/// Message direction code: 1 = received, 2 = sent.
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub type_: i32,
|
||||||
|
/// Present for semantic / hybrid modes; absent for fts5.
|
||||||
|
#[serde(default)]
|
||||||
|
pub similarity_score: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct SmsSearchResponse {
|
||||||
|
results: Vec<SmsSearchHit>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[serde(default)]
|
||||||
|
search_method: String,
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user