//! Memory Retriever - Retrieves relevant memories from OpenViking //! //! This module provides the `MemoryRetriever` which performs semantic search //! over stored memories to find contextually relevant information. //! Uses multiple retrieval strategies and intelligent reranking. use crate::retrieval::{MemoryCache, QueryAnalyzer, SemanticScorer}; use crate::types::{MemoryEntry, MemoryType, RetrievalConfig, RetrievalResult}; use crate::viking_adapter::{FindOptions, VikingAdapter}; use std::sync::Arc; use tokio::sync::RwLock; use zclaw_types::{AgentId, Result}; /// Memory Retriever - retrieves relevant memories from OpenViking pub struct MemoryRetriever { /// OpenViking adapter viking: Arc, /// Retrieval configuration config: RetrievalConfig, /// Semantic scorer for similarity computation scorer: RwLock, /// Query analyzer analyzer: QueryAnalyzer, /// Memory cache cache: MemoryCache, } impl MemoryRetriever { /// Create a new memory retriever pub fn new(viking: Arc) -> Self { Self { viking, config: RetrievalConfig::default(), scorer: RwLock::new(SemanticScorer::new()), analyzer: QueryAnalyzer::new(), cache: MemoryCache::default_config(), } } /// Create with custom configuration pub fn with_config(mut self, config: RetrievalConfig) -> Self { self.config = config; self } /// Retrieve relevant memories for a query /// /// This method: /// 1. Analyzes the query to determine intent and keywords /// 2. Searches for preferences matching the query /// 3. Searches for relevant knowledge /// 4. Searches for applicable experience /// 5. Reranks results using semantic similarity /// 6. Applies token budget constraints pub async fn retrieve( &self, agent_id: &AgentId, query: &str, ) -> Result { tracing::debug!("[MemoryRetriever] Retrieving memories for query: {}", query); // Analyze query let analyzed = self.analyzer.analyze(query); tracing::debug!( "[MemoryRetriever] Query analysis: intent={:?}, keywords={:?}", analyzed.intent, analyzed.keywords ); // Retrieve each type with budget constraints and reranking let preferences = self .retrieve_and_rerank( &agent_id.to_string(), MemoryType::Preference, query, &analyzed.keywords, self.config.max_results_per_type, self.config.preference_budget, ) .await?; let knowledge = self .retrieve_and_rerank( &agent_id.to_string(), MemoryType::Knowledge, query, &analyzed.keywords, self.config.max_results_per_type, self.config.knowledge_budget, ) .await?; let experience = self .retrieve_and_rerank( &agent_id.to_string(), MemoryType::Experience, query, &analyzed.keywords, self.config.max_results_per_type / 2, self.config.experience_budget, ) .await?; let total_tokens = preferences.iter() .chain(knowledge.iter()) .chain(experience.iter()) .map(|m| m.estimated_tokens()) .sum(); // Update cache with retrieved entries for entry in preferences.iter().chain(knowledge.iter()).chain(experience.iter()) { self.cache.put(entry.clone()).await; } tracing::info!( "[MemoryRetriever] Retrieved {} preferences, {} knowledge, {} experience ({} tokens)", preferences.len(), knowledge.len(), experience.len(), total_tokens ); Ok(RetrievalResult { preferences, knowledge, experience, total_tokens, }) } /// Retrieve and rerank memories by type async fn retrieve_and_rerank( &self, agent_id: &str, memory_type: MemoryType, query: &str, keywords: &[String], max_results: usize, token_budget: usize, ) -> Result> { // Build scope for OpenViking search let scope = format!("agent://{}/{}", agent_id, memory_type); // Generate search queries (original + expanded) let analyzed_for_search = crate::retrieval::query::AnalyzedQuery { original: query.to_string(), keywords: keywords.to_vec(), intent: crate::retrieval::query::QueryIntent::General, target_types: vec![], expansions: vec![], }; let search_queries = self.analyzer.generate_search_queries(&analyzed_for_search); // Search with multiple queries and deduplicate let mut all_results = Vec::new(); let mut seen_uris = std::collections::HashSet::new(); for search_query in search_queries { let options = FindOptions { scope: Some(scope.clone()), limit: Some(max_results * 2), min_similarity: Some(self.config.min_similarity), }; let results = self.viking.find(&search_query, options).await?; for entry in results { if seen_uris.insert(entry.uri.clone()) { all_results.push(entry); } } } // Rerank using semantic similarity let scored = self.rerank_entries(query, all_results).await; // Apply token budget let mut filtered = Vec::new(); let mut used_tokens = 0; for entry in scored { let tokens = entry.estimated_tokens(); if used_tokens + tokens <= token_budget { used_tokens += tokens; filtered.push(entry); } if filtered.len() >= max_results { break; } } Ok(filtered) } /// Rerank entries using semantic similarity async fn rerank_entries( &self, query: &str, entries: Vec, ) -> Vec { if entries.is_empty() { return entries; } let mut scorer = self.scorer.write().await; // Index entries for semantic search for entry in &entries { scorer.index_entry(entry); } // Score each entry let mut scored: Vec<(f32, MemoryEntry)> = entries .into_iter() .map(|entry| { let score = scorer.score_similarity(query, &entry); (score, entry) }) .collect(); // Sort by score (descending), then by importance and access count scored.sort_by(|a, b| { b.0.partial_cmp(&a.0) .unwrap_or(std::cmp::Ordering::Equal) .then_with(|| b.1.importance.cmp(&a.1.importance)) .then_with(|| b.1.access_count.cmp(&a.1.access_count)) }); scored.into_iter().map(|(_, entry)| entry).collect() } /// Retrieve a specific memory by URI (with cache) pub async fn get_by_uri(&self, uri: &str) -> Result> { // Check cache first if let Some(cached) = self.cache.get(uri).await { return Ok(Some(cached)); } // Fall back to storage let result = self.viking.get(uri).await?; // Update cache if let Some(ref entry) = result { self.cache.put(entry.clone()).await; } Ok(result) } /// Get all memories for an agent (for debugging/admin) pub async fn get_all_memories(&self, agent_id: &AgentId) -> Result> { let scope = format!("agent://{}", agent_id); let options = FindOptions { scope: Some(scope), limit: None, min_similarity: None, }; self.viking.find("", options).await } /// Get memory statistics for an agent pub async fn get_stats(&self, agent_id: &AgentId) -> Result { let all = self.get_all_memories(agent_id).await?; let preference_count = all.iter().filter(|m| m.memory_type == MemoryType::Preference).count(); let knowledge_count = all.iter().filter(|m| m.memory_type == MemoryType::Knowledge).count(); let experience_count = all.iter().filter(|m| m.memory_type == MemoryType::Experience).count(); Ok(MemoryStats { total_count: all.len(), preference_count, knowledge_count, experience_count, cache_hit_rate: self.cache.hit_rate().await, }) } /// Clear the semantic index pub async fn clear_index(&self) { let mut scorer = self.scorer.write().await; scorer.clear(); } /// Get cache statistics pub async fn cache_stats(&self) -> (usize, f32) { let size = self.cache.size().await; let hit_rate = self.cache.hit_rate().await; (size, hit_rate) } /// Warm up cache with hot entries pub async fn warmup_cache(&self, agent_id: &AgentId) -> Result { let all = self.get_all_memories(agent_id).await?; // Sort by access count to get hot entries let mut sorted = all; sorted.sort_by(|a, b| b.access_count.cmp(&a.access_count)); // Take top 50 hot entries let hot: Vec<_> = sorted.into_iter().take(50).collect(); let count = hot.len(); self.cache.warmup(hot).await; Ok(count) } } /// Memory statistics #[derive(Debug, Clone)] pub struct MemoryStats { pub total_count: usize, pub preference_count: usize, pub knowledge_count: usize, pub experience_count: usize, pub cache_hit_rate: f32, } #[cfg(test)] mod tests { use super::*; #[test] fn test_retrieval_config_default() { let config = RetrievalConfig::default(); assert_eq!(config.max_tokens, 500); assert_eq!(config.preference_budget, 200); assert_eq!(config.knowledge_budget, 200); } #[test] fn test_memory_type_scope() { let scope = format!("agent://test-agent/{}", MemoryType::Preference); assert!(scope.contains("test-agent")); assert!(scope.contains("preferences")); } #[tokio::test] async fn test_retriever_creation() { let viking = Arc::new(VikingAdapter::in_memory()); let retriever = MemoryRetriever::new(viking); let stats = retriever.cache_stats().await; assert_eq!(stats.0, 0); // Cache size should be 0 } }