fix(presentation): 修复 presentation 模块类型错误和语法问题
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- 创建 types.ts 定义完整的类型系统 - 重写 DocumentRenderer.tsx 修复语法错误 - 重写 QuizRenderer.tsx 修复语法错误 - 重写 PresentationContainer.tsx 添加类型守卫 - 重写 TypeSwitcher.tsx 修复类型引用 - 更新 index.ts 移除不存在的 ChartRenderer 导出 审计结果: - 类型检查: 通过 - 单元测试: 222 passed - 构建: 成功
This commit is contained in:
365
crates/zclaw-growth/src/retrieval/cache.rs
Normal file
365
crates/zclaw-growth/src/retrieval/cache.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
//! Memory Cache
|
||||
//!
|
||||
//! Provides caching for frequently accessed memories to improve
|
||||
//! retrieval performance.
|
||||
|
||||
use crate::types::{MemoryEntry, MemoryType};
|
||||
use std::collections::HashMap;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Cache entry with metadata
|
||||
struct CacheEntry {
|
||||
/// The memory entry
|
||||
entry: MemoryEntry,
|
||||
/// Last access time
|
||||
last_accessed: Instant,
|
||||
/// Access count
|
||||
access_count: u32,
|
||||
}
|
||||
|
||||
/// Cache key for efficient lookups
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
struct CacheKey {
|
||||
agent_id: String,
|
||||
memory_type: MemoryType,
|
||||
category: String,
|
||||
}
|
||||
|
||||
impl From<&MemoryEntry> for CacheKey {
|
||||
fn from(entry: &MemoryEntry) -> Self {
|
||||
// Parse URI to extract components
|
||||
let parts: Vec<&str> = entry.uri.trim_start_matches("agent://").split('/').collect();
|
||||
Self {
|
||||
agent_id: parts.first().unwrap_or(&"").to_string(),
|
||||
memory_type: entry.memory_type,
|
||||
category: parts.get(2).unwrap_or(&"").to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory cache configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheConfig {
|
||||
/// Maximum number of entries
|
||||
pub max_entries: usize,
|
||||
/// Time-to-live for entries
|
||||
pub ttl: Duration,
|
||||
/// Enable/disable caching
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_entries: 1000,
|
||||
ttl: Duration::from_secs(3600), // 1 hour
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory cache for hot memories
|
||||
pub struct MemoryCache {
|
||||
/// Cache storage
|
||||
cache: RwLock<HashMap<String, CacheEntry>>,
|
||||
/// Configuration
|
||||
config: CacheConfig,
|
||||
/// Cache statistics
|
||||
stats: RwLock<CacheStats>,
|
||||
}
|
||||
|
||||
/// Cache statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CacheStats {
|
||||
/// Total cache hits
|
||||
pub hits: u64,
|
||||
/// Total cache misses
|
||||
pub misses: u64,
|
||||
/// Total entries evicted
|
||||
pub evictions: u64,
|
||||
}
|
||||
|
||||
impl MemoryCache {
|
||||
/// Create a new memory cache
|
||||
pub fn new(config: CacheConfig) -> Self {
|
||||
Self {
|
||||
cache: RwLock::new(HashMap::new()),
|
||||
config,
|
||||
stats: RwLock::new(CacheStats::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(CacheConfig::default())
|
||||
}
|
||||
|
||||
/// Get a memory from cache
|
||||
pub async fn get(&self, uri: &str) -> Option<MemoryEntry> {
|
||||
if !self.config.enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut cache = self.cache.write().await;
|
||||
|
||||
if let Some(cached) = cache.get_mut(uri) {
|
||||
// Check TTL
|
||||
if cached.last_accessed.elapsed() > self.config.ttl {
|
||||
cache.remove(uri);
|
||||
return None;
|
||||
}
|
||||
|
||||
// Update access metadata
|
||||
cached.last_accessed = Instant::now();
|
||||
cached.access_count += 1;
|
||||
|
||||
// Update stats
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.hits += 1;
|
||||
|
||||
return Some(cached.entry.clone());
|
||||
}
|
||||
|
||||
// Update stats
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.misses += 1;
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Put a memory into cache
|
||||
pub async fn put(&self, entry: MemoryEntry) {
|
||||
if !self.config.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut cache = self.cache.write().await;
|
||||
|
||||
// Check capacity and evict if necessary
|
||||
if cache.len() >= self.config.max_entries {
|
||||
self.evict_lru(&mut cache).await;
|
||||
}
|
||||
|
||||
cache.insert(
|
||||
entry.uri.clone(),
|
||||
CacheEntry {
|
||||
entry,
|
||||
last_accessed: Instant::now(),
|
||||
access_count: 0,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove a memory from cache
|
||||
pub async fn remove(&self, uri: &str) {
|
||||
let mut cache = self.cache.write().await;
|
||||
cache.remove(uri);
|
||||
}
|
||||
|
||||
/// Clear the cache
|
||||
pub async fn clear(&self) {
|
||||
let mut cache = self.cache.write().await;
|
||||
cache.clear();
|
||||
}
|
||||
|
||||
/// Evict least recently used entries
|
||||
async fn evict_lru(&self, cache: &mut HashMap<String, CacheEntry>) {
|
||||
// Find LRU entry
|
||||
let lru_key = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| (v.access_count, v.last_accessed))
|
||||
.map(|(k, _)| k.clone());
|
||||
|
||||
if let Some(key) = lru_key {
|
||||
cache.remove(&key);
|
||||
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.evictions += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub async fn stats(&self) -> CacheStats {
|
||||
self.stats.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get cache hit rate
|
||||
pub async fn hit_rate(&self) -> f32 {
|
||||
let stats = self.stats.read().await;
|
||||
let total = stats.hits + stats.misses;
|
||||
if total == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
stats.hits as f32 / total as f32
|
||||
}
|
||||
|
||||
/// Get cache size
|
||||
pub async fn size(&self) -> usize {
|
||||
self.cache.read().await.len()
|
||||
}
|
||||
|
||||
/// Warm up cache with frequently accessed entries
|
||||
pub async fn warmup(&self, entries: Vec<MemoryEntry>) {
|
||||
for entry in entries {
|
||||
self.put(entry).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get top accessed entries (for preloading)
|
||||
pub async fn get_hot_entries(&self, limit: usize) -> Vec<MemoryEntry> {
|
||||
let cache = self.cache.read().await;
|
||||
|
||||
let mut entries: Vec<_> = cache
|
||||
.values()
|
||||
.map(|c| (c.access_count, c.entry.clone()))
|
||||
.collect();
|
||||
|
||||
entries.sort_by(|a, b| b.0.cmp(&a.0));
|
||||
entries.truncate(limit);
|
||||
|
||||
entries.into_iter().map(|(_, e)| e).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::MemoryType;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_put_and_get() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"User prefers concise responses".to_string(),
|
||||
);
|
||||
|
||||
cache.put(entry.clone()).await;
|
||||
let retrieved = cache.get(&entry.uri).await;
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().content, "User prefers concise responses");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_miss() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let retrieved = cache.get("nonexistent").await;
|
||||
|
||||
assert!(retrieved.is_none());
|
||||
|
||||
let stats = cache.stats().await;
|
||||
assert_eq!(stats.misses, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_remove() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
cache.put(entry.clone()).await;
|
||||
cache.remove(&entry.uri).await;
|
||||
let retrieved = cache.get(&entry.uri).await;
|
||||
|
||||
assert!(retrieved.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_clear() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
cache.put(entry).await;
|
||||
cache.clear().await;
|
||||
let size = cache.size().await;
|
||||
|
||||
assert_eq!(size, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_stats() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
cache.put(entry.clone()).await;
|
||||
|
||||
// Hit
|
||||
cache.get(&entry.uri).await;
|
||||
// Miss
|
||||
cache.get("nonexistent").await;
|
||||
|
||||
let stats = cache.stats().await;
|
||||
assert_eq!(stats.hits, 1);
|
||||
assert_eq!(stats.misses, 1);
|
||||
|
||||
let hit_rate = cache.hit_rate().await;
|
||||
assert!((hit_rate - 0.5).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_eviction() {
|
||||
let config = CacheConfig {
|
||||
max_entries: 2,
|
||||
ttl: Duration::from_secs(3600),
|
||||
enabled: true,
|
||||
};
|
||||
let cache = MemoryCache::new(config);
|
||||
|
||||
let entry1 = MemoryEntry::new("test", MemoryType::Preference, "1", "1".to_string());
|
||||
let entry2 = MemoryEntry::new("test", MemoryType::Preference, "2", "2".to_string());
|
||||
let entry3 = MemoryEntry::new("test", MemoryType::Preference, "3", "3".to_string());
|
||||
|
||||
cache.put(entry1.clone()).await;
|
||||
cache.put(entry2.clone()).await;
|
||||
|
||||
// Access entry1 to make it hot
|
||||
cache.get(&entry1.uri).await;
|
||||
|
||||
// Add entry3, should evict entry2 (LRU)
|
||||
cache.put(entry3).await;
|
||||
|
||||
let size = cache.size().await;
|
||||
assert_eq!(size, 2);
|
||||
|
||||
let stats = cache.stats().await;
|
||||
assert_eq!(stats.evictions, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_hot_entries() {
|
||||
let cache = MemoryCache::default_config();
|
||||
|
||||
let entry1 = MemoryEntry::new("test", MemoryType::Preference, "1", "1".to_string());
|
||||
let entry2 = MemoryEntry::new("test", MemoryType::Preference, "2", "2".to_string());
|
||||
|
||||
cache.put(entry1.clone()).await;
|
||||
cache.put(entry2.clone()).await;
|
||||
|
||||
// Access entry1 multiple times
|
||||
cache.get(&entry1.uri).await;
|
||||
cache.get(&entry1.uri).await;
|
||||
|
||||
let hot = cache.get_hot_entries(10).await;
|
||||
assert_eq!(hot.len(), 2);
|
||||
// entry1 should be first (more accesses)
|
||||
assert_eq!(hot[0].uri, entry1.uri);
|
||||
}
|
||||
}
|
||||
14
crates/zclaw-growth/src/retrieval/mod.rs
Normal file
14
crates/zclaw-growth/src/retrieval/mod.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
//! Retrieval components for ZCLAW Growth System
|
||||
//!
|
||||
//! This module provides advanced retrieval capabilities:
|
||||
//! - `semantic`: Semantic similarity computation
|
||||
//! - `query`: Query analysis and expansion
|
||||
//! - `cache`: Hot memory caching
|
||||
|
||||
pub mod semantic;
|
||||
pub mod query;
|
||||
pub mod cache;
|
||||
|
||||
pub use semantic::SemanticScorer;
|
||||
pub use query::QueryAnalyzer;
|
||||
pub use cache::MemoryCache;
|
||||
352
crates/zclaw-growth/src/retrieval/query.rs
Normal file
352
crates/zclaw-growth/src/retrieval/query.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
//! Query Analyzer
|
||||
//!
|
||||
//! Provides query analysis and expansion capabilities for improved retrieval.
|
||||
//! Extracts keywords, identifies intent, and generates search variations.
|
||||
|
||||
use crate::types::MemoryType;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Query analysis result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnalyzedQuery {
|
||||
/// Original query string
|
||||
pub original: String,
|
||||
/// Extracted keywords
|
||||
pub keywords: Vec<String>,
|
||||
/// Query intent
|
||||
pub intent: QueryIntent,
|
||||
/// Memory types to search (inferred from query)
|
||||
pub target_types: Vec<MemoryType>,
|
||||
/// Expanded search terms
|
||||
pub expansions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Query intent classification
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum QueryIntent {
|
||||
/// Looking for preferences/settings
|
||||
Preference,
|
||||
/// Looking for factual knowledge
|
||||
Knowledge,
|
||||
/// Looking for how-to/experience
|
||||
Experience,
|
||||
/// General conversation
|
||||
General,
|
||||
/// Code-related query
|
||||
Code,
|
||||
/// Configuration query
|
||||
Configuration,
|
||||
}
|
||||
|
||||
/// Query analyzer
|
||||
pub struct QueryAnalyzer {
|
||||
/// Keywords that indicate preference queries
|
||||
preference_indicators: HashSet<String>,
|
||||
/// Keywords that indicate knowledge queries
|
||||
knowledge_indicators: HashSet<String>,
|
||||
/// Keywords that indicate experience queries
|
||||
experience_indicators: HashSet<String>,
|
||||
/// Keywords that indicate code queries
|
||||
code_indicators: HashSet<String>,
|
||||
/// Stop words to filter out
|
||||
stop_words: HashSet<String>,
|
||||
}
|
||||
|
||||
impl QueryAnalyzer {
|
||||
/// Create a new query analyzer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
preference_indicators: [
|
||||
"prefer", "like", "want", "favorite", "favourite", "style",
|
||||
"format", "language", "setting", "preference", "usually",
|
||||
"typically", "always", "never", "习惯", "偏好", "喜欢", "想要",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
knowledge_indicators: [
|
||||
"what", "how", "why", "explain", "tell", "know", "learn",
|
||||
"understand", "meaning", "definition", "concept", "theory",
|
||||
"是什么", "怎么", "为什么", "解释", "了解", "知道",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
experience_indicators: [
|
||||
"experience", "tried", "used", "before", "last time",
|
||||
"previous", "history", "remember", "recall", "when",
|
||||
"经验", "尝试", "用过", "上次", "记得", "回忆",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
code_indicators: [
|
||||
"code", "function", "class", "method", "variable", "type",
|
||||
"error", "bug", "fix", "implement", "refactor", "api",
|
||||
"代码", "函数", "类", "方法", "变量", "错误", "修复", "实现",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
stop_words: [
|
||||
"the", "a", "an", "is", "are", "was", "were", "be", "been",
|
||||
"have", "has", "had", "do", "does", "did", "will", "would",
|
||||
"could", "should", "may", "might", "must", "can", "to", "of",
|
||||
"in", "for", "on", "with", "at", "by", "from", "as", "and",
|
||||
"or", "but", "if", "then", "else", "when", "where", "which",
|
||||
"who", "whom", "whose", "this", "that", "these", "those",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze a query string
|
||||
pub fn analyze(&self, query: &str) -> AnalyzedQuery {
|
||||
let keywords = self.extract_keywords(query);
|
||||
let intent = self.classify_intent(&keywords);
|
||||
let target_types = self.infer_memory_types(intent, &keywords);
|
||||
let expansions = self.expand_query(&keywords);
|
||||
|
||||
AnalyzedQuery {
|
||||
original: query.to_string(),
|
||||
keywords,
|
||||
intent,
|
||||
target_types,
|
||||
expansions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract keywords from query
|
||||
fn extract_keywords(&self, query: &str) -> Vec<String> {
|
||||
query
|
||||
.to_lowercase()
|
||||
.split(|c: char| !c.is_alphanumeric() && !is_cjk(c))
|
||||
.filter(|s| !s.is_empty() && s.len() > 1)
|
||||
.filter(|s| !self.stop_words.contains(*s))
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Classify query intent
|
||||
fn classify_intent(&self, keywords: &[String]) -> QueryIntent {
|
||||
let mut scores = [
|
||||
(QueryIntent::Preference, 0),
|
||||
(QueryIntent::Knowledge, 0),
|
||||
(QueryIntent::Experience, 0),
|
||||
(QueryIntent::Code, 0),
|
||||
];
|
||||
|
||||
for keyword in keywords {
|
||||
if self.preference_indicators.contains(keyword) {
|
||||
scores[0].1 += 2;
|
||||
}
|
||||
if self.knowledge_indicators.contains(keyword) {
|
||||
scores[1].1 += 2;
|
||||
}
|
||||
if self.experience_indicators.contains(keyword) {
|
||||
scores[2].1 += 2;
|
||||
}
|
||||
if self.code_indicators.contains(keyword) {
|
||||
scores[3].1 += 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Find highest scoring intent
|
||||
scores.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
|
||||
if scores[0].1 > 0 {
|
||||
scores[0].0
|
||||
} else {
|
||||
QueryIntent::General
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer which memory types to search
|
||||
fn infer_memory_types(&self, intent: QueryIntent, _keywords: &[String]) -> Vec<MemoryType> {
|
||||
let mut types = Vec::new();
|
||||
|
||||
match intent {
|
||||
QueryIntent::Preference => {
|
||||
types.push(MemoryType::Preference);
|
||||
}
|
||||
QueryIntent::Knowledge | QueryIntent::Code => {
|
||||
types.push(MemoryType::Knowledge);
|
||||
types.push(MemoryType::Experience);
|
||||
}
|
||||
QueryIntent::Experience => {
|
||||
types.push(MemoryType::Experience);
|
||||
types.push(MemoryType::Knowledge);
|
||||
}
|
||||
QueryIntent::General => {
|
||||
// Search all types
|
||||
types.push(MemoryType::Preference);
|
||||
types.push(MemoryType::Knowledge);
|
||||
types.push(MemoryType::Experience);
|
||||
}
|
||||
QueryIntent::Configuration => {
|
||||
types.push(MemoryType::Preference);
|
||||
types.push(MemoryType::Knowledge);
|
||||
}
|
||||
}
|
||||
|
||||
types
|
||||
}
|
||||
|
||||
/// Expand query with related terms
|
||||
fn expand_query(&self, keywords: &[String]) -> Vec<String> {
|
||||
let mut expansions = Vec::new();
|
||||
|
||||
// Add stemmed variations (simplified)
|
||||
for keyword in keywords {
|
||||
// Add singular/plural variations
|
||||
if keyword.ends_with('s') && keyword.len() > 3 {
|
||||
expansions.push(keyword[..keyword.len()-1].to_string());
|
||||
} else {
|
||||
expansions.push(format!("{}s", keyword));
|
||||
}
|
||||
|
||||
// Add common synonyms (simplified)
|
||||
if let Some(synonyms) = self.get_synonyms(keyword) {
|
||||
expansions.extend(synonyms);
|
||||
}
|
||||
}
|
||||
|
||||
expansions
|
||||
}
|
||||
|
||||
/// Get synonyms for a keyword (simplified)
|
||||
fn get_synonyms(&self, keyword: &str) -> Option<Vec<String>> {
|
||||
let synonyms: &[&str] = match keyword {
|
||||
"code" => &["program", "script", "source"],
|
||||
"error" => &["bug", "issue", "problem", "exception"],
|
||||
"fix" => &["solve", "resolve", "repair", "patch"],
|
||||
"fast" => &["quick", "speed", "performance", "efficient"],
|
||||
"slow" => &["performance", "optimize", "speed"],
|
||||
"help" => &["assist", "support", "guide", "aid"],
|
||||
"learn" => &["study", "understand", "know", "grasp"],
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
Some(synonyms.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
|
||||
/// Generate search queries from analyzed query
|
||||
pub fn generate_search_queries(&self, analyzed: &AnalyzedQuery) -> Vec<String> {
|
||||
let mut queries = vec![analyzed.original.clone()];
|
||||
|
||||
// Add keyword-based query
|
||||
if !analyzed.keywords.is_empty() {
|
||||
queries.push(analyzed.keywords.join(" "));
|
||||
}
|
||||
|
||||
// Add expanded terms
|
||||
for expansion in &analyzed.expansions {
|
||||
if !expansion.is_empty() {
|
||||
queries.push(expansion.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate
|
||||
queries.sort();
|
||||
queries.dedup();
|
||||
|
||||
queries
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QueryAnalyzer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if character is CJK
|
||||
fn is_cjk(c: char) -> bool {
|
||||
matches!(c,
|
||||
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
|
||||
'\u{3400}'..='\u{4DBF}' | // CJK Unified Ideographs Extension A
|
||||
'\u{20000}'..='\u{2A6DF}' | // CJK Unified Ideographs Extension B
|
||||
'\u{2A700}'..='\u{2B73F}' | // CJK Unified Ideographs Extension C
|
||||
'\u{2B740}'..='\u{2B81F}' | // CJK Unified Ideographs Extension D
|
||||
'\u{2B820}'..='\u{2CEAF}' | // CJK Unified Ideographs Extension E
|
||||
'\u{F900}'..='\u{FAFF}' | // CJK Compatibility Ideographs
|
||||
'\u{2F800}'..='\u{2FA1F}' // CJK Compatibility Ideographs Supplement
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_keywords() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let keywords = analyzer.extract_keywords("What is the Rust programming language?");
|
||||
|
||||
assert!(keywords.contains(&"rust".to_string()));
|
||||
assert!(keywords.contains(&"programming".to_string()));
|
||||
assert!(keywords.contains(&"language".to_string()));
|
||||
assert!(!keywords.contains(&"the".to_string())); // stop word
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_intent_preference() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("I prefer concise responses");
|
||||
|
||||
assert_eq!(analyzed.intent, QueryIntent::Preference);
|
||||
assert!(analyzed.target_types.contains(&MemoryType::Preference));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_intent_knowledge() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("Explain how async/await works in Rust");
|
||||
|
||||
assert_eq!(analyzed.intent, QueryIntent::Knowledge);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_intent_code() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("Fix this error in my function");
|
||||
|
||||
assert_eq!(analyzed.intent, QueryIntent::Code);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_expansion() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("fix the error");
|
||||
|
||||
assert!(!analyzed.expansions.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_search_queries() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("Rust programming");
|
||||
let queries = analyzer.generate_search_queries(&analyzed);
|
||||
|
||||
assert!(queries.len() >= 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cjk_detection() {
|
||||
assert!(is_cjk('中'));
|
||||
assert!(is_cjk('文'));
|
||||
assert!(!is_cjk('a'));
|
||||
assert!(!is_cjk('1'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chinese_keywords() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let keywords = analyzer.extract_keywords("我喜欢简洁的回复");
|
||||
|
||||
// Chinese characters should be extracted
|
||||
assert!(!keywords.is_empty());
|
||||
}
|
||||
}
|
||||
374
crates/zclaw-growth/src/retrieval/semantic.rs
Normal file
374
crates/zclaw-growth/src/retrieval/semantic.rs
Normal file
@@ -0,0 +1,374 @@
|
||||
//! Semantic Similarity Scorer
|
||||
//!
|
||||
//! Provides TF-IDF based semantic similarity computation for memory retrieval.
|
||||
//! This is a lightweight, dependency-free implementation suitable for
|
||||
//! medium-scale memory systems.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use crate::types::MemoryEntry;
|
||||
|
||||
/// Semantic similarity scorer using TF-IDF
|
||||
pub struct SemanticScorer {
|
||||
/// Document frequency for IDF computation
|
||||
document_frequencies: HashMap<String, usize>,
|
||||
/// Total number of documents
|
||||
total_documents: usize,
|
||||
/// Precomputed TF-IDF vectors for entries
|
||||
entry_vectors: HashMap<String, HashMap<String, f32>>,
|
||||
/// Stop words to ignore
|
||||
stop_words: HashSet<String>,
|
||||
}
|
||||
|
||||
impl SemanticScorer {
|
||||
/// Create a new semantic scorer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
document_frequencies: HashMap::new(),
|
||||
total_documents: 0,
|
||||
entry_vectors: HashMap::new(),
|
||||
stop_words: Self::default_stop_words(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get default stop words
|
||||
fn default_stop_words() -> HashSet<String> {
|
||||
[
|
||||
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
|
||||
"have", "has", "had", "do", "does", "did", "will", "would", "could",
|
||||
"should", "may", "might", "must", "shall", "can", "need", "dare",
|
||||
"ought", "used", "to", "of", "in", "for", "on", "with", "at", "by",
|
||||
"from", "as", "into", "through", "during", "before", "after",
|
||||
"above", "below", "between", "under", "again", "further", "then",
|
||||
"once", "here", "there", "when", "where", "why", "how", "all",
|
||||
"each", "few", "more", "most", "other", "some", "such", "no", "nor",
|
||||
"not", "only", "own", "same", "so", "than", "too", "very", "just",
|
||||
"and", "but", "if", "or", "because", "until", "while", "although",
|
||||
"though", "after", "before", "when", "whenever", "i", "you", "he",
|
||||
"she", "it", "we", "they", "what", "which", "who", "whom", "this",
|
||||
"that", "these", "those", "am", "im", "youre", "hes", "shes",
|
||||
"its", "were", "theyre", "ive", "youve", "weve", "theyve", "id",
|
||||
"youd", "hed", "shed", "wed", "theyd", "ill", "youll", "hell",
|
||||
"shell", "well", "theyll", "isnt", "arent", "wasnt", "werent",
|
||||
"hasnt", "havent", "hadnt", "doesnt", "dont", "didnt", "wont",
|
||||
"wouldnt", "shant", "shouldnt", "cant", "cannot", "couldnt",
|
||||
"mustnt", "lets", "thats", "whos", "whats", "heres", "theres",
|
||||
"whens", "wheres", "whys", "hows", "a", "b", "c", "d", "e", "f",
|
||||
"g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s",
|
||||
"t", "u", "v", "w", "x", "y", "z",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Tokenize text into words
|
||||
fn tokenize(text: &str) -> Vec<String> {
|
||||
text.to_lowercase()
|
||||
.split(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty() && s.len() > 1)
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Remove stop words from tokens
|
||||
fn remove_stop_words(&self, tokens: &[String]) -> Vec<String> {
|
||||
tokens
|
||||
.iter()
|
||||
.filter(|t| !self.stop_words.contains(*t))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute term frequency for a list of tokens
|
||||
fn compute_tf(tokens: &[String]) -> HashMap<String, f32> {
|
||||
let mut tf = HashMap::new();
|
||||
let total = tokens.len() as f32;
|
||||
|
||||
for token in tokens {
|
||||
*tf.entry(token.clone()).or_insert(0.0) += 1.0;
|
||||
}
|
||||
|
||||
// Normalize by total tokens
|
||||
for count in tf.values_mut() {
|
||||
*count /= total;
|
||||
}
|
||||
|
||||
tf
|
||||
}
|
||||
|
||||
/// Compute IDF for a term
|
||||
fn compute_idf(&self, term: &str) -> f32 {
|
||||
let df = self.document_frequencies.get(term).copied().unwrap_or(0);
|
||||
if df == 0 || self.total_documents == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
((self.total_documents as f32 + 1.0) / (df as f32 + 1.0)).ln() + 1.0
|
||||
}
|
||||
|
||||
/// Index an entry for semantic search
|
||||
pub fn index_entry(&mut self, entry: &MemoryEntry) {
|
||||
// Tokenize content and keywords
|
||||
let mut all_tokens = Self::tokenize(&entry.content);
|
||||
for keyword in &entry.keywords {
|
||||
all_tokens.extend(Self::tokenize(keyword));
|
||||
}
|
||||
all_tokens = self.remove_stop_words(&all_tokens);
|
||||
|
||||
// Update document frequencies
|
||||
let unique_terms: HashSet<_> = all_tokens.iter().cloned().collect();
|
||||
for term in &unique_terms {
|
||||
*self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
|
||||
}
|
||||
self.total_documents += 1;
|
||||
|
||||
// Compute TF-IDF vector
|
||||
let tf = Self::compute_tf(&all_tokens);
|
||||
let mut tfidf = HashMap::new();
|
||||
for (term, tf_val) in tf {
|
||||
let idf = self.compute_idf(&term);
|
||||
tfidf.insert(term, tf_val * idf);
|
||||
}
|
||||
|
||||
self.entry_vectors.insert(entry.uri.clone(), tfidf);
|
||||
}
|
||||
|
||||
/// Remove an entry from the index
|
||||
pub fn remove_entry(&mut self, uri: &str) {
|
||||
self.entry_vectors.remove(uri);
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
fn cosine_similarity(v1: &HashMap<String, f32>, v2: &HashMap<String, f32>) -> f32 {
|
||||
if v1.is_empty() || v2.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Find common keys
|
||||
let mut dot_product = 0.0;
|
||||
let mut norm1 = 0.0;
|
||||
let mut norm2 = 0.0;
|
||||
|
||||
for (k, v) in v1 {
|
||||
norm1 += v * v;
|
||||
if let Some(v2_val) = v2.get(k) {
|
||||
dot_product += v * v2_val;
|
||||
}
|
||||
}
|
||||
|
||||
for v in v2.values() {
|
||||
norm2 += v * v;
|
||||
}
|
||||
|
||||
let denom = (norm1 * norm2).sqrt();
|
||||
if denom == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
(dot_product / denom).clamp(0.0, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Score similarity between query and entry
|
||||
pub fn score_similarity(&self, query: &str, entry: &MemoryEntry) -> f32 {
|
||||
// Tokenize query
|
||||
let query_tokens = self.remove_stop_words(&Self::tokenize(query));
|
||||
if query_tokens.is_empty() {
|
||||
return 0.5; // Neutral score for empty query
|
||||
}
|
||||
|
||||
// Compute query TF-IDF
|
||||
let query_tf = Self::compute_tf(&query_tokens);
|
||||
let mut query_vec = HashMap::new();
|
||||
for (term, tf_val) in query_tf {
|
||||
let idf = self.compute_idf(&term);
|
||||
query_vec.insert(term, tf_val * idf);
|
||||
}
|
||||
|
||||
// Get entry vector
|
||||
let entry_vec = match self.entry_vectors.get(&entry.uri) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
// Fall back to simple matching if not indexed
|
||||
return self.fallback_similarity(&query_tokens, entry);
|
||||
}
|
||||
};
|
||||
|
||||
// Compute cosine similarity
|
||||
let cosine = Self::cosine_similarity(&query_vec, entry_vec);
|
||||
|
||||
// Combine with keyword matching for better results
|
||||
let keyword_boost = self.keyword_match_score(&query_tokens, entry);
|
||||
|
||||
// Weighted combination
|
||||
cosine * 0.7 + keyword_boost * 0.3
|
||||
}
|
||||
|
||||
/// Fallback similarity when entry is not indexed
|
||||
fn fallback_similarity(&self, query_tokens: &[String], entry: &MemoryEntry) -> f32 {
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let mut matches = 0;
|
||||
|
||||
for token in query_tokens {
|
||||
if content_lower.contains(token) {
|
||||
matches += 1;
|
||||
}
|
||||
for keyword in &entry.keywords {
|
||||
if keyword.to_lowercase().contains(token) {
|
||||
matches += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(matches as f32) / (query_tokens.len() * 2).max(1) as f32
|
||||
}
|
||||
|
||||
/// Compute keyword match score
|
||||
fn keyword_match_score(&self, query_tokens: &[String], entry: &MemoryEntry) -> f32 {
|
||||
if entry.keywords.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut matches = 0;
|
||||
for token in query_tokens {
|
||||
for keyword in &entry.keywords {
|
||||
if keyword.to_lowercase().contains(&token.to_lowercase()) {
|
||||
matches += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(matches as f32) / query_tokens.len().max(1) as f32
|
||||
}
|
||||
|
||||
/// Clear the index
|
||||
pub fn clear(&mut self) {
|
||||
self.document_frequencies.clear();
|
||||
self.total_documents = 0;
|
||||
self.entry_vectors.clear();
|
||||
}
|
||||
|
||||
/// Get statistics about the index
|
||||
pub fn stats(&self) -> IndexStats {
|
||||
IndexStats {
|
||||
total_documents: self.total_documents,
|
||||
unique_terms: self.document_frequencies.len(),
|
||||
indexed_entries: self.entry_vectors.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SemanticScorer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Index statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndexStats {
|
||||
pub total_documents: usize,
|
||||
pub unique_terms: usize,
|
||||
pub indexed_entries: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::MemoryType;
|
||||
|
||||
#[test]
|
||||
fn test_tokenize() {
|
||||
let tokens = SemanticScorer::tokenize("Hello, World! This is a test.");
|
||||
assert_eq!(tokens, vec!["hello", "world", "this", "is", "test"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_words_removal() {
|
||||
let scorer = SemanticScorer::new();
|
||||
let tokens = vec!["hello".to_string(), "the".to_string(), "world".to_string()];
|
||||
let filtered = scorer.remove_stop_words(&tokens);
|
||||
assert_eq!(filtered, vec!["hello", "world"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tf_computation() {
|
||||
let tokens = vec!["hello".to_string(), "hello".to_string(), "world".to_string()];
|
||||
let tf = SemanticScorer::compute_tf(&tokens);
|
||||
|
||||
let hello_tf = tf.get("hello").unwrap();
|
||||
let world_tf = tf.get("world").unwrap();
|
||||
|
||||
// Allow for floating point comparison
|
||||
assert!((hello_tf - (2.0 / 3.0)).abs() < 0.001);
|
||||
assert!((world_tf - (1.0 / 3.0)).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let mut v1 = HashMap::new();
|
||||
v1.insert("a".to_string(), 1.0);
|
||||
v1.insert("b".to_string(), 2.0);
|
||||
|
||||
let mut v2 = HashMap::new();
|
||||
v2.insert("a".to_string(), 1.0);
|
||||
v2.insert("b".to_string(), 2.0);
|
||||
|
||||
// Identical vectors should have similarity 1.0
|
||||
let sim = SemanticScorer::cosine_similarity(&v1, &v2);
|
||||
assert!((sim - 1.0).abs() < 0.001);
|
||||
|
||||
// Orthogonal vectors should have similarity 0.0
|
||||
let mut v3 = HashMap::new();
|
||||
v3.insert("c".to_string(), 1.0);
|
||||
let sim2 = SemanticScorer::cosine_similarity(&v1, &v3);
|
||||
assert!((sim2 - 0.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_and_score() {
|
||||
let mut scorer = SemanticScorer::new();
|
||||
|
||||
let entry1 = MemoryEntry::new(
|
||||
"test",
|
||||
MemoryType::Knowledge,
|
||||
"rust",
|
||||
"Rust is a systems programming language focused on safety and performance".to_string(),
|
||||
).with_keywords(vec!["rust".to_string(), "programming".to_string(), "safety".to_string()]);
|
||||
|
||||
let entry2 = MemoryEntry::new(
|
||||
"test",
|
||||
MemoryType::Knowledge,
|
||||
"python",
|
||||
"Python is a high-level programming language".to_string(),
|
||||
).with_keywords(vec!["python".to_string(), "programming".to_string()]);
|
||||
|
||||
scorer.index_entry(&entry1);
|
||||
scorer.index_entry(&entry2);
|
||||
|
||||
// Query for Rust should score higher on entry1
|
||||
let score1 = scorer.score_similarity("rust safety", &entry1);
|
||||
let score2 = scorer.score_similarity("rust safety", &entry2);
|
||||
|
||||
assert!(score1 > score2, "Rust query should score higher on Rust entry");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats() {
|
||||
let mut scorer = SemanticScorer::new();
|
||||
|
||||
let entry = MemoryEntry::new(
|
||||
"test",
|
||||
MemoryType::Knowledge,
|
||||
"test",
|
||||
"Hello world".to_string(),
|
||||
);
|
||||
|
||||
scorer.index_entry(&entry);
|
||||
let stats = scorer.stats();
|
||||
|
||||
assert_eq!(stats.total_documents, 1);
|
||||
assert_eq!(stats.indexed_entries, 1);
|
||||
assert!(stats.unique_terms > 0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user