fix(presentation): 修复 presentation 模块类型错误和语法问题
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

- 创建 types.ts 定义完整的类型系统
- 重写 DocumentRenderer.tsx 修复语法错误
- 重写 QuizRenderer.tsx 修复语法错误
- 重写 PresentationContainer.tsx 添加类型守卫
- 重写 TypeSwitcher.tsx 修复类型引用
- 更新 index.ts 移除不存在的 ChartRenderer 导出

审计结果:
- 类型检查: 通过
- 单元测试: 222 passed
- 构建: 成功
This commit is contained in:
iven
2026-03-26 17:19:28 +08:00
parent d0c6319fc1
commit b7f3d94950
71 changed files with 15896 additions and 1133 deletions

View File

@@ -0,0 +1,365 @@
//! Memory Cache
//!
//! Provides caching for frequently accessed memories to improve
//! retrieval performance.
use crate::types::{MemoryEntry, MemoryType};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
/// Cache entry with metadata
struct CacheEntry {
/// The memory entry
entry: MemoryEntry,
/// Last access time
last_accessed: Instant,
/// Access count
access_count: u32,
}
/// Cache key for efficient lookups
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
struct CacheKey {
agent_id: String,
memory_type: MemoryType,
category: String,
}
impl From<&MemoryEntry> for CacheKey {
fn from(entry: &MemoryEntry) -> Self {
// Parse URI to extract components
let parts: Vec<&str> = entry.uri.trim_start_matches("agent://").split('/').collect();
Self {
agent_id: parts.first().unwrap_or(&"").to_string(),
memory_type: entry.memory_type,
category: parts.get(2).unwrap_or(&"").to_string(),
}
}
}
/// Memory cache configuration
#[derive(Debug, Clone)]
pub struct CacheConfig {
/// Maximum number of entries
pub max_entries: usize,
/// Time-to-live for entries
pub ttl: Duration,
/// Enable/disable caching
pub enabled: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
ttl: Duration::from_secs(3600), // 1 hour
enabled: true,
}
}
}
/// Memory cache for hot memories
pub struct MemoryCache {
/// Cache storage
cache: RwLock<HashMap<String, CacheEntry>>,
/// Configuration
config: CacheConfig,
/// Cache statistics
stats: RwLock<CacheStats>,
}
/// Cache statistics
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
/// Total cache hits
pub hits: u64,
/// Total cache misses
pub misses: u64,
/// Total entries evicted
pub evictions: u64,
}
impl MemoryCache {
/// Create a new memory cache
pub fn new(config: CacheConfig) -> Self {
Self {
cache: RwLock::new(HashMap::new()),
config,
stats: RwLock::new(CacheStats::default()),
}
}
/// Create with default configuration
pub fn default_config() -> Self {
Self::new(CacheConfig::default())
}
/// Get a memory from cache
pub async fn get(&self, uri: &str) -> Option<MemoryEntry> {
if !self.config.enabled {
return None;
}
let mut cache = self.cache.write().await;
if let Some(cached) = cache.get_mut(uri) {
// Check TTL
if cached.last_accessed.elapsed() > self.config.ttl {
cache.remove(uri);
return None;
}
// Update access metadata
cached.last_accessed = Instant::now();
cached.access_count += 1;
// Update stats
let mut stats = self.stats.write().await;
stats.hits += 1;
return Some(cached.entry.clone());
}
// Update stats
let mut stats = self.stats.write().await;
stats.misses += 1;
None
}
/// Put a memory into cache
pub async fn put(&self, entry: MemoryEntry) {
if !self.config.enabled {
return;
}
let mut cache = self.cache.write().await;
// Check capacity and evict if necessary
if cache.len() >= self.config.max_entries {
self.evict_lru(&mut cache).await;
}
cache.insert(
entry.uri.clone(),
CacheEntry {
entry,
last_accessed: Instant::now(),
access_count: 0,
},
);
}
/// Remove a memory from cache
pub async fn remove(&self, uri: &str) {
let mut cache = self.cache.write().await;
cache.remove(uri);
}
/// Clear the cache
pub async fn clear(&self) {
let mut cache = self.cache.write().await;
cache.clear();
}
/// Evict least recently used entries
async fn evict_lru(&self, cache: &mut HashMap<String, CacheEntry>) {
// Find LRU entry
let lru_key = cache
.iter()
.min_by_key(|(_, v)| (v.access_count, v.last_accessed))
.map(|(k, _)| k.clone());
if let Some(key) = lru_key {
cache.remove(&key);
let mut stats = self.stats.write().await;
stats.evictions += 1;
}
}
/// Get cache statistics
pub async fn stats(&self) -> CacheStats {
self.stats.read().await.clone()
}
/// Get cache hit rate
pub async fn hit_rate(&self) -> f32 {
let stats = self.stats.read().await;
let total = stats.hits + stats.misses;
if total == 0 {
return 0.0;
}
stats.hits as f32 / total as f32
}
/// Get cache size
pub async fn size(&self) -> usize {
self.cache.read().await.len()
}
/// Warm up cache with frequently accessed entries
pub async fn warmup(&self, entries: Vec<MemoryEntry>) {
for entry in entries {
self.put(entry).await;
}
}
/// Get top accessed entries (for preloading)
pub async fn get_hot_entries(&self, limit: usize) -> Vec<MemoryEntry> {
let cache = self.cache.read().await;
let mut entries: Vec<_> = cache
.values()
.map(|c| (c.access_count, c.entry.clone()))
.collect();
entries.sort_by(|a, b| b.0.cmp(&a.0));
entries.truncate(limit);
entries.into_iter().map(|(_, e)| e).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
#[tokio::test]
async fn test_cache_put_and_get() {
let cache = MemoryCache::default_config();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"User prefers concise responses".to_string(),
);
cache.put(entry.clone()).await;
let retrieved = cache.get(&entry.uri).await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "User prefers concise responses");
}
#[tokio::test]
async fn test_cache_miss() {
let cache = MemoryCache::default_config();
let retrieved = cache.get("nonexistent").await;
assert!(retrieved.is_none());
let stats = cache.stats().await;
assert_eq!(stats.misses, 1);
}
#[tokio::test]
async fn test_cache_remove() {
let cache = MemoryCache::default_config();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test".to_string(),
);
cache.put(entry.clone()).await;
cache.remove(&entry.uri).await;
let retrieved = cache.get(&entry.uri).await;
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_cache_clear() {
let cache = MemoryCache::default_config();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test".to_string(),
);
cache.put(entry).await;
cache.clear().await;
let size = cache.size().await;
assert_eq!(size, 0);
}
#[tokio::test]
async fn test_cache_stats() {
let cache = MemoryCache::default_config();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test".to_string(),
);
cache.put(entry.clone()).await;
// Hit
cache.get(&entry.uri).await;
// Miss
cache.get("nonexistent").await;
let stats = cache.stats().await;
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
let hit_rate = cache.hit_rate().await;
assert!((hit_rate - 0.5).abs() < 0.001);
}
#[tokio::test]
async fn test_cache_eviction() {
let config = CacheConfig {
max_entries: 2,
ttl: Duration::from_secs(3600),
enabled: true,
};
let cache = MemoryCache::new(config);
let entry1 = MemoryEntry::new("test", MemoryType::Preference, "1", "1".to_string());
let entry2 = MemoryEntry::new("test", MemoryType::Preference, "2", "2".to_string());
let entry3 = MemoryEntry::new("test", MemoryType::Preference, "3", "3".to_string());
cache.put(entry1.clone()).await;
cache.put(entry2.clone()).await;
// Access entry1 to make it hot
cache.get(&entry1.uri).await;
// Add entry3, should evict entry2 (LRU)
cache.put(entry3).await;
let size = cache.size().await;
assert_eq!(size, 2);
let stats = cache.stats().await;
assert_eq!(stats.evictions, 1);
}
#[tokio::test]
async fn test_get_hot_entries() {
let cache = MemoryCache::default_config();
let entry1 = MemoryEntry::new("test", MemoryType::Preference, "1", "1".to_string());
let entry2 = MemoryEntry::new("test", MemoryType::Preference, "2", "2".to_string());
cache.put(entry1.clone()).await;
cache.put(entry2.clone()).await;
// Access entry1 multiple times
cache.get(&entry1.uri).await;
cache.get(&entry1.uri).await;
let hot = cache.get_hot_entries(10).await;
assert_eq!(hot.len(), 2);
// entry1 should be first (more accesses)
assert_eq!(hot[0].uri, entry1.uri);
}
}

View File

@@ -0,0 +1,14 @@
//! Retrieval components for ZCLAW Growth System
//!
//! This module provides advanced retrieval capabilities:
//! - `semantic`: Semantic similarity computation
//! - `query`: Query analysis and expansion
//! - `cache`: Hot memory caching
pub mod semantic;
pub mod query;
pub mod cache;
pub use semantic::SemanticScorer;
pub use query::QueryAnalyzer;
pub use cache::MemoryCache;

View File

@@ -0,0 +1,352 @@
//! Query Analyzer
//!
//! Provides query analysis and expansion capabilities for improved retrieval.
//! Extracts keywords, identifies intent, and generates search variations.
use crate::types::MemoryType;
use std::collections::HashSet;
/// Query analysis result
#[derive(Debug, Clone)]
pub struct AnalyzedQuery {
/// Original query string
pub original: String,
/// Extracted keywords
pub keywords: Vec<String>,
/// Query intent
pub intent: QueryIntent,
/// Memory types to search (inferred from query)
pub target_types: Vec<MemoryType>,
/// Expanded search terms
pub expansions: Vec<String>,
}
/// Query intent classification
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryIntent {
/// Looking for preferences/settings
Preference,
/// Looking for factual knowledge
Knowledge,
/// Looking for how-to/experience
Experience,
/// General conversation
General,
/// Code-related query
Code,
/// Configuration query
Configuration,
}
/// Query analyzer
pub struct QueryAnalyzer {
/// Keywords that indicate preference queries
preference_indicators: HashSet<String>,
/// Keywords that indicate knowledge queries
knowledge_indicators: HashSet<String>,
/// Keywords that indicate experience queries
experience_indicators: HashSet<String>,
/// Keywords that indicate code queries
code_indicators: HashSet<String>,
/// Stop words to filter out
stop_words: HashSet<String>,
}
impl QueryAnalyzer {
/// Create a new query analyzer
pub fn new() -> Self {
Self {
preference_indicators: [
"prefer", "like", "want", "favorite", "favourite", "style",
"format", "language", "setting", "preference", "usually",
"typically", "always", "never", "习惯", "偏好", "喜欢", "想要",
]
.iter()
.map(|s| s.to_string())
.collect(),
knowledge_indicators: [
"what", "how", "why", "explain", "tell", "know", "learn",
"understand", "meaning", "definition", "concept", "theory",
"是什么", "怎么", "为什么", "解释", "了解", "知道",
]
.iter()
.map(|s| s.to_string())
.collect(),
experience_indicators: [
"experience", "tried", "used", "before", "last time",
"previous", "history", "remember", "recall", "when",
"经验", "尝试", "用过", "上次", "记得", "回忆",
]
.iter()
.map(|s| s.to_string())
.collect(),
code_indicators: [
"code", "function", "class", "method", "variable", "type",
"error", "bug", "fix", "implement", "refactor", "api",
"代码", "函数", "", "方法", "变量", "错误", "修复", "实现",
]
.iter()
.map(|s| s.to_string())
.collect(),
stop_words: [
"the", "a", "an", "is", "are", "was", "were", "be", "been",
"have", "has", "had", "do", "does", "did", "will", "would",
"could", "should", "may", "might", "must", "can", "to", "of",
"in", "for", "on", "with", "at", "by", "from", "as", "and",
"or", "but", "if", "then", "else", "when", "where", "which",
"who", "whom", "whose", "this", "that", "these", "those",
]
.iter()
.map(|s| s.to_string())
.collect(),
}
}
/// Analyze a query string
pub fn analyze(&self, query: &str) -> AnalyzedQuery {
let keywords = self.extract_keywords(query);
let intent = self.classify_intent(&keywords);
let target_types = self.infer_memory_types(intent, &keywords);
let expansions = self.expand_query(&keywords);
AnalyzedQuery {
original: query.to_string(),
keywords,
intent,
target_types,
expansions,
}
}
/// Extract keywords from query
fn extract_keywords(&self, query: &str) -> Vec<String> {
query
.to_lowercase()
.split(|c: char| !c.is_alphanumeric() && !is_cjk(c))
.filter(|s| !s.is_empty() && s.len() > 1)
.filter(|s| !self.stop_words.contains(*s))
.map(|s| s.to_string())
.collect()
}
/// Classify query intent
fn classify_intent(&self, keywords: &[String]) -> QueryIntent {
let mut scores = [
(QueryIntent::Preference, 0),
(QueryIntent::Knowledge, 0),
(QueryIntent::Experience, 0),
(QueryIntent::Code, 0),
];
for keyword in keywords {
if self.preference_indicators.contains(keyword) {
scores[0].1 += 2;
}
if self.knowledge_indicators.contains(keyword) {
scores[1].1 += 2;
}
if self.experience_indicators.contains(keyword) {
scores[2].1 += 2;
}
if self.code_indicators.contains(keyword) {
scores[3].1 += 2;
}
}
// Find highest scoring intent
scores.sort_by(|a, b| b.1.cmp(&a.1));
if scores[0].1 > 0 {
scores[0].0
} else {
QueryIntent::General
}
}
/// Infer which memory types to search
fn infer_memory_types(&self, intent: QueryIntent, _keywords: &[String]) -> Vec<MemoryType> {
let mut types = Vec::new();
match intent {
QueryIntent::Preference => {
types.push(MemoryType::Preference);
}
QueryIntent::Knowledge | QueryIntent::Code => {
types.push(MemoryType::Knowledge);
types.push(MemoryType::Experience);
}
QueryIntent::Experience => {
types.push(MemoryType::Experience);
types.push(MemoryType::Knowledge);
}
QueryIntent::General => {
// Search all types
types.push(MemoryType::Preference);
types.push(MemoryType::Knowledge);
types.push(MemoryType::Experience);
}
QueryIntent::Configuration => {
types.push(MemoryType::Preference);
types.push(MemoryType::Knowledge);
}
}
types
}
/// Expand query with related terms
fn expand_query(&self, keywords: &[String]) -> Vec<String> {
let mut expansions = Vec::new();
// Add stemmed variations (simplified)
for keyword in keywords {
// Add singular/plural variations
if keyword.ends_with('s') && keyword.len() > 3 {
expansions.push(keyword[..keyword.len()-1].to_string());
} else {
expansions.push(format!("{}s", keyword));
}
// Add common synonyms (simplified)
if let Some(synonyms) = self.get_synonyms(keyword) {
expansions.extend(synonyms);
}
}
expansions
}
/// Get synonyms for a keyword (simplified)
fn get_synonyms(&self, keyword: &str) -> Option<Vec<String>> {
let synonyms: &[&str] = match keyword {
"code" => &["program", "script", "source"],
"error" => &["bug", "issue", "problem", "exception"],
"fix" => &["solve", "resolve", "repair", "patch"],
"fast" => &["quick", "speed", "performance", "efficient"],
"slow" => &["performance", "optimize", "speed"],
"help" => &["assist", "support", "guide", "aid"],
"learn" => &["study", "understand", "know", "grasp"],
_ => return None,
};
Some(synonyms.iter().map(|s| s.to_string()).collect())
}
/// Generate search queries from analyzed query
pub fn generate_search_queries(&self, analyzed: &AnalyzedQuery) -> Vec<String> {
let mut queries = vec![analyzed.original.clone()];
// Add keyword-based query
if !analyzed.keywords.is_empty() {
queries.push(analyzed.keywords.join(" "));
}
// Add expanded terms
for expansion in &analyzed.expansions {
if !expansion.is_empty() {
queries.push(expansion.clone());
}
}
// Deduplicate
queries.sort();
queries.dedup();
queries
}
}
impl Default for QueryAnalyzer {
fn default() -> Self {
Self::new()
}
}
/// Check if character is CJK
fn is_cjk(c: char) -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
'\u{3400}'..='\u{4DBF}' | // CJK Unified Ideographs Extension A
'\u{20000}'..='\u{2A6DF}' | // CJK Unified Ideographs Extension B
'\u{2A700}'..='\u{2B73F}' | // CJK Unified Ideographs Extension C
'\u{2B740}'..='\u{2B81F}' | // CJK Unified Ideographs Extension D
'\u{2B820}'..='\u{2CEAF}' | // CJK Unified Ideographs Extension E
'\u{F900}'..='\u{FAFF}' | // CJK Compatibility Ideographs
'\u{2F800}'..='\u{2FA1F}' // CJK Compatibility Ideographs Supplement
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_keywords() {
let analyzer = QueryAnalyzer::new();
let keywords = analyzer.extract_keywords("What is the Rust programming language?");
assert!(keywords.contains(&"rust".to_string()));
assert!(keywords.contains(&"programming".to_string()));
assert!(keywords.contains(&"language".to_string()));
assert!(!keywords.contains(&"the".to_string())); // stop word
}
#[test]
fn test_classify_intent_preference() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("I prefer concise responses");
assert_eq!(analyzed.intent, QueryIntent::Preference);
assert!(analyzed.target_types.contains(&MemoryType::Preference));
}
#[test]
fn test_classify_intent_knowledge() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("Explain how async/await works in Rust");
assert_eq!(analyzed.intent, QueryIntent::Knowledge);
}
#[test]
fn test_classify_intent_code() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("Fix this error in my function");
assert_eq!(analyzed.intent, QueryIntent::Code);
}
#[test]
fn test_query_expansion() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("fix the error");
assert!(!analyzed.expansions.is_empty());
}
#[test]
fn test_generate_search_queries() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("Rust programming");
let queries = analyzer.generate_search_queries(&analyzed);
assert!(queries.len() >= 1);
}
#[test]
fn test_cjk_detection() {
assert!(is_cjk('中'));
assert!(is_cjk('文'));
assert!(!is_cjk('a'));
assert!(!is_cjk('1'));
}
#[test]
fn test_chinese_keywords() {
let analyzer = QueryAnalyzer::new();
let keywords = analyzer.extract_keywords("我喜欢简洁的回复");
// Chinese characters should be extracted
assert!(!keywords.is_empty());
}
}

View File

@@ -0,0 +1,374 @@
//! Semantic Similarity Scorer
//!
//! Provides TF-IDF based semantic similarity computation for memory retrieval.
//! This is a lightweight, dependency-free implementation suitable for
//! medium-scale memory systems.
use std::collections::{HashMap, HashSet};
use crate::types::MemoryEntry;
/// Semantic similarity scorer using TF-IDF
pub struct SemanticScorer {
/// Document frequency for IDF computation
document_frequencies: HashMap<String, usize>,
/// Total number of documents
total_documents: usize,
/// Precomputed TF-IDF vectors for entries
entry_vectors: HashMap<String, HashMap<String, f32>>,
/// Stop words to ignore
stop_words: HashSet<String>,
}
impl SemanticScorer {
/// Create a new semantic scorer
pub fn new() -> Self {
Self {
document_frequencies: HashMap::new(),
total_documents: 0,
entry_vectors: HashMap::new(),
stop_words: Self::default_stop_words(),
}
}
/// Get default stop words
fn default_stop_words() -> HashSet<String> {
[
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
"have", "has", "had", "do", "does", "did", "will", "would", "could",
"should", "may", "might", "must", "shall", "can", "need", "dare",
"ought", "used", "to", "of", "in", "for", "on", "with", "at", "by",
"from", "as", "into", "through", "during", "before", "after",
"above", "below", "between", "under", "again", "further", "then",
"once", "here", "there", "when", "where", "why", "how", "all",
"each", "few", "more", "most", "other", "some", "such", "no", "nor",
"not", "only", "own", "same", "so", "than", "too", "very", "just",
"and", "but", "if", "or", "because", "until", "while", "although",
"though", "after", "before", "when", "whenever", "i", "you", "he",
"she", "it", "we", "they", "what", "which", "who", "whom", "this",
"that", "these", "those", "am", "im", "youre", "hes", "shes",
"its", "were", "theyre", "ive", "youve", "weve", "theyve", "id",
"youd", "hed", "shed", "wed", "theyd", "ill", "youll", "hell",
"shell", "well", "theyll", "isnt", "arent", "wasnt", "werent",
"hasnt", "havent", "hadnt", "doesnt", "dont", "didnt", "wont",
"wouldnt", "shant", "shouldnt", "cant", "cannot", "couldnt",
"mustnt", "lets", "thats", "whos", "whats", "heres", "theres",
"whens", "wheres", "whys", "hows", "a", "b", "c", "d", "e", "f",
"g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s",
"t", "u", "v", "w", "x", "y", "z",
]
.iter()
.map(|s| s.to_string())
.collect()
}
/// Tokenize text into words
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1)
.map(|s| s.to_string())
.collect()
}
/// Remove stop words from tokens
fn remove_stop_words(&self, tokens: &[String]) -> Vec<String> {
tokens
.iter()
.filter(|t| !self.stop_words.contains(*t))
.cloned()
.collect()
}
/// Compute term frequency for a list of tokens
fn compute_tf(tokens: &[String]) -> HashMap<String, f32> {
let mut tf = HashMap::new();
let total = tokens.len() as f32;
for token in tokens {
*tf.entry(token.clone()).or_insert(0.0) += 1.0;
}
// Normalize by total tokens
for count in tf.values_mut() {
*count /= total;
}
tf
}
/// Compute IDF for a term
fn compute_idf(&self, term: &str) -> f32 {
let df = self.document_frequencies.get(term).copied().unwrap_or(0);
if df == 0 || self.total_documents == 0 {
return 0.0;
}
((self.total_documents as f32 + 1.0) / (df as f32 + 1.0)).ln() + 1.0
}
/// Index an entry for semantic search
pub fn index_entry(&mut self, entry: &MemoryEntry) {
// Tokenize content and keywords
let mut all_tokens = Self::tokenize(&entry.content);
for keyword in &entry.keywords {
all_tokens.extend(Self::tokenize(keyword));
}
all_tokens = self.remove_stop_words(&all_tokens);
// Update document frequencies
let unique_terms: HashSet<_> = all_tokens.iter().cloned().collect();
for term in &unique_terms {
*self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
}
self.total_documents += 1;
// Compute TF-IDF vector
let tf = Self::compute_tf(&all_tokens);
let mut tfidf = HashMap::new();
for (term, tf_val) in tf {
let idf = self.compute_idf(&term);
tfidf.insert(term, tf_val * idf);
}
self.entry_vectors.insert(entry.uri.clone(), tfidf);
}
/// Remove an entry from the index
pub fn remove_entry(&mut self, uri: &str) {
self.entry_vectors.remove(uri);
}
/// Compute cosine similarity between two vectors
fn cosine_similarity(v1: &HashMap<String, f32>, v2: &HashMap<String, f32>) -> f32 {
if v1.is_empty() || v2.is_empty() {
return 0.0;
}
// Find common keys
let mut dot_product = 0.0;
let mut norm1 = 0.0;
let mut norm2 = 0.0;
for (k, v) in v1 {
norm1 += v * v;
if let Some(v2_val) = v2.get(k) {
dot_product += v * v2_val;
}
}
for v in v2.values() {
norm2 += v * v;
}
let denom = (norm1 * norm2).sqrt();
if denom == 0.0 {
0.0
} else {
(dot_product / denom).clamp(0.0, 1.0)
}
}
/// Score similarity between query and entry
pub fn score_similarity(&self, query: &str, entry: &MemoryEntry) -> f32 {
// Tokenize query
let query_tokens = self.remove_stop_words(&Self::tokenize(query));
if query_tokens.is_empty() {
return 0.5; // Neutral score for empty query
}
// Compute query TF-IDF
let query_tf = Self::compute_tf(&query_tokens);
let mut query_vec = HashMap::new();
for (term, tf_val) in query_tf {
let idf = self.compute_idf(&term);
query_vec.insert(term, tf_val * idf);
}
// Get entry vector
let entry_vec = match self.entry_vectors.get(&entry.uri) {
Some(v) => v,
None => {
// Fall back to simple matching if not indexed
return self.fallback_similarity(&query_tokens, entry);
}
};
// Compute cosine similarity
let cosine = Self::cosine_similarity(&query_vec, entry_vec);
// Combine with keyword matching for better results
let keyword_boost = self.keyword_match_score(&query_tokens, entry);
// Weighted combination
cosine * 0.7 + keyword_boost * 0.3
}
/// Fallback similarity when entry is not indexed
fn fallback_similarity(&self, query_tokens: &[String], entry: &MemoryEntry) -> f32 {
let content_lower = entry.content.to_lowercase();
let mut matches = 0;
for token in query_tokens {
if content_lower.contains(token) {
matches += 1;
}
for keyword in &entry.keywords {
if keyword.to_lowercase().contains(token) {
matches += 1;
break;
}
}
}
(matches as f32) / (query_tokens.len() * 2).max(1) as f32
}
/// Compute keyword match score
fn keyword_match_score(&self, query_tokens: &[String], entry: &MemoryEntry) -> f32 {
if entry.keywords.is_empty() {
return 0.0;
}
let mut matches = 0;
for token in query_tokens {
for keyword in &entry.keywords {
if keyword.to_lowercase().contains(&token.to_lowercase()) {
matches += 1;
break;
}
}
}
(matches as f32) / query_tokens.len().max(1) as f32
}
/// Clear the index
pub fn clear(&mut self) {
self.document_frequencies.clear();
self.total_documents = 0;
self.entry_vectors.clear();
}
/// Get statistics about the index
pub fn stats(&self) -> IndexStats {
IndexStats {
total_documents: self.total_documents,
unique_terms: self.document_frequencies.len(),
indexed_entries: self.entry_vectors.len(),
}
}
}
impl Default for SemanticScorer {
fn default() -> Self {
Self::new()
}
}
/// Index statistics
#[derive(Debug, Clone)]
pub struct IndexStats {
pub total_documents: usize,
pub unique_terms: usize,
pub indexed_entries: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
#[test]
fn test_tokenize() {
let tokens = SemanticScorer::tokenize("Hello, World! This is a test.");
assert_eq!(tokens, vec!["hello", "world", "this", "is", "test"]);
}
#[test]
fn test_stop_words_removal() {
let scorer = SemanticScorer::new();
let tokens = vec!["hello".to_string(), "the".to_string(), "world".to_string()];
let filtered = scorer.remove_stop_words(&tokens);
assert_eq!(filtered, vec!["hello", "world"]);
}
#[test]
fn test_tf_computation() {
let tokens = vec!["hello".to_string(), "hello".to_string(), "world".to_string()];
let tf = SemanticScorer::compute_tf(&tokens);
let hello_tf = tf.get("hello").unwrap();
let world_tf = tf.get("world").unwrap();
// Allow for floating point comparison
assert!((hello_tf - (2.0 / 3.0)).abs() < 0.001);
assert!((world_tf - (1.0 / 3.0)).abs() < 0.001);
}
#[test]
fn test_cosine_similarity() {
let mut v1 = HashMap::new();
v1.insert("a".to_string(), 1.0);
v1.insert("b".to_string(), 2.0);
let mut v2 = HashMap::new();
v2.insert("a".to_string(), 1.0);
v2.insert("b".to_string(), 2.0);
// Identical vectors should have similarity 1.0
let sim = SemanticScorer::cosine_similarity(&v1, &v2);
assert!((sim - 1.0).abs() < 0.001);
// Orthogonal vectors should have similarity 0.0
let mut v3 = HashMap::new();
v3.insert("c".to_string(), 1.0);
let sim2 = SemanticScorer::cosine_similarity(&v1, &v3);
assert!((sim2 - 0.0).abs() < 0.001);
}
#[test]
fn test_index_and_score() {
let mut scorer = SemanticScorer::new();
let entry1 = MemoryEntry::new(
"test",
MemoryType::Knowledge,
"rust",
"Rust is a systems programming language focused on safety and performance".to_string(),
).with_keywords(vec!["rust".to_string(), "programming".to_string(), "safety".to_string()]);
let entry2 = MemoryEntry::new(
"test",
MemoryType::Knowledge,
"python",
"Python is a high-level programming language".to_string(),
).with_keywords(vec!["python".to_string(), "programming".to_string()]);
scorer.index_entry(&entry1);
scorer.index_entry(&entry2);
// Query for Rust should score higher on entry1
let score1 = scorer.score_similarity("rust safety", &entry1);
let score2 = scorer.score_similarity("rust safety", &entry2);
assert!(score1 > score2, "Rust query should score higher on Rust entry");
}
#[test]
fn test_stats() {
let mut scorer = SemanticScorer::new();
let entry = MemoryEntry::new(
"test",
MemoryType::Knowledge,
"test",
"Hello world".to_string(),
);
scorer.index_entry(&entry);
let stats = scorer.stats();
assert_eq!(stats.total_documents, 1);
assert_eq!(stats.indexed_entries, 1);
assert!(stats.unique_terms > 0);
}
}