fix(presentation): 修复 presentation 模块类型错误和语法问题
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- 创建 types.ts 定义完整的类型系统 - 重写 DocumentRenderer.tsx 修复语法错误 - 重写 QuizRenderer.tsx 修复语法错误 - 重写 PresentationContainer.tsx 添加类型守卫 - 重写 TypeSwitcher.tsx 修复类型引用 - 更新 index.ts 移除不存在的 ChartRenderer 导出 审计结果: - 类型检查: 通过 - 单元测试: 222 passed - 构建: 成功
This commit is contained in:
40
crates/zclaw-growth/Cargo.toml
Normal file
40
crates/zclaw-growth/Cargo.toml
Normal file
@@ -0,0 +1,40 @@
|
||||
[package]
|
||||
name = "zclaw-growth"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
rust-version.workspace = true
|
||||
description = "ZCLAW Agent Growth System - Memory extraction, retrieval, and prompt injection"
|
||||
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Error handling
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
# Logging
|
||||
tracing = { workspace = true }
|
||||
|
||||
# Time
|
||||
chrono = { workspace = true }
|
||||
|
||||
# IDs
|
||||
uuid = { workspace = true }
|
||||
|
||||
# Database
|
||||
sqlx = { workspace = true }
|
||||
|
||||
# Internal crates
|
||||
zclaw-types = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
372
crates/zclaw-growth/src/extractor.rs
Normal file
372
crates/zclaw-growth/src/extractor.rs
Normal file
@@ -0,0 +1,372 @@
|
||||
//! Memory Extractor - Extracts preferences, knowledge, and experience from conversations
|
||||
//!
|
||||
//! This module provides the `MemoryExtractor` which analyzes conversations
|
||||
//! using LLM to extract valuable memories for agent growth.
|
||||
|
||||
use crate::types::{ExtractedMemory, ExtractionConfig, MemoryType};
|
||||
use crate::viking_adapter::VikingAdapter;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use zclaw_types::{Message, Result, SessionId};
|
||||
|
||||
/// Trait for LLM driver abstraction
|
||||
/// This allows us to use any LLM driver implementation
|
||||
#[async_trait]
|
||||
pub trait LlmDriverForExtraction: Send + Sync {
|
||||
/// Extract memories from conversation using LLM
|
||||
async fn extract_memories(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
extraction_type: MemoryType,
|
||||
) -> Result<Vec<ExtractedMemory>>;
|
||||
}
|
||||
|
||||
/// Memory Extractor - extracts memories from conversations
|
||||
pub struct MemoryExtractor {
|
||||
/// LLM driver for extraction (optional)
|
||||
llm_driver: Option<Arc<dyn LlmDriverForExtraction>>,
|
||||
/// OpenViking adapter for storage
|
||||
viking: Option<Arc<VikingAdapter>>,
|
||||
/// Extraction configuration
|
||||
config: ExtractionConfig,
|
||||
}
|
||||
|
||||
impl MemoryExtractor {
|
||||
/// Create a new memory extractor with LLM driver
|
||||
pub fn new(llm_driver: Arc<dyn LlmDriverForExtraction>) -> Self {
|
||||
Self {
|
||||
llm_driver: Some(llm_driver),
|
||||
viking: None,
|
||||
config: ExtractionConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new memory extractor without LLM driver
|
||||
///
|
||||
/// This is useful for cases where LLM-based extraction is not needed
|
||||
/// or will be set later using `with_llm_driver`
|
||||
pub fn new_without_driver() -> Self {
|
||||
Self {
|
||||
llm_driver: None,
|
||||
viking: None,
|
||||
config: ExtractionConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the LLM driver
|
||||
pub fn with_llm_driver(mut self, driver: Arc<dyn LlmDriverForExtraction>) -> Self {
|
||||
self.llm_driver = Some(driver);
|
||||
self
|
||||
}
|
||||
|
||||
/// Create with OpenViking adapter
|
||||
pub fn with_viking(mut self, viking: Arc<VikingAdapter>) -> Self {
|
||||
self.viking = Some(viking);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set extraction configuration
|
||||
pub fn with_config(mut self, config: ExtractionConfig) -> Self {
|
||||
self.config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Extract memories from a conversation
|
||||
///
|
||||
/// This method analyzes the conversation and extracts:
|
||||
/// - Preferences: User's communication style, format preferences, language preferences
|
||||
/// - Knowledge: User-related facts, domain knowledge, lessons learned
|
||||
/// - Experience: Skill/tool usage patterns and outcomes
|
||||
///
|
||||
/// Returns an empty Vec if no LLM driver is configured
|
||||
pub async fn extract(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session_id: SessionId,
|
||||
) -> Result<Vec<ExtractedMemory>> {
|
||||
// Check if LLM driver is available
|
||||
let _llm_driver = match &self.llm_driver {
|
||||
Some(driver) => driver,
|
||||
None => {
|
||||
tracing::debug!("[MemoryExtractor] No LLM driver configured, skipping extraction");
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Extract preferences if enabled
|
||||
if self.config.extract_preferences {
|
||||
tracing::debug!("[MemoryExtractor] Extracting preferences...");
|
||||
let prefs = self.extract_preferences(messages, session_id).await?;
|
||||
results.extend(prefs);
|
||||
}
|
||||
|
||||
// Extract knowledge if enabled
|
||||
if self.config.extract_knowledge {
|
||||
tracing::debug!("[MemoryExtractor] Extracting knowledge...");
|
||||
let knowledge = self.extract_knowledge(messages, session_id).await?;
|
||||
results.extend(knowledge);
|
||||
}
|
||||
|
||||
// Extract experience if enabled
|
||||
if self.config.extract_experience {
|
||||
tracing::debug!("[MemoryExtractor] Extracting experience...");
|
||||
let experience = self.extract_experience(messages, session_id).await?;
|
||||
results.extend(experience);
|
||||
}
|
||||
|
||||
// Filter by confidence threshold
|
||||
results.retain(|m| m.confidence >= self.config.min_confidence);
|
||||
|
||||
tracing::info!(
|
||||
"[MemoryExtractor] Extracted {} memories (confidence >= {})",
|
||||
results.len(),
|
||||
self.config.min_confidence
|
||||
);
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Extract user preferences from conversation
|
||||
async fn extract_preferences(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session_id: SessionId,
|
||||
) -> Result<Vec<ExtractedMemory>> {
|
||||
let llm_driver = match &self.llm_driver {
|
||||
Some(driver) => driver,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let mut results = llm_driver
|
||||
.extract_memories(messages, MemoryType::Preference)
|
||||
.await?;
|
||||
|
||||
// Set source session
|
||||
for memory in &mut results {
|
||||
memory.source_session = session_id;
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Extract knowledge from conversation
|
||||
async fn extract_knowledge(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session_id: SessionId,
|
||||
) -> Result<Vec<ExtractedMemory>> {
|
||||
let llm_driver = match &self.llm_driver {
|
||||
Some(driver) => driver,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let mut results = llm_driver
|
||||
.extract_memories(messages, MemoryType::Knowledge)
|
||||
.await?;
|
||||
|
||||
for memory in &mut results {
|
||||
memory.source_session = session_id;
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Extract experience from conversation
|
||||
async fn extract_experience(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session_id: SessionId,
|
||||
) -> Result<Vec<ExtractedMemory>> {
|
||||
let llm_driver = match &self.llm_driver {
|
||||
Some(driver) => driver,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let mut results = llm_driver
|
||||
.extract_memories(messages, MemoryType::Experience)
|
||||
.await?;
|
||||
|
||||
for memory in &mut results {
|
||||
memory.source_session = session_id;
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Store extracted memories to OpenViking
|
||||
pub async fn store_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
memories: &[ExtractedMemory],
|
||||
) -> Result<usize> {
|
||||
let viking = match &self.viking {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
tracing::warn!("[MemoryExtractor] No VikingAdapter configured, memories not stored");
|
||||
return Ok(0);
|
||||
}
|
||||
};
|
||||
|
||||
let mut stored = 0;
|
||||
for memory in memories {
|
||||
let entry = memory.to_memory_entry(agent_id);
|
||||
match viking.store(&entry).await {
|
||||
Ok(_) => stored += 1,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"[MemoryExtractor] Failed to store memory {}: {}",
|
||||
memory.category,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("[MemoryExtractor] Stored {} memories to OpenViking", stored);
|
||||
Ok(stored)
|
||||
}
|
||||
}
|
||||
|
||||
/// Default extraction prompts for LLM
|
||||
pub mod prompts {
|
||||
use crate::types::MemoryType;
|
||||
|
||||
/// Get the extraction prompt for a memory type
|
||||
pub fn get_extraction_prompt(memory_type: MemoryType) -> &'static str {
|
||||
match memory_type {
|
||||
MemoryType::Preference => PREFERENCE_EXTRACTION_PROMPT,
|
||||
MemoryType::Knowledge => KNOWLEDGE_EXTRACTION_PROMPT,
|
||||
MemoryType::Experience => EXPERIENCE_EXTRACTION_PROMPT,
|
||||
MemoryType::Session => SESSION_SUMMARY_PROMPT,
|
||||
}
|
||||
}
|
||||
|
||||
const PREFERENCE_EXTRACTION_PROMPT: &str = r#"
|
||||
分析以下对话,提取用户的偏好设置。关注:
|
||||
- 沟通风格偏好(简洁/详细、正式/随意)
|
||||
- 回复格式偏好(列表/段落、代码块风格)
|
||||
- 语言偏好
|
||||
- 主题兴趣
|
||||
|
||||
请以 JSON 格式返回,格式如下:
|
||||
[
|
||||
{
|
||||
"category": "communication-style",
|
||||
"content": "用户偏好简洁的回复",
|
||||
"confidence": 0.9,
|
||||
"keywords": ["简洁", "回复风格"]
|
||||
}
|
||||
]
|
||||
|
||||
对话内容:
|
||||
"#;
|
||||
|
||||
const KNOWLEDGE_EXTRACTION_PROMPT: &str = r#"
|
||||
分析以下对话,提取有价值的知识。关注:
|
||||
- 用户相关事实(职业、项目、背景)
|
||||
- 领域知识(技术栈、工具、最佳实践)
|
||||
- 经验教训(成功/失败案例)
|
||||
|
||||
请以 JSON 格式返回,格式如下:
|
||||
[
|
||||
{
|
||||
"category": "user-facts",
|
||||
"content": "用户是一名 Rust 开发者",
|
||||
"confidence": 0.85,
|
||||
"keywords": ["Rust", "开发者"]
|
||||
}
|
||||
]
|
||||
|
||||
对话内容:
|
||||
"#;
|
||||
|
||||
const EXPERIENCE_EXTRACTION_PROMPT: &str = r#"
|
||||
分析以下对话,提取技能/工具使用经验。关注:
|
||||
- 使用的技能或工具
|
||||
- 执行结果(成功/失败)
|
||||
- 改进建议
|
||||
|
||||
请以 JSON 格式返回,格式如下:
|
||||
[
|
||||
{
|
||||
"category": "skill-browser",
|
||||
"content": "浏览器技能在搜索技术文档时效果很好",
|
||||
"confidence": 0.8,
|
||||
"keywords": ["浏览器", "搜索", "文档"]
|
||||
}
|
||||
]
|
||||
|
||||
对话内容:
|
||||
"#;
|
||||
|
||||
const SESSION_SUMMARY_PROMPT: &str = r#"
|
||||
总结以下对话会话。关注:
|
||||
- 主要话题
|
||||
- 关键决策
|
||||
- 未解决问题
|
||||
|
||||
请以 JSON 格式返回,格式如下:
|
||||
{
|
||||
"summary": "会话摘要内容",
|
||||
"keywords": ["关键词1", "关键词2"],
|
||||
"topics": ["主题1", "主题2"]
|
||||
}
|
||||
|
||||
对话内容:
|
||||
"#;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct MockLlmDriver;
|
||||
|
||||
#[async_trait]
|
||||
impl LlmDriverForExtraction for MockLlmDriver {
|
||||
async fn extract_memories(
|
||||
&self,
|
||||
_messages: &[Message],
|
||||
extraction_type: MemoryType,
|
||||
) -> Result<Vec<ExtractedMemory>> {
|
||||
Ok(vec![ExtractedMemory::new(
|
||||
extraction_type,
|
||||
"test-category",
|
||||
"test content",
|
||||
SessionId::new(),
|
||||
)])
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extractor_creation() {
|
||||
let driver = Arc::new(MockLlmDriver);
|
||||
let extractor = MemoryExtractor::new(driver);
|
||||
assert!(extractor.viking.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_memories() {
|
||||
let driver = Arc::new(MockLlmDriver);
|
||||
let extractor = MemoryExtractor::new(driver);
|
||||
let messages = vec![Message::user("Hello")];
|
||||
|
||||
let result = extractor
|
||||
.extract(&messages, SessionId::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should extract preferences, knowledge, and experience
|
||||
assert!(!result.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prompts_available() {
|
||||
assert!(!prompts::get_extraction_prompt(MemoryType::Preference).is_empty());
|
||||
assert!(!prompts::get_extraction_prompt(MemoryType::Knowledge).is_empty());
|
||||
assert!(!prompts::get_extraction_prompt(MemoryType::Experience).is_empty());
|
||||
assert!(!prompts::get_extraction_prompt(MemoryType::Session).is_empty());
|
||||
}
|
||||
}
|
||||
537
crates/zclaw-growth/src/injector.rs
Normal file
537
crates/zclaw-growth/src/injector.rs
Normal file
@@ -0,0 +1,537 @@
|
||||
//! Prompt Injector - Injects retrieved memories into system prompts
|
||||
//!
|
||||
//! This module provides the `PromptInjector` which formats and injects
|
||||
//! retrieved memories into the agent's system prompt for context enhancement.
|
||||
//!
|
||||
//! # Formatting Options
|
||||
//!
|
||||
//! - `inject()` - Standard markdown format with sections
|
||||
//! - `inject_compact()` - Compact format for limited token budgets
|
||||
//! - `inject_json()` - JSON format for structured processing
|
||||
//! - `inject_custom()` - Custom template with placeholders
|
||||
|
||||
use crate::types::{MemoryEntry, RetrievalConfig, RetrievalResult};
|
||||
|
||||
/// Output format for memory injection
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum InjectionFormat {
|
||||
/// Standard markdown with sections (default)
|
||||
Markdown,
|
||||
/// Compact inline format
|
||||
Compact,
|
||||
/// JSON structured format
|
||||
Json,
|
||||
}
|
||||
|
||||
/// Prompt Injector - injects memories into system prompts
|
||||
pub struct PromptInjector {
|
||||
/// Retrieval configuration for token budgets
|
||||
config: RetrievalConfig,
|
||||
/// Output format
|
||||
format: InjectionFormat,
|
||||
/// Custom template (uses {{preferences}}, {{knowledge}}, {{experience}} placeholders)
|
||||
custom_template: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for PromptInjector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptInjector {
|
||||
/// Create a new prompt injector
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: RetrievalConfig::default(),
|
||||
format: InjectionFormat::Markdown,
|
||||
custom_template: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom configuration
|
||||
pub fn with_config(config: RetrievalConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
format: InjectionFormat::Markdown,
|
||||
custom_template: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the output format
|
||||
pub fn with_format(mut self, format: InjectionFormat) -> Self {
|
||||
self.format = format;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a custom template for injection
|
||||
///
|
||||
/// Template placeholders:
|
||||
/// - `{{preferences}}` - Formatted preferences section
|
||||
/// - `{{knowledge}}` - Formatted knowledge section
|
||||
/// - `{{experience}}` - Formatted experience section
|
||||
/// - `{{all}}` - All memories combined
|
||||
pub fn with_custom_template(mut self, template: impl Into<String>) -> Self {
|
||||
self.custom_template = Some(template.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Inject memories into a base system prompt
|
||||
///
|
||||
/// This method constructs an enhanced system prompt by:
|
||||
/// 1. Starting with the base prompt
|
||||
/// 2. Adding a "用户偏好" section if preferences exist
|
||||
/// 3. Adding a "相关知识" section if knowledge exists
|
||||
/// 4. Adding an "经验参考" section if experience exists
|
||||
///
|
||||
/// Each section respects the token budget configuration.
|
||||
pub fn inject(&self, base_prompt: &str, memories: &RetrievalResult) -> String {
|
||||
// If no memories, return base prompt unchanged
|
||||
if memories.is_empty() {
|
||||
return base_prompt.to_string();
|
||||
}
|
||||
|
||||
let mut result = base_prompt.to_string();
|
||||
|
||||
// Inject preferences section
|
||||
if !memories.preferences.is_empty() {
|
||||
let section = self.format_section(
|
||||
"## 用户偏好",
|
||||
&memories.preferences,
|
||||
self.config.preference_budget,
|
||||
|entry| format!("- {}", entry.content),
|
||||
);
|
||||
result.push_str("\n\n");
|
||||
result.push_str(§ion);
|
||||
}
|
||||
|
||||
// Inject knowledge section
|
||||
if !memories.knowledge.is_empty() {
|
||||
let section = self.format_section(
|
||||
"## 相关知识",
|
||||
&memories.knowledge,
|
||||
self.config.knowledge_budget,
|
||||
|entry| format!("- {}", entry.content),
|
||||
);
|
||||
result.push_str("\n\n");
|
||||
result.push_str(§ion);
|
||||
}
|
||||
|
||||
// Inject experience section
|
||||
if !memories.experience.is_empty() {
|
||||
let section = self.format_section(
|
||||
"## 经验参考",
|
||||
&memories.experience,
|
||||
self.config.experience_budget,
|
||||
|entry| format!("- {}", entry.content),
|
||||
);
|
||||
result.push_str("\n\n");
|
||||
result.push_str(§ion);
|
||||
}
|
||||
|
||||
// Add memory context footer
|
||||
result.push_str("\n\n");
|
||||
result.push_str("<!-- 以上内容基于历史对话自动提取的记忆 -->");
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Format a section of memories with token budget
|
||||
fn format_section<F>(
|
||||
&self,
|
||||
header: &str,
|
||||
entries: &[MemoryEntry],
|
||||
token_budget: usize,
|
||||
formatter: F,
|
||||
) -> String
|
||||
where
|
||||
F: Fn(&MemoryEntry) -> String,
|
||||
{
|
||||
let mut result = String::new();
|
||||
result.push_str(header);
|
||||
result.push('\n');
|
||||
|
||||
let mut used_tokens = 0;
|
||||
let header_tokens = header.len() / 4;
|
||||
used_tokens += header_tokens;
|
||||
|
||||
for entry in entries {
|
||||
let line = formatter(entry);
|
||||
let line_tokens = line.len() / 4;
|
||||
|
||||
if used_tokens + line_tokens > token_budget {
|
||||
// Add truncation indicator
|
||||
result.push_str("- ... (更多内容已省略)\n");
|
||||
break;
|
||||
}
|
||||
|
||||
result.push_str(&line);
|
||||
result.push('\n');
|
||||
used_tokens += line_tokens;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Build a minimal context string for token-limited scenarios
|
||||
pub fn build_minimal_context(&self, memories: &RetrievalResult) -> String {
|
||||
if memories.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut context = String::new();
|
||||
|
||||
// Only include top preference
|
||||
if let Some(pref) = memories.preferences.first() {
|
||||
context.push_str(&format!("[偏好] {}\n", pref.content));
|
||||
}
|
||||
|
||||
// Only include top knowledge
|
||||
if let Some(knowledge) = memories.knowledge.first() {
|
||||
context.push_str(&format!("[知识] {}\n", knowledge.content));
|
||||
}
|
||||
|
||||
context
|
||||
}
|
||||
|
||||
/// Inject memories in compact format
|
||||
///
|
||||
/// Compact format uses inline notation: [P] for preferences, [K] for knowledge, [E] for experience
|
||||
pub fn inject_compact(&self, base_prompt: &str, memories: &RetrievalResult) -> String {
|
||||
if memories.is_empty() {
|
||||
return base_prompt.to_string();
|
||||
}
|
||||
|
||||
let mut result = base_prompt.to_string();
|
||||
let mut context_parts = Vec::new();
|
||||
|
||||
// Add compact preferences
|
||||
for entry in &memories.preferences {
|
||||
context_parts.push(format!("[P] {}", entry.content));
|
||||
}
|
||||
|
||||
// Add compact knowledge
|
||||
for entry in &memories.knowledge {
|
||||
context_parts.push(format!("[K] {}", entry.content));
|
||||
}
|
||||
|
||||
// Add compact experience
|
||||
for entry in &memories.experience {
|
||||
context_parts.push(format!("[E] {}", entry.content));
|
||||
}
|
||||
|
||||
if !context_parts.is_empty() {
|
||||
result.push_str("\n\n[记忆上下文]\n");
|
||||
result.push_str(&context_parts.join("\n"));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Inject memories as JSON structure
|
||||
///
|
||||
/// Returns a JSON object with preferences, knowledge, and experience arrays
|
||||
pub fn inject_json(&self, base_prompt: &str, memories: &RetrievalResult) -> String {
|
||||
if memories.is_empty() {
|
||||
return base_prompt.to_string();
|
||||
}
|
||||
|
||||
let preferences: Vec<_> = memories.preferences.iter()
|
||||
.map(|e| serde_json::json!({
|
||||
"content": e.content,
|
||||
"importance": e.importance,
|
||||
"keywords": e.keywords,
|
||||
}))
|
||||
.collect();
|
||||
|
||||
let knowledge: Vec<_> = memories.knowledge.iter()
|
||||
.map(|e| serde_json::json!({
|
||||
"content": e.content,
|
||||
"importance": e.importance,
|
||||
"keywords": e.keywords,
|
||||
}))
|
||||
.collect();
|
||||
|
||||
let experience: Vec<_> = memories.experience.iter()
|
||||
.map(|e| serde_json::json!({
|
||||
"content": e.content,
|
||||
"importance": e.importance,
|
||||
"keywords": e.keywords,
|
||||
}))
|
||||
.collect();
|
||||
|
||||
let memories_json = serde_json::json!({
|
||||
"preferences": preferences,
|
||||
"knowledge": knowledge,
|
||||
"experience": experience,
|
||||
});
|
||||
|
||||
format!("{}\n\n[记忆上下文]\n{}", base_prompt, serde_json::to_string_pretty(&memories_json).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Inject using custom template
|
||||
///
|
||||
/// Template placeholders:
|
||||
/// - `{{preferences}}` - Formatted preferences section
|
||||
/// - `{{knowledge}}` - Formatted knowledge section
|
||||
/// - `{{experience}}` - Formatted experience section
|
||||
/// - `{{all}}` - All memories combined
|
||||
pub fn inject_custom(&self, template: &str, memories: &RetrievalResult) -> String {
|
||||
let mut result = template.to_string();
|
||||
|
||||
// Format each section
|
||||
let prefs = if !memories.preferences.is_empty() {
|
||||
memories.preferences.iter()
|
||||
.map(|e| format!("- {}", e.content))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let knowledge = if !memories.knowledge.is_empty() {
|
||||
memories.knowledge.iter()
|
||||
.map(|e| format!("- {}", e.content))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let experience = if !memories.experience.is_empty() {
|
||||
memories.experience.iter()
|
||||
.map(|e| format!("- {}", e.content))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Combine all
|
||||
let all = format!(
|
||||
"用户偏好:\n{}\n\n相关知识:\n{}\n\n经验参考:\n{}",
|
||||
if prefs.is_empty() { "无" } else { &prefs },
|
||||
if knowledge.is_empty() { "无" } else { &knowledge },
|
||||
if experience.is_empty() { "无" } else { &experience },
|
||||
);
|
||||
|
||||
// Replace placeholders
|
||||
result = result.replace("{{preferences}}", &prefs);
|
||||
result = result.replace("{{knowledge}}", &knowledge);
|
||||
result = result.replace("{{experience}}", &experience);
|
||||
result = result.replace("{{all}}", &all);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Inject memories using the configured format
|
||||
pub fn inject_with_format(&self, base_prompt: &str, memories: &RetrievalResult) -> String {
|
||||
match self.format {
|
||||
InjectionFormat::Markdown => self.inject(base_prompt, memories),
|
||||
InjectionFormat::Compact => self.inject_compact(base_prompt, memories),
|
||||
InjectionFormat::Json => self.inject_json(base_prompt, memories),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate total tokens that will be injected
|
||||
pub fn estimate_injection_tokens(&self, memories: &RetrievalResult) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Count preference tokens
|
||||
for entry in &memories.preferences {
|
||||
total += entry.estimated_tokens();
|
||||
if total > self.config.preference_budget {
|
||||
total = self.config.preference_budget;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Count knowledge tokens
|
||||
let mut knowledge_tokens = 0;
|
||||
for entry in &memories.knowledge {
|
||||
knowledge_tokens += entry.estimated_tokens();
|
||||
if knowledge_tokens > self.config.knowledge_budget {
|
||||
knowledge_tokens = self.config.knowledge_budget;
|
||||
break;
|
||||
}
|
||||
}
|
||||
total += knowledge_tokens;
|
||||
|
||||
// Count experience tokens
|
||||
let mut experience_tokens = 0;
|
||||
for entry in &memories.experience {
|
||||
experience_tokens += entry.estimated_tokens();
|
||||
if experience_tokens > self.config.experience_budget {
|
||||
experience_tokens = self.config.experience_budget;
|
||||
break;
|
||||
}
|
||||
}
|
||||
total += experience_tokens;
|
||||
|
||||
total
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::MemoryType;
|
||||
use chrono::Utc;
|
||||
|
||||
fn create_test_entry(content: &str) -> MemoryEntry {
|
||||
MemoryEntry {
|
||||
uri: "test://uri".to_string(),
|
||||
memory_type: MemoryType::Preference,
|
||||
content: content.to_string(),
|
||||
keywords: vec![],
|
||||
importance: 5,
|
||||
access_count: 0,
|
||||
created_at: Utc::now(),
|
||||
last_accessed: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_injector_empty_memories() {
|
||||
let injector = PromptInjector::new();
|
||||
let base = "You are a helpful assistant.";
|
||||
let memories = RetrievalResult::default();
|
||||
|
||||
let result = injector.inject(base, &memories);
|
||||
assert_eq!(result, base);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_injector_with_preferences() {
|
||||
let injector = PromptInjector::new();
|
||||
let base = "You are a helpful assistant.";
|
||||
let memories = RetrievalResult {
|
||||
preferences: vec![create_test_entry("User prefers concise responses")],
|
||||
knowledge: vec![],
|
||||
experience: vec![],
|
||||
total_tokens: 0,
|
||||
};
|
||||
|
||||
let result = injector.inject(base, &memories);
|
||||
assert!(result.contains("用户偏好"));
|
||||
assert!(result.contains("User prefers concise responses"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_injector_with_all_types() {
|
||||
let injector = PromptInjector::new();
|
||||
let base = "You are a helpful assistant.";
|
||||
|
||||
let memories = RetrievalResult {
|
||||
preferences: vec![create_test_entry("Prefers concise")],
|
||||
knowledge: vec![create_test_entry("Knows Rust")],
|
||||
experience: vec![create_test_entry("Browser skill works well")],
|
||||
total_tokens: 0,
|
||||
};
|
||||
|
||||
let result = injector.inject(base, &memories);
|
||||
assert!(result.contains("用户偏好"));
|
||||
assert!(result.contains("相关知识"));
|
||||
assert!(result.contains("经验参考"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimal_context() {
|
||||
let injector = PromptInjector::new();
|
||||
let memories = RetrievalResult {
|
||||
preferences: vec![create_test_entry("Prefers concise")],
|
||||
knowledge: vec![create_test_entry("Knows Rust")],
|
||||
experience: vec![],
|
||||
total_tokens: 0,
|
||||
};
|
||||
|
||||
let context = injector.build_minimal_context(&memories);
|
||||
assert!(context.contains("[偏好]"));
|
||||
assert!(context.contains("[知识]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens() {
|
||||
let injector = PromptInjector::new();
|
||||
let memories = RetrievalResult {
|
||||
preferences: vec![create_test_entry("Short text")],
|
||||
knowledge: vec![],
|
||||
experience: vec![],
|
||||
total_tokens: 0,
|
||||
};
|
||||
|
||||
let estimate = injector.estimate_injection_tokens(&memories);
|
||||
assert!(estimate > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_compact() {
|
||||
let injector = PromptInjector::new();
|
||||
let base = "You are a helpful assistant.";
|
||||
let memories = RetrievalResult {
|
||||
preferences: vec![create_test_entry("Prefers concise")],
|
||||
knowledge: vec![create_test_entry("Knows Rust")],
|
||||
experience: vec![],
|
||||
total_tokens: 0,
|
||||
};
|
||||
|
||||
let result = injector.inject_compact(base, &memories);
|
||||
assert!(result.contains("[P]"));
|
||||
assert!(result.contains("[K]"));
|
||||
assert!(result.contains("[记忆上下文]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_json() {
|
||||
let injector = PromptInjector::new();
|
||||
let base = "You are a helpful assistant.";
|
||||
let memories = RetrievalResult {
|
||||
preferences: vec![create_test_entry("Prefers concise")],
|
||||
knowledge: vec![],
|
||||
experience: vec![],
|
||||
total_tokens: 0,
|
||||
};
|
||||
|
||||
let result = injector.inject_json(base, &memories);
|
||||
assert!(result.contains("\"preferences\""));
|
||||
assert!(result.contains("Prefers concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_custom() {
|
||||
let injector = PromptInjector::new();
|
||||
let template = "Context:\n{{all}}";
|
||||
let memories = RetrievalResult {
|
||||
preferences: vec![create_test_entry("Prefers concise")],
|
||||
knowledge: vec![create_test_entry("Knows Rust")],
|
||||
experience: vec![],
|
||||
total_tokens: 0,
|
||||
};
|
||||
|
||||
let result = injector.inject_custom(template, &memories);
|
||||
assert!(result.contains("用户偏好"));
|
||||
assert!(result.contains("相关知识"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_selection() {
|
||||
let base = "Base";
|
||||
|
||||
let memories = RetrievalResult {
|
||||
preferences: vec![create_test_entry("Test")],
|
||||
knowledge: vec![],
|
||||
experience: vec![],
|
||||
total_tokens: 0,
|
||||
};
|
||||
|
||||
// Test markdown format
|
||||
let injector_md = PromptInjector::new().with_format(InjectionFormat::Markdown);
|
||||
let result_md = injector_md.inject_with_format(base, &memories);
|
||||
assert!(result_md.contains("## 用户偏好"));
|
||||
|
||||
// Test compact format
|
||||
let injector_compact = PromptInjector::new().with_format(InjectionFormat::Compact);
|
||||
let result_compact = injector_compact.inject_with_format(base, &memories);
|
||||
assert!(result_compact.contains("[P]"));
|
||||
}
|
||||
}
|
||||
141
crates/zclaw-growth/src/lib.rs
Normal file
141
crates/zclaw-growth/src/lib.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
//! ZCLAW Agent Growth System
|
||||
//!
|
||||
//! This crate provides the agent growth functionality for ZCLAW,
|
||||
//! enabling agents to learn and evolve from conversations.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The growth system consists of four main components:
|
||||
//!
|
||||
//! 1. **MemoryExtractor** (`extractor`) - Analyzes conversations and extracts
|
||||
//! preferences, knowledge, and experience using LLM.
|
||||
//!
|
||||
//! 2. **MemoryRetriever** (`retriever`) - Performs semantic search over
|
||||
//! stored memories to find contextually relevant information.
|
||||
//!
|
||||
//! 3. **PromptInjector** (`injector`) - Injects retrieved memories into
|
||||
//! the system prompt with token budget control.
|
||||
//!
|
||||
//! 4. **GrowthTracker** (`tracker`) - Tracks growth metrics and evolution
|
||||
//! over time.
|
||||
//!
|
||||
//! # Storage
|
||||
//!
|
||||
//! All memories are stored in OpenViking with a URI structure:
|
||||
//!
|
||||
//! ```text
|
||||
//! agent://{agent_id}/
|
||||
//! ├── preferences/{category} - User preferences
|
||||
//! ├── knowledge/{domain} - Accumulated knowledge
|
||||
//! ├── experience/{skill} - Skill/tool experience
|
||||
//! └── sessions/{session_id}/ - Conversation history
|
||||
//! ├── raw - Original conversation (L0)
|
||||
//! ├── summary - Summary (L1)
|
||||
//! └── keywords - Keywords (L2)
|
||||
//! ```
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use zclaw_growth::{MemoryExtractor, MemoryRetriever, PromptInjector, VikingAdapter};
|
||||
//!
|
||||
//! // Create components
|
||||
//! let viking = VikingAdapter::in_memory();
|
||||
//! let retriever = MemoryRetriever::new(Arc::new(viking.clone()));
|
||||
//! let injector = PromptInjector::new();
|
||||
//!
|
||||
//! // Before conversation: retrieve relevant memories
|
||||
//! let memories = retriever.retrieve(&agent_id, &user_input).await?;
|
||||
//!
|
||||
//! // Inject into system prompt
|
||||
//! let enhanced_prompt = injector.inject(&base_prompt, &memories);
|
||||
//!
|
||||
//! // After conversation: extract and store new memories
|
||||
//! let extracted = extractor.extract(&messages, session_id).await?;
|
||||
//! extractor.store_memories(&agent_id, &extracted).await?;
|
||||
//! ```
|
||||
|
||||
pub mod types;
|
||||
pub mod extractor;
|
||||
pub mod retriever;
|
||||
pub mod injector;
|
||||
pub mod tracker;
|
||||
pub mod viking_adapter;
|
||||
pub mod storage;
|
||||
pub mod retrieval;
|
||||
|
||||
// Re-export main types for convenience
|
||||
pub use types::{
|
||||
ExtractedMemory,
|
||||
ExtractionConfig,
|
||||
GrowthStats,
|
||||
MemoryEntry,
|
||||
MemoryType,
|
||||
RetrievalConfig,
|
||||
RetrievalResult,
|
||||
UriBuilder,
|
||||
};
|
||||
|
||||
pub use extractor::{LlmDriverForExtraction, MemoryExtractor};
|
||||
pub use retriever::{MemoryRetriever, MemoryStats};
|
||||
pub use injector::{InjectionFormat, PromptInjector};
|
||||
pub use tracker::{AgentMetadata, GrowthTracker, LearningEvent};
|
||||
pub use viking_adapter::{FindOptions, VikingAdapter, VikingLevel, VikingStorage};
|
||||
pub use storage::SqliteStorage;
|
||||
pub use retrieval::{MemoryCache, QueryAnalyzer, SemanticScorer};
|
||||
|
||||
/// Growth system configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GrowthConfig {
|
||||
/// Enable/disable growth system
|
||||
pub enabled: bool,
|
||||
/// Retrieval configuration
|
||||
pub retrieval: RetrievalConfig,
|
||||
/// Extraction configuration
|
||||
pub extraction: ExtractionConfig,
|
||||
/// Auto-extract after each conversation
|
||||
pub auto_extract: bool,
|
||||
}
|
||||
|
||||
impl Default for GrowthConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
retrieval: RetrievalConfig::default(),
|
||||
extraction: ExtractionConfig::default(),
|
||||
auto_extract: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience function to create a complete growth system
|
||||
pub fn create_growth_system(
|
||||
viking: std::sync::Arc<VikingAdapter>,
|
||||
llm_driver: std::sync::Arc<dyn LlmDriverForExtraction>,
|
||||
) -> (MemoryExtractor, MemoryRetriever, PromptInjector, GrowthTracker) {
|
||||
let extractor = MemoryExtractor::new(llm_driver).with_viking(viking.clone());
|
||||
let retriever = MemoryRetriever::new(viking.clone());
|
||||
let injector = PromptInjector::new();
|
||||
let tracker = GrowthTracker::new(viking);
|
||||
|
||||
(extractor, retriever, injector, tracker)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_growth_config_default() {
|
||||
let config = GrowthConfig::default();
|
||||
assert!(config.enabled);
|
||||
assert!(config.auto_extract);
|
||||
assert_eq!(config.retrieval.max_tokens, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_type_reexport() {
|
||||
let mt = MemoryType::Preference;
|
||||
assert_eq!(format!("{}", mt), "preferences");
|
||||
}
|
||||
}
|
||||
365
crates/zclaw-growth/src/retrieval/cache.rs
Normal file
365
crates/zclaw-growth/src/retrieval/cache.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
//! Memory Cache
|
||||
//!
|
||||
//! Provides caching for frequently accessed memories to improve
|
||||
//! retrieval performance.
|
||||
|
||||
use crate::types::{MemoryEntry, MemoryType};
|
||||
use std::collections::HashMap;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Cache entry with metadata
|
||||
struct CacheEntry {
|
||||
/// The memory entry
|
||||
entry: MemoryEntry,
|
||||
/// Last access time
|
||||
last_accessed: Instant,
|
||||
/// Access count
|
||||
access_count: u32,
|
||||
}
|
||||
|
||||
/// Cache key for efficient lookups
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
struct CacheKey {
|
||||
agent_id: String,
|
||||
memory_type: MemoryType,
|
||||
category: String,
|
||||
}
|
||||
|
||||
impl From<&MemoryEntry> for CacheKey {
|
||||
fn from(entry: &MemoryEntry) -> Self {
|
||||
// Parse URI to extract components
|
||||
let parts: Vec<&str> = entry.uri.trim_start_matches("agent://").split('/').collect();
|
||||
Self {
|
||||
agent_id: parts.first().unwrap_or(&"").to_string(),
|
||||
memory_type: entry.memory_type,
|
||||
category: parts.get(2).unwrap_or(&"").to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory cache configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheConfig {
|
||||
/// Maximum number of entries
|
||||
pub max_entries: usize,
|
||||
/// Time-to-live for entries
|
||||
pub ttl: Duration,
|
||||
/// Enable/disable caching
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_entries: 1000,
|
||||
ttl: Duration::from_secs(3600), // 1 hour
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory cache for hot memories
|
||||
pub struct MemoryCache {
|
||||
/// Cache storage
|
||||
cache: RwLock<HashMap<String, CacheEntry>>,
|
||||
/// Configuration
|
||||
config: CacheConfig,
|
||||
/// Cache statistics
|
||||
stats: RwLock<CacheStats>,
|
||||
}
|
||||
|
||||
/// Cache statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CacheStats {
|
||||
/// Total cache hits
|
||||
pub hits: u64,
|
||||
/// Total cache misses
|
||||
pub misses: u64,
|
||||
/// Total entries evicted
|
||||
pub evictions: u64,
|
||||
}
|
||||
|
||||
impl MemoryCache {
|
||||
/// Create a new memory cache
|
||||
pub fn new(config: CacheConfig) -> Self {
|
||||
Self {
|
||||
cache: RwLock::new(HashMap::new()),
|
||||
config,
|
||||
stats: RwLock::new(CacheStats::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(CacheConfig::default())
|
||||
}
|
||||
|
||||
/// Get a memory from cache
|
||||
pub async fn get(&self, uri: &str) -> Option<MemoryEntry> {
|
||||
if !self.config.enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut cache = self.cache.write().await;
|
||||
|
||||
if let Some(cached) = cache.get_mut(uri) {
|
||||
// Check TTL
|
||||
if cached.last_accessed.elapsed() > self.config.ttl {
|
||||
cache.remove(uri);
|
||||
return None;
|
||||
}
|
||||
|
||||
// Update access metadata
|
||||
cached.last_accessed = Instant::now();
|
||||
cached.access_count += 1;
|
||||
|
||||
// Update stats
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.hits += 1;
|
||||
|
||||
return Some(cached.entry.clone());
|
||||
}
|
||||
|
||||
// Update stats
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.misses += 1;
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Put a memory into cache
|
||||
pub async fn put(&self, entry: MemoryEntry) {
|
||||
if !self.config.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut cache = self.cache.write().await;
|
||||
|
||||
// Check capacity and evict if necessary
|
||||
if cache.len() >= self.config.max_entries {
|
||||
self.evict_lru(&mut cache).await;
|
||||
}
|
||||
|
||||
cache.insert(
|
||||
entry.uri.clone(),
|
||||
CacheEntry {
|
||||
entry,
|
||||
last_accessed: Instant::now(),
|
||||
access_count: 0,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove a memory from cache
|
||||
pub async fn remove(&self, uri: &str) {
|
||||
let mut cache = self.cache.write().await;
|
||||
cache.remove(uri);
|
||||
}
|
||||
|
||||
/// Clear the cache
|
||||
pub async fn clear(&self) {
|
||||
let mut cache = self.cache.write().await;
|
||||
cache.clear();
|
||||
}
|
||||
|
||||
/// Evict least recently used entries
|
||||
async fn evict_lru(&self, cache: &mut HashMap<String, CacheEntry>) {
|
||||
// Find LRU entry
|
||||
let lru_key = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| (v.access_count, v.last_accessed))
|
||||
.map(|(k, _)| k.clone());
|
||||
|
||||
if let Some(key) = lru_key {
|
||||
cache.remove(&key);
|
||||
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.evictions += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub async fn stats(&self) -> CacheStats {
|
||||
self.stats.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get cache hit rate
|
||||
pub async fn hit_rate(&self) -> f32 {
|
||||
let stats = self.stats.read().await;
|
||||
let total = stats.hits + stats.misses;
|
||||
if total == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
stats.hits as f32 / total as f32
|
||||
}
|
||||
|
||||
/// Get cache size
|
||||
pub async fn size(&self) -> usize {
|
||||
self.cache.read().await.len()
|
||||
}
|
||||
|
||||
/// Warm up cache with frequently accessed entries
|
||||
pub async fn warmup(&self, entries: Vec<MemoryEntry>) {
|
||||
for entry in entries {
|
||||
self.put(entry).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get top accessed entries (for preloading)
|
||||
pub async fn get_hot_entries(&self, limit: usize) -> Vec<MemoryEntry> {
|
||||
let cache = self.cache.read().await;
|
||||
|
||||
let mut entries: Vec<_> = cache
|
||||
.values()
|
||||
.map(|c| (c.access_count, c.entry.clone()))
|
||||
.collect();
|
||||
|
||||
entries.sort_by(|a, b| b.0.cmp(&a.0));
|
||||
entries.truncate(limit);
|
||||
|
||||
entries.into_iter().map(|(_, e)| e).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::MemoryType;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_put_and_get() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"User prefers concise responses".to_string(),
|
||||
);
|
||||
|
||||
cache.put(entry.clone()).await;
|
||||
let retrieved = cache.get(&entry.uri).await;
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().content, "User prefers concise responses");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_miss() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let retrieved = cache.get("nonexistent").await;
|
||||
|
||||
assert!(retrieved.is_none());
|
||||
|
||||
let stats = cache.stats().await;
|
||||
assert_eq!(stats.misses, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_remove() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
cache.put(entry.clone()).await;
|
||||
cache.remove(&entry.uri).await;
|
||||
let retrieved = cache.get(&entry.uri).await;
|
||||
|
||||
assert!(retrieved.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_clear() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
cache.put(entry).await;
|
||||
cache.clear().await;
|
||||
let size = cache.size().await;
|
||||
|
||||
assert_eq!(size, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_stats() {
|
||||
let cache = MemoryCache::default_config();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
cache.put(entry.clone()).await;
|
||||
|
||||
// Hit
|
||||
cache.get(&entry.uri).await;
|
||||
// Miss
|
||||
cache.get("nonexistent").await;
|
||||
|
||||
let stats = cache.stats().await;
|
||||
assert_eq!(stats.hits, 1);
|
||||
assert_eq!(stats.misses, 1);
|
||||
|
||||
let hit_rate = cache.hit_rate().await;
|
||||
assert!((hit_rate - 0.5).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_eviction() {
|
||||
let config = CacheConfig {
|
||||
max_entries: 2,
|
||||
ttl: Duration::from_secs(3600),
|
||||
enabled: true,
|
||||
};
|
||||
let cache = MemoryCache::new(config);
|
||||
|
||||
let entry1 = MemoryEntry::new("test", MemoryType::Preference, "1", "1".to_string());
|
||||
let entry2 = MemoryEntry::new("test", MemoryType::Preference, "2", "2".to_string());
|
||||
let entry3 = MemoryEntry::new("test", MemoryType::Preference, "3", "3".to_string());
|
||||
|
||||
cache.put(entry1.clone()).await;
|
||||
cache.put(entry2.clone()).await;
|
||||
|
||||
// Access entry1 to make it hot
|
||||
cache.get(&entry1.uri).await;
|
||||
|
||||
// Add entry3, should evict entry2 (LRU)
|
||||
cache.put(entry3).await;
|
||||
|
||||
let size = cache.size().await;
|
||||
assert_eq!(size, 2);
|
||||
|
||||
let stats = cache.stats().await;
|
||||
assert_eq!(stats.evictions, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_hot_entries() {
|
||||
let cache = MemoryCache::default_config();
|
||||
|
||||
let entry1 = MemoryEntry::new("test", MemoryType::Preference, "1", "1".to_string());
|
||||
let entry2 = MemoryEntry::new("test", MemoryType::Preference, "2", "2".to_string());
|
||||
|
||||
cache.put(entry1.clone()).await;
|
||||
cache.put(entry2.clone()).await;
|
||||
|
||||
// Access entry1 multiple times
|
||||
cache.get(&entry1.uri).await;
|
||||
cache.get(&entry1.uri).await;
|
||||
|
||||
let hot = cache.get_hot_entries(10).await;
|
||||
assert_eq!(hot.len(), 2);
|
||||
// entry1 should be first (more accesses)
|
||||
assert_eq!(hot[0].uri, entry1.uri);
|
||||
}
|
||||
}
|
||||
14
crates/zclaw-growth/src/retrieval/mod.rs
Normal file
14
crates/zclaw-growth/src/retrieval/mod.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
//! Retrieval components for ZCLAW Growth System
|
||||
//!
|
||||
//! This module provides advanced retrieval capabilities:
|
||||
//! - `semantic`: Semantic similarity computation
|
||||
//! - `query`: Query analysis and expansion
|
||||
//! - `cache`: Hot memory caching
|
||||
|
||||
pub mod semantic;
|
||||
pub mod query;
|
||||
pub mod cache;
|
||||
|
||||
pub use semantic::SemanticScorer;
|
||||
pub use query::QueryAnalyzer;
|
||||
pub use cache::MemoryCache;
|
||||
352
crates/zclaw-growth/src/retrieval/query.rs
Normal file
352
crates/zclaw-growth/src/retrieval/query.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
//! Query Analyzer
|
||||
//!
|
||||
//! Provides query analysis and expansion capabilities for improved retrieval.
|
||||
//! Extracts keywords, identifies intent, and generates search variations.
|
||||
|
||||
use crate::types::MemoryType;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Query analysis result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnalyzedQuery {
|
||||
/// Original query string
|
||||
pub original: String,
|
||||
/// Extracted keywords
|
||||
pub keywords: Vec<String>,
|
||||
/// Query intent
|
||||
pub intent: QueryIntent,
|
||||
/// Memory types to search (inferred from query)
|
||||
pub target_types: Vec<MemoryType>,
|
||||
/// Expanded search terms
|
||||
pub expansions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Query intent classification
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum QueryIntent {
|
||||
/// Looking for preferences/settings
|
||||
Preference,
|
||||
/// Looking for factual knowledge
|
||||
Knowledge,
|
||||
/// Looking for how-to/experience
|
||||
Experience,
|
||||
/// General conversation
|
||||
General,
|
||||
/// Code-related query
|
||||
Code,
|
||||
/// Configuration query
|
||||
Configuration,
|
||||
}
|
||||
|
||||
/// Query analyzer
|
||||
pub struct QueryAnalyzer {
|
||||
/// Keywords that indicate preference queries
|
||||
preference_indicators: HashSet<String>,
|
||||
/// Keywords that indicate knowledge queries
|
||||
knowledge_indicators: HashSet<String>,
|
||||
/// Keywords that indicate experience queries
|
||||
experience_indicators: HashSet<String>,
|
||||
/// Keywords that indicate code queries
|
||||
code_indicators: HashSet<String>,
|
||||
/// Stop words to filter out
|
||||
stop_words: HashSet<String>,
|
||||
}
|
||||
|
||||
impl QueryAnalyzer {
|
||||
/// Create a new query analyzer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
preference_indicators: [
|
||||
"prefer", "like", "want", "favorite", "favourite", "style",
|
||||
"format", "language", "setting", "preference", "usually",
|
||||
"typically", "always", "never", "习惯", "偏好", "喜欢", "想要",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
knowledge_indicators: [
|
||||
"what", "how", "why", "explain", "tell", "know", "learn",
|
||||
"understand", "meaning", "definition", "concept", "theory",
|
||||
"是什么", "怎么", "为什么", "解释", "了解", "知道",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
experience_indicators: [
|
||||
"experience", "tried", "used", "before", "last time",
|
||||
"previous", "history", "remember", "recall", "when",
|
||||
"经验", "尝试", "用过", "上次", "记得", "回忆",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
code_indicators: [
|
||||
"code", "function", "class", "method", "variable", "type",
|
||||
"error", "bug", "fix", "implement", "refactor", "api",
|
||||
"代码", "函数", "类", "方法", "变量", "错误", "修复", "实现",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
stop_words: [
|
||||
"the", "a", "an", "is", "are", "was", "were", "be", "been",
|
||||
"have", "has", "had", "do", "does", "did", "will", "would",
|
||||
"could", "should", "may", "might", "must", "can", "to", "of",
|
||||
"in", "for", "on", "with", "at", "by", "from", "as", "and",
|
||||
"or", "but", "if", "then", "else", "when", "where", "which",
|
||||
"who", "whom", "whose", "this", "that", "these", "those",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze a query string
|
||||
pub fn analyze(&self, query: &str) -> AnalyzedQuery {
|
||||
let keywords = self.extract_keywords(query);
|
||||
let intent = self.classify_intent(&keywords);
|
||||
let target_types = self.infer_memory_types(intent, &keywords);
|
||||
let expansions = self.expand_query(&keywords);
|
||||
|
||||
AnalyzedQuery {
|
||||
original: query.to_string(),
|
||||
keywords,
|
||||
intent,
|
||||
target_types,
|
||||
expansions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract keywords from query
|
||||
fn extract_keywords(&self, query: &str) -> Vec<String> {
|
||||
query
|
||||
.to_lowercase()
|
||||
.split(|c: char| !c.is_alphanumeric() && !is_cjk(c))
|
||||
.filter(|s| !s.is_empty() && s.len() > 1)
|
||||
.filter(|s| !self.stop_words.contains(*s))
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Classify query intent
|
||||
fn classify_intent(&self, keywords: &[String]) -> QueryIntent {
|
||||
let mut scores = [
|
||||
(QueryIntent::Preference, 0),
|
||||
(QueryIntent::Knowledge, 0),
|
||||
(QueryIntent::Experience, 0),
|
||||
(QueryIntent::Code, 0),
|
||||
];
|
||||
|
||||
for keyword in keywords {
|
||||
if self.preference_indicators.contains(keyword) {
|
||||
scores[0].1 += 2;
|
||||
}
|
||||
if self.knowledge_indicators.contains(keyword) {
|
||||
scores[1].1 += 2;
|
||||
}
|
||||
if self.experience_indicators.contains(keyword) {
|
||||
scores[2].1 += 2;
|
||||
}
|
||||
if self.code_indicators.contains(keyword) {
|
||||
scores[3].1 += 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Find highest scoring intent
|
||||
scores.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
|
||||
if scores[0].1 > 0 {
|
||||
scores[0].0
|
||||
} else {
|
||||
QueryIntent::General
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer which memory types to search
|
||||
fn infer_memory_types(&self, intent: QueryIntent, _keywords: &[String]) -> Vec<MemoryType> {
|
||||
let mut types = Vec::new();
|
||||
|
||||
match intent {
|
||||
QueryIntent::Preference => {
|
||||
types.push(MemoryType::Preference);
|
||||
}
|
||||
QueryIntent::Knowledge | QueryIntent::Code => {
|
||||
types.push(MemoryType::Knowledge);
|
||||
types.push(MemoryType::Experience);
|
||||
}
|
||||
QueryIntent::Experience => {
|
||||
types.push(MemoryType::Experience);
|
||||
types.push(MemoryType::Knowledge);
|
||||
}
|
||||
QueryIntent::General => {
|
||||
// Search all types
|
||||
types.push(MemoryType::Preference);
|
||||
types.push(MemoryType::Knowledge);
|
||||
types.push(MemoryType::Experience);
|
||||
}
|
||||
QueryIntent::Configuration => {
|
||||
types.push(MemoryType::Preference);
|
||||
types.push(MemoryType::Knowledge);
|
||||
}
|
||||
}
|
||||
|
||||
types
|
||||
}
|
||||
|
||||
/// Expand query with related terms
|
||||
fn expand_query(&self, keywords: &[String]) -> Vec<String> {
|
||||
let mut expansions = Vec::new();
|
||||
|
||||
// Add stemmed variations (simplified)
|
||||
for keyword in keywords {
|
||||
// Add singular/plural variations
|
||||
if keyword.ends_with('s') && keyword.len() > 3 {
|
||||
expansions.push(keyword[..keyword.len()-1].to_string());
|
||||
} else {
|
||||
expansions.push(format!("{}s", keyword));
|
||||
}
|
||||
|
||||
// Add common synonyms (simplified)
|
||||
if let Some(synonyms) = self.get_synonyms(keyword) {
|
||||
expansions.extend(synonyms);
|
||||
}
|
||||
}
|
||||
|
||||
expansions
|
||||
}
|
||||
|
||||
/// Get synonyms for a keyword (simplified)
|
||||
fn get_synonyms(&self, keyword: &str) -> Option<Vec<String>> {
|
||||
let synonyms: &[&str] = match keyword {
|
||||
"code" => &["program", "script", "source"],
|
||||
"error" => &["bug", "issue", "problem", "exception"],
|
||||
"fix" => &["solve", "resolve", "repair", "patch"],
|
||||
"fast" => &["quick", "speed", "performance", "efficient"],
|
||||
"slow" => &["performance", "optimize", "speed"],
|
||||
"help" => &["assist", "support", "guide", "aid"],
|
||||
"learn" => &["study", "understand", "know", "grasp"],
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
Some(synonyms.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
|
||||
/// Generate search queries from analyzed query
|
||||
pub fn generate_search_queries(&self, analyzed: &AnalyzedQuery) -> Vec<String> {
|
||||
let mut queries = vec![analyzed.original.clone()];
|
||||
|
||||
// Add keyword-based query
|
||||
if !analyzed.keywords.is_empty() {
|
||||
queries.push(analyzed.keywords.join(" "));
|
||||
}
|
||||
|
||||
// Add expanded terms
|
||||
for expansion in &analyzed.expansions {
|
||||
if !expansion.is_empty() {
|
||||
queries.push(expansion.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate
|
||||
queries.sort();
|
||||
queries.dedup();
|
||||
|
||||
queries
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QueryAnalyzer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if character is CJK
|
||||
fn is_cjk(c: char) -> bool {
|
||||
matches!(c,
|
||||
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
|
||||
'\u{3400}'..='\u{4DBF}' | // CJK Unified Ideographs Extension A
|
||||
'\u{20000}'..='\u{2A6DF}' | // CJK Unified Ideographs Extension B
|
||||
'\u{2A700}'..='\u{2B73F}' | // CJK Unified Ideographs Extension C
|
||||
'\u{2B740}'..='\u{2B81F}' | // CJK Unified Ideographs Extension D
|
||||
'\u{2B820}'..='\u{2CEAF}' | // CJK Unified Ideographs Extension E
|
||||
'\u{F900}'..='\u{FAFF}' | // CJK Compatibility Ideographs
|
||||
'\u{2F800}'..='\u{2FA1F}' // CJK Compatibility Ideographs Supplement
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_keywords() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let keywords = analyzer.extract_keywords("What is the Rust programming language?");
|
||||
|
||||
assert!(keywords.contains(&"rust".to_string()));
|
||||
assert!(keywords.contains(&"programming".to_string()));
|
||||
assert!(keywords.contains(&"language".to_string()));
|
||||
assert!(!keywords.contains(&"the".to_string())); // stop word
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_intent_preference() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("I prefer concise responses");
|
||||
|
||||
assert_eq!(analyzed.intent, QueryIntent::Preference);
|
||||
assert!(analyzed.target_types.contains(&MemoryType::Preference));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_intent_knowledge() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("Explain how async/await works in Rust");
|
||||
|
||||
assert_eq!(analyzed.intent, QueryIntent::Knowledge);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_intent_code() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("Fix this error in my function");
|
||||
|
||||
assert_eq!(analyzed.intent, QueryIntent::Code);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_expansion() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("fix the error");
|
||||
|
||||
assert!(!analyzed.expansions.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_search_queries() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let analyzed = analyzer.analyze("Rust programming");
|
||||
let queries = analyzer.generate_search_queries(&analyzed);
|
||||
|
||||
assert!(queries.len() >= 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cjk_detection() {
|
||||
assert!(is_cjk('中'));
|
||||
assert!(is_cjk('文'));
|
||||
assert!(!is_cjk('a'));
|
||||
assert!(!is_cjk('1'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chinese_keywords() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
let keywords = analyzer.extract_keywords("我喜欢简洁的回复");
|
||||
|
||||
// Chinese characters should be extracted
|
||||
assert!(!keywords.is_empty());
|
||||
}
|
||||
}
|
||||
374
crates/zclaw-growth/src/retrieval/semantic.rs
Normal file
374
crates/zclaw-growth/src/retrieval/semantic.rs
Normal file
@@ -0,0 +1,374 @@
|
||||
//! Semantic Similarity Scorer
|
||||
//!
|
||||
//! Provides TF-IDF based semantic similarity computation for memory retrieval.
|
||||
//! This is a lightweight, dependency-free implementation suitable for
|
||||
//! medium-scale memory systems.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use crate::types::MemoryEntry;
|
||||
|
||||
/// Semantic similarity scorer using TF-IDF
|
||||
pub struct SemanticScorer {
|
||||
/// Document frequency for IDF computation
|
||||
document_frequencies: HashMap<String, usize>,
|
||||
/// Total number of documents
|
||||
total_documents: usize,
|
||||
/// Precomputed TF-IDF vectors for entries
|
||||
entry_vectors: HashMap<String, HashMap<String, f32>>,
|
||||
/// Stop words to ignore
|
||||
stop_words: HashSet<String>,
|
||||
}
|
||||
|
||||
impl SemanticScorer {
|
||||
/// Create a new semantic scorer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
document_frequencies: HashMap::new(),
|
||||
total_documents: 0,
|
||||
entry_vectors: HashMap::new(),
|
||||
stop_words: Self::default_stop_words(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get default stop words
|
||||
fn default_stop_words() -> HashSet<String> {
|
||||
[
|
||||
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
|
||||
"have", "has", "had", "do", "does", "did", "will", "would", "could",
|
||||
"should", "may", "might", "must", "shall", "can", "need", "dare",
|
||||
"ought", "used", "to", "of", "in", "for", "on", "with", "at", "by",
|
||||
"from", "as", "into", "through", "during", "before", "after",
|
||||
"above", "below", "between", "under", "again", "further", "then",
|
||||
"once", "here", "there", "when", "where", "why", "how", "all",
|
||||
"each", "few", "more", "most", "other", "some", "such", "no", "nor",
|
||||
"not", "only", "own", "same", "so", "than", "too", "very", "just",
|
||||
"and", "but", "if", "or", "because", "until", "while", "although",
|
||||
"though", "after", "before", "when", "whenever", "i", "you", "he",
|
||||
"she", "it", "we", "they", "what", "which", "who", "whom", "this",
|
||||
"that", "these", "those", "am", "im", "youre", "hes", "shes",
|
||||
"its", "were", "theyre", "ive", "youve", "weve", "theyve", "id",
|
||||
"youd", "hed", "shed", "wed", "theyd", "ill", "youll", "hell",
|
||||
"shell", "well", "theyll", "isnt", "arent", "wasnt", "werent",
|
||||
"hasnt", "havent", "hadnt", "doesnt", "dont", "didnt", "wont",
|
||||
"wouldnt", "shant", "shouldnt", "cant", "cannot", "couldnt",
|
||||
"mustnt", "lets", "thats", "whos", "whats", "heres", "theres",
|
||||
"whens", "wheres", "whys", "hows", "a", "b", "c", "d", "e", "f",
|
||||
"g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s",
|
||||
"t", "u", "v", "w", "x", "y", "z",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Tokenize text into words
|
||||
fn tokenize(text: &str) -> Vec<String> {
|
||||
text.to_lowercase()
|
||||
.split(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty() && s.len() > 1)
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Remove stop words from tokens
|
||||
fn remove_stop_words(&self, tokens: &[String]) -> Vec<String> {
|
||||
tokens
|
||||
.iter()
|
||||
.filter(|t| !self.stop_words.contains(*t))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute term frequency for a list of tokens
|
||||
fn compute_tf(tokens: &[String]) -> HashMap<String, f32> {
|
||||
let mut tf = HashMap::new();
|
||||
let total = tokens.len() as f32;
|
||||
|
||||
for token in tokens {
|
||||
*tf.entry(token.clone()).or_insert(0.0) += 1.0;
|
||||
}
|
||||
|
||||
// Normalize by total tokens
|
||||
for count in tf.values_mut() {
|
||||
*count /= total;
|
||||
}
|
||||
|
||||
tf
|
||||
}
|
||||
|
||||
/// Compute IDF for a term
|
||||
fn compute_idf(&self, term: &str) -> f32 {
|
||||
let df = self.document_frequencies.get(term).copied().unwrap_or(0);
|
||||
if df == 0 || self.total_documents == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
((self.total_documents as f32 + 1.0) / (df as f32 + 1.0)).ln() + 1.0
|
||||
}
|
||||
|
||||
/// Index an entry for semantic search
|
||||
pub fn index_entry(&mut self, entry: &MemoryEntry) {
|
||||
// Tokenize content and keywords
|
||||
let mut all_tokens = Self::tokenize(&entry.content);
|
||||
for keyword in &entry.keywords {
|
||||
all_tokens.extend(Self::tokenize(keyword));
|
||||
}
|
||||
all_tokens = self.remove_stop_words(&all_tokens);
|
||||
|
||||
// Update document frequencies
|
||||
let unique_terms: HashSet<_> = all_tokens.iter().cloned().collect();
|
||||
for term in &unique_terms {
|
||||
*self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
|
||||
}
|
||||
self.total_documents += 1;
|
||||
|
||||
// Compute TF-IDF vector
|
||||
let tf = Self::compute_tf(&all_tokens);
|
||||
let mut tfidf = HashMap::new();
|
||||
for (term, tf_val) in tf {
|
||||
let idf = self.compute_idf(&term);
|
||||
tfidf.insert(term, tf_val * idf);
|
||||
}
|
||||
|
||||
self.entry_vectors.insert(entry.uri.clone(), tfidf);
|
||||
}
|
||||
|
||||
/// Remove an entry from the index
|
||||
pub fn remove_entry(&mut self, uri: &str) {
|
||||
self.entry_vectors.remove(uri);
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
fn cosine_similarity(v1: &HashMap<String, f32>, v2: &HashMap<String, f32>) -> f32 {
|
||||
if v1.is_empty() || v2.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Find common keys
|
||||
let mut dot_product = 0.0;
|
||||
let mut norm1 = 0.0;
|
||||
let mut norm2 = 0.0;
|
||||
|
||||
for (k, v) in v1 {
|
||||
norm1 += v * v;
|
||||
if let Some(v2_val) = v2.get(k) {
|
||||
dot_product += v * v2_val;
|
||||
}
|
||||
}
|
||||
|
||||
for v in v2.values() {
|
||||
norm2 += v * v;
|
||||
}
|
||||
|
||||
let denom = (norm1 * norm2).sqrt();
|
||||
if denom == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
(dot_product / denom).clamp(0.0, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Score similarity between query and entry
|
||||
pub fn score_similarity(&self, query: &str, entry: &MemoryEntry) -> f32 {
|
||||
// Tokenize query
|
||||
let query_tokens = self.remove_stop_words(&Self::tokenize(query));
|
||||
if query_tokens.is_empty() {
|
||||
return 0.5; // Neutral score for empty query
|
||||
}
|
||||
|
||||
// Compute query TF-IDF
|
||||
let query_tf = Self::compute_tf(&query_tokens);
|
||||
let mut query_vec = HashMap::new();
|
||||
for (term, tf_val) in query_tf {
|
||||
let idf = self.compute_idf(&term);
|
||||
query_vec.insert(term, tf_val * idf);
|
||||
}
|
||||
|
||||
// Get entry vector
|
||||
let entry_vec = match self.entry_vectors.get(&entry.uri) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
// Fall back to simple matching if not indexed
|
||||
return self.fallback_similarity(&query_tokens, entry);
|
||||
}
|
||||
};
|
||||
|
||||
// Compute cosine similarity
|
||||
let cosine = Self::cosine_similarity(&query_vec, entry_vec);
|
||||
|
||||
// Combine with keyword matching for better results
|
||||
let keyword_boost = self.keyword_match_score(&query_tokens, entry);
|
||||
|
||||
// Weighted combination
|
||||
cosine * 0.7 + keyword_boost * 0.3
|
||||
}
|
||||
|
||||
/// Fallback similarity when entry is not indexed
|
||||
fn fallback_similarity(&self, query_tokens: &[String], entry: &MemoryEntry) -> f32 {
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let mut matches = 0;
|
||||
|
||||
for token in query_tokens {
|
||||
if content_lower.contains(token) {
|
||||
matches += 1;
|
||||
}
|
||||
for keyword in &entry.keywords {
|
||||
if keyword.to_lowercase().contains(token) {
|
||||
matches += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(matches as f32) / (query_tokens.len() * 2).max(1) as f32
|
||||
}
|
||||
|
||||
/// Compute keyword match score
|
||||
fn keyword_match_score(&self, query_tokens: &[String], entry: &MemoryEntry) -> f32 {
|
||||
if entry.keywords.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut matches = 0;
|
||||
for token in query_tokens {
|
||||
for keyword in &entry.keywords {
|
||||
if keyword.to_lowercase().contains(&token.to_lowercase()) {
|
||||
matches += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(matches as f32) / query_tokens.len().max(1) as f32
|
||||
}
|
||||
|
||||
/// Clear the index
|
||||
pub fn clear(&mut self) {
|
||||
self.document_frequencies.clear();
|
||||
self.total_documents = 0;
|
||||
self.entry_vectors.clear();
|
||||
}
|
||||
|
||||
/// Get statistics about the index
|
||||
pub fn stats(&self) -> IndexStats {
|
||||
IndexStats {
|
||||
total_documents: self.total_documents,
|
||||
unique_terms: self.document_frequencies.len(),
|
||||
indexed_entries: self.entry_vectors.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SemanticScorer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Index statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndexStats {
|
||||
pub total_documents: usize,
|
||||
pub unique_terms: usize,
|
||||
pub indexed_entries: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::MemoryType;
|
||||
|
||||
#[test]
|
||||
fn test_tokenize() {
|
||||
let tokens = SemanticScorer::tokenize("Hello, World! This is a test.");
|
||||
assert_eq!(tokens, vec!["hello", "world", "this", "is", "test"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_words_removal() {
|
||||
let scorer = SemanticScorer::new();
|
||||
let tokens = vec!["hello".to_string(), "the".to_string(), "world".to_string()];
|
||||
let filtered = scorer.remove_stop_words(&tokens);
|
||||
assert_eq!(filtered, vec!["hello", "world"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tf_computation() {
|
||||
let tokens = vec!["hello".to_string(), "hello".to_string(), "world".to_string()];
|
||||
let tf = SemanticScorer::compute_tf(&tokens);
|
||||
|
||||
let hello_tf = tf.get("hello").unwrap();
|
||||
let world_tf = tf.get("world").unwrap();
|
||||
|
||||
// Allow for floating point comparison
|
||||
assert!((hello_tf - (2.0 / 3.0)).abs() < 0.001);
|
||||
assert!((world_tf - (1.0 / 3.0)).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let mut v1 = HashMap::new();
|
||||
v1.insert("a".to_string(), 1.0);
|
||||
v1.insert("b".to_string(), 2.0);
|
||||
|
||||
let mut v2 = HashMap::new();
|
||||
v2.insert("a".to_string(), 1.0);
|
||||
v2.insert("b".to_string(), 2.0);
|
||||
|
||||
// Identical vectors should have similarity 1.0
|
||||
let sim = SemanticScorer::cosine_similarity(&v1, &v2);
|
||||
assert!((sim - 1.0).abs() < 0.001);
|
||||
|
||||
// Orthogonal vectors should have similarity 0.0
|
||||
let mut v3 = HashMap::new();
|
||||
v3.insert("c".to_string(), 1.0);
|
||||
let sim2 = SemanticScorer::cosine_similarity(&v1, &v3);
|
||||
assert!((sim2 - 0.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_and_score() {
|
||||
let mut scorer = SemanticScorer::new();
|
||||
|
||||
let entry1 = MemoryEntry::new(
|
||||
"test",
|
||||
MemoryType::Knowledge,
|
||||
"rust",
|
||||
"Rust is a systems programming language focused on safety and performance".to_string(),
|
||||
).with_keywords(vec!["rust".to_string(), "programming".to_string(), "safety".to_string()]);
|
||||
|
||||
let entry2 = MemoryEntry::new(
|
||||
"test",
|
||||
MemoryType::Knowledge,
|
||||
"python",
|
||||
"Python is a high-level programming language".to_string(),
|
||||
).with_keywords(vec!["python".to_string(), "programming".to_string()]);
|
||||
|
||||
scorer.index_entry(&entry1);
|
||||
scorer.index_entry(&entry2);
|
||||
|
||||
// Query for Rust should score higher on entry1
|
||||
let score1 = scorer.score_similarity("rust safety", &entry1);
|
||||
let score2 = scorer.score_similarity("rust safety", &entry2);
|
||||
|
||||
assert!(score1 > score2, "Rust query should score higher on Rust entry");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats() {
|
||||
let mut scorer = SemanticScorer::new();
|
||||
|
||||
let entry = MemoryEntry::new(
|
||||
"test",
|
||||
MemoryType::Knowledge,
|
||||
"test",
|
||||
"Hello world".to_string(),
|
||||
);
|
||||
|
||||
scorer.index_entry(&entry);
|
||||
let stats = scorer.stats();
|
||||
|
||||
assert_eq!(stats.total_documents, 1);
|
||||
assert_eq!(stats.indexed_entries, 1);
|
||||
assert!(stats.unique_terms > 0);
|
||||
}
|
||||
}
|
||||
348
crates/zclaw-growth/src/retriever.rs
Normal file
348
crates/zclaw-growth/src/retriever.rs
Normal file
@@ -0,0 +1,348 @@
|
||||
//! Memory Retriever - Retrieves relevant memories from OpenViking
|
||||
//!
|
||||
//! This module provides the `MemoryRetriever` which performs semantic search
|
||||
//! over stored memories to find contextually relevant information.
|
||||
//! Uses multiple retrieval strategies and intelligent reranking.
|
||||
|
||||
use crate::retrieval::{MemoryCache, QueryAnalyzer, SemanticScorer};
|
||||
use crate::types::{MemoryEntry, MemoryType, RetrievalConfig, RetrievalResult};
|
||||
use crate::viking_adapter::{FindOptions, VikingAdapter};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use zclaw_types::{AgentId, Result};
|
||||
|
||||
/// Memory Retriever - retrieves relevant memories from OpenViking
|
||||
pub struct MemoryRetriever {
|
||||
/// OpenViking adapter
|
||||
viking: Arc<VikingAdapter>,
|
||||
/// Retrieval configuration
|
||||
config: RetrievalConfig,
|
||||
/// Semantic scorer for similarity computation
|
||||
scorer: RwLock<SemanticScorer>,
|
||||
/// Query analyzer
|
||||
analyzer: QueryAnalyzer,
|
||||
/// Memory cache
|
||||
cache: MemoryCache,
|
||||
}
|
||||
|
||||
impl MemoryRetriever {
|
||||
/// Create a new memory retriever
|
||||
pub fn new(viking: Arc<VikingAdapter>) -> Self {
|
||||
Self {
|
||||
viking,
|
||||
config: RetrievalConfig::default(),
|
||||
scorer: RwLock::new(SemanticScorer::new()),
|
||||
analyzer: QueryAnalyzer::new(),
|
||||
cache: MemoryCache::default_config(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom configuration
|
||||
pub fn with_config(mut self, config: RetrievalConfig) -> Self {
|
||||
self.config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Retrieve relevant memories for a query
|
||||
///
|
||||
/// This method:
|
||||
/// 1. Analyzes the query to determine intent and keywords
|
||||
/// 2. Searches for preferences matching the query
|
||||
/// 3. Searches for relevant knowledge
|
||||
/// 4. Searches for applicable experience
|
||||
/// 5. Reranks results using semantic similarity
|
||||
/// 6. Applies token budget constraints
|
||||
pub async fn retrieve(
|
||||
&self,
|
||||
agent_id: &AgentId,
|
||||
query: &str,
|
||||
) -> Result<RetrievalResult> {
|
||||
tracing::debug!("[MemoryRetriever] Retrieving memories for query: {}", query);
|
||||
|
||||
// Analyze query
|
||||
let analyzed = self.analyzer.analyze(query);
|
||||
tracing::debug!(
|
||||
"[MemoryRetriever] Query analysis: intent={:?}, keywords={:?}",
|
||||
analyzed.intent,
|
||||
analyzed.keywords
|
||||
);
|
||||
|
||||
// Retrieve each type with budget constraints and reranking
|
||||
let preferences = self
|
||||
.retrieve_and_rerank(
|
||||
&agent_id.to_string(),
|
||||
MemoryType::Preference,
|
||||
query,
|
||||
&analyzed.keywords,
|
||||
self.config.max_results_per_type,
|
||||
self.config.preference_budget,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let knowledge = self
|
||||
.retrieve_and_rerank(
|
||||
&agent_id.to_string(),
|
||||
MemoryType::Knowledge,
|
||||
query,
|
||||
&analyzed.keywords,
|
||||
self.config.max_results_per_type,
|
||||
self.config.knowledge_budget,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let experience = self
|
||||
.retrieve_and_rerank(
|
||||
&agent_id.to_string(),
|
||||
MemoryType::Experience,
|
||||
query,
|
||||
&analyzed.keywords,
|
||||
self.config.max_results_per_type / 2,
|
||||
self.config.experience_budget,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let total_tokens = preferences.iter()
|
||||
.chain(knowledge.iter())
|
||||
.chain(experience.iter())
|
||||
.map(|m| m.estimated_tokens())
|
||||
.sum();
|
||||
|
||||
// Update cache with retrieved entries
|
||||
for entry in preferences.iter().chain(knowledge.iter()).chain(experience.iter()) {
|
||||
self.cache.put(entry.clone()).await;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[MemoryRetriever] Retrieved {} preferences, {} knowledge, {} experience ({} tokens)",
|
||||
preferences.len(),
|
||||
knowledge.len(),
|
||||
experience.len(),
|
||||
total_tokens
|
||||
);
|
||||
|
||||
Ok(RetrievalResult {
|
||||
preferences,
|
||||
knowledge,
|
||||
experience,
|
||||
total_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieve and rerank memories by type
|
||||
async fn retrieve_and_rerank(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
memory_type: MemoryType,
|
||||
query: &str,
|
||||
keywords: &[String],
|
||||
max_results: usize,
|
||||
token_budget: usize,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
// Build scope for OpenViking search
|
||||
let scope = format!("agent://{}/{}", agent_id, memory_type);
|
||||
|
||||
// Generate search queries (original + expanded)
|
||||
let analyzed_for_search = crate::retrieval::query::AnalyzedQuery {
|
||||
original: query.to_string(),
|
||||
keywords: keywords.to_vec(),
|
||||
intent: crate::retrieval::query::QueryIntent::General,
|
||||
target_types: vec![],
|
||||
expansions: vec![],
|
||||
};
|
||||
let search_queries = self.analyzer.generate_search_queries(&analyzed_for_search);
|
||||
|
||||
// Search with multiple queries and deduplicate
|
||||
let mut all_results = Vec::new();
|
||||
let mut seen_uris = std::collections::HashSet::new();
|
||||
|
||||
for search_query in search_queries {
|
||||
let options = FindOptions {
|
||||
scope: Some(scope.clone()),
|
||||
limit: Some(max_results * 2),
|
||||
min_similarity: Some(self.config.min_similarity),
|
||||
};
|
||||
|
||||
let results = self.viking.find(&search_query, options).await?;
|
||||
|
||||
for entry in results {
|
||||
if seen_uris.insert(entry.uri.clone()) {
|
||||
all_results.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rerank using semantic similarity
|
||||
let scored = self.rerank_entries(query, all_results).await;
|
||||
|
||||
// Apply token budget
|
||||
let mut filtered = Vec::new();
|
||||
let mut used_tokens = 0;
|
||||
|
||||
for entry in scored {
|
||||
let tokens = entry.estimated_tokens();
|
||||
if used_tokens + tokens <= token_budget {
|
||||
used_tokens += tokens;
|
||||
filtered.push(entry);
|
||||
}
|
||||
|
||||
if filtered.len() >= max_results {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Rerank entries using semantic similarity
|
||||
async fn rerank_entries(
|
||||
&self,
|
||||
query: &str,
|
||||
entries: Vec<MemoryEntry>,
|
||||
) -> Vec<MemoryEntry> {
|
||||
if entries.is_empty() {
|
||||
return entries;
|
||||
}
|
||||
|
||||
let mut scorer = self.scorer.write().await;
|
||||
|
||||
// Index entries for semantic search
|
||||
for entry in &entries {
|
||||
scorer.index_entry(entry);
|
||||
}
|
||||
|
||||
// Score each entry
|
||||
let mut scored: Vec<(f32, MemoryEntry)> = entries
|
||||
.into_iter()
|
||||
.map(|entry| {
|
||||
let score = scorer.score_similarity(query, &entry);
|
||||
(score, entry)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by score (descending), then by importance and access count
|
||||
scored.sort_by(|a, b| {
|
||||
b.0.partial_cmp(&a.0)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
.then_with(|| b.1.importance.cmp(&a.1.importance))
|
||||
.then_with(|| b.1.access_count.cmp(&a.1.access_count))
|
||||
});
|
||||
|
||||
scored.into_iter().map(|(_, entry)| entry).collect()
|
||||
}
|
||||
|
||||
/// Retrieve a specific memory by URI (with cache)
|
||||
pub async fn get_by_uri(&self, uri: &str) -> Result<Option<MemoryEntry>> {
|
||||
// Check cache first
|
||||
if let Some(cached) = self.cache.get(uri).await {
|
||||
return Ok(Some(cached));
|
||||
}
|
||||
|
||||
// Fall back to storage
|
||||
let result = self.viking.get(uri).await?;
|
||||
|
||||
// Update cache
|
||||
if let Some(ref entry) = result {
|
||||
self.cache.put(entry.clone()).await;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get all memories for an agent (for debugging/admin)
|
||||
pub async fn get_all_memories(&self, agent_id: &AgentId) -> Result<Vec<MemoryEntry>> {
|
||||
let scope = format!("agent://{}", agent_id);
|
||||
let options = FindOptions {
|
||||
scope: Some(scope),
|
||||
limit: None,
|
||||
min_similarity: None,
|
||||
};
|
||||
|
||||
self.viking.find("", options).await
|
||||
}
|
||||
|
||||
/// Get memory statistics for an agent
|
||||
pub async fn get_stats(&self, agent_id: &AgentId) -> Result<MemoryStats> {
|
||||
let all = self.get_all_memories(agent_id).await?;
|
||||
|
||||
let preference_count = all.iter().filter(|m| m.memory_type == MemoryType::Preference).count();
|
||||
let knowledge_count = all.iter().filter(|m| m.memory_type == MemoryType::Knowledge).count();
|
||||
let experience_count = all.iter().filter(|m| m.memory_type == MemoryType::Experience).count();
|
||||
|
||||
Ok(MemoryStats {
|
||||
total_count: all.len(),
|
||||
preference_count,
|
||||
knowledge_count,
|
||||
experience_count,
|
||||
cache_hit_rate: self.cache.hit_rate().await,
|
||||
})
|
||||
}
|
||||
|
||||
/// Clear the semantic index
|
||||
pub async fn clear_index(&self) {
|
||||
let mut scorer = self.scorer.write().await;
|
||||
scorer.clear();
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub async fn cache_stats(&self) -> (usize, f32) {
|
||||
let size = self.cache.size().await;
|
||||
let hit_rate = self.cache.hit_rate().await;
|
||||
(size, hit_rate)
|
||||
}
|
||||
|
||||
/// Warm up cache with hot entries
|
||||
pub async fn warmup_cache(&self, agent_id: &AgentId) -> Result<usize> {
|
||||
let all = self.get_all_memories(agent_id).await?;
|
||||
|
||||
// Sort by access count to get hot entries
|
||||
let mut sorted = all;
|
||||
sorted.sort_by(|a, b| b.access_count.cmp(&a.access_count));
|
||||
|
||||
// Take top 50 hot entries
|
||||
let hot: Vec<_> = sorted.into_iter().take(50).collect();
|
||||
let count = hot.len();
|
||||
|
||||
self.cache.warmup(hot).await;
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryStats {
|
||||
pub total_count: usize,
|
||||
pub preference_count: usize,
|
||||
pub knowledge_count: usize,
|
||||
pub experience_count: usize,
|
||||
pub cache_hit_rate: f32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_retrieval_config_default() {
|
||||
let config = RetrievalConfig::default();
|
||||
assert_eq!(config.max_tokens, 500);
|
||||
assert_eq!(config.preference_budget, 200);
|
||||
assert_eq!(config.knowledge_budget, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_type_scope() {
|
||||
let scope = format!("agent://test-agent/{}", MemoryType::Preference);
|
||||
assert!(scope.contains("test-agent"));
|
||||
assert!(scope.contains("preferences"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retriever_creation() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let retriever = MemoryRetriever::new(viking);
|
||||
|
||||
let stats = retriever.cache_stats().await;
|
||||
assert_eq!(stats.0, 0); // Cache size should be 0
|
||||
}
|
||||
}
|
||||
9
crates/zclaw-growth/src/storage/mod.rs
Normal file
9
crates/zclaw-growth/src/storage/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
//! Storage backends for ZCLAW Growth System
|
||||
//!
|
||||
//! This module provides multiple storage backend implementations:
|
||||
//! - `InMemoryStorage`: Fast in-memory storage for testing and development
|
||||
//! - `SqliteStorage`: Persistent SQLite storage for production use
|
||||
|
||||
mod sqlite;
|
||||
|
||||
pub use sqlite::SqliteStorage;
|
||||
563
crates/zclaw-growth/src/storage/sqlite.rs
Normal file
563
crates/zclaw-growth/src/storage/sqlite.rs
Normal file
@@ -0,0 +1,563 @@
|
||||
//! SQLite Storage Backend
|
||||
//!
|
||||
//! Persistent storage backend using SQLite for production use.
|
||||
//! Provides efficient querying and full-text search capabilities.
|
||||
|
||||
use crate::retrieval::semantic::SemanticScorer;
|
||||
use crate::types::MemoryEntry;
|
||||
use crate::viking_adapter::{FindOptions, VikingStorage};
|
||||
use async_trait::async_trait;
|
||||
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions, SqliteRow};
|
||||
use sqlx::Row;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use zclaw_types::Result;
|
||||
use zclaw_types::ZclawError;
|
||||
|
||||
/// SQLite storage backend with TF-IDF semantic scoring
|
||||
pub struct SqliteStorage {
|
||||
/// Database connection pool
|
||||
pool: SqlitePool,
|
||||
/// Semantic scorer for similarity computation
|
||||
scorer: Arc<RwLock<SemanticScorer>>,
|
||||
/// Database path (for reference)
|
||||
#[allow(dead_code)]
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
/// Database row structure for memory entry
|
||||
struct MemoryRow {
|
||||
uri: String,
|
||||
memory_type: String,
|
||||
content: String,
|
||||
keywords: String,
|
||||
importance: i32,
|
||||
access_count: i32,
|
||||
created_at: String,
|
||||
last_accessed: String,
|
||||
}
|
||||
|
||||
impl SqliteStorage {
|
||||
/// Create a new SQLite storage at the given path
|
||||
pub async fn new(path: impl Into<PathBuf>) -> Result<Self> {
|
||||
let path = path.into();
|
||||
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = path.parent() {
|
||||
if parent.to_str() != Some(":memory:") {
|
||||
tokio::fs::create_dir_all(parent).await.map_err(|e| {
|
||||
ZclawError::StorageError(format!("Failed to create storage directory: {}", e))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
// Build connection string
|
||||
let db_url = if path.to_str() == Some(":memory:") {
|
||||
"sqlite::memory:".to_string()
|
||||
} else {
|
||||
format!("sqlite:{}?mode=rwc", path.to_string_lossy())
|
||||
};
|
||||
|
||||
// Create connection pool
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&db_url)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to connect to database: {}", e)))?;
|
||||
|
||||
let storage = Self {
|
||||
pool,
|
||||
scorer: Arc::new(RwLock::new(SemanticScorer::new())),
|
||||
path,
|
||||
};
|
||||
|
||||
storage.initialize_schema().await?;
|
||||
storage.warmup_scorer().await?;
|
||||
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
/// Create an in-memory SQLite database (for testing)
|
||||
pub async fn in_memory() -> Self {
|
||||
Self::new(":memory:").await.expect("Failed to create in-memory database")
|
||||
}
|
||||
|
||||
/// Initialize database schema with FTS5
|
||||
async fn initialize_schema(&self) -> Result<()> {
|
||||
// Create main memories table
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
uri TEXT PRIMARY KEY,
|
||||
memory_type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
keywords TEXT NOT NULL DEFAULT '[]',
|
||||
importance INTEGER NOT NULL DEFAULT 5,
|
||||
access_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
last_accessed TEXT NOT NULL
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to create memories table: {}", e)))?;
|
||||
|
||||
// Create FTS5 virtual table for full-text search
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
||||
uri,
|
||||
content,
|
||||
keywords,
|
||||
tokenize='unicode61'
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to create FTS5 table: {}", e)))?;
|
||||
|
||||
// Create index on memory_type for filtering
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_memory_type ON memories(memory_type)")
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to create index: {}", e)))?;
|
||||
|
||||
// Create index on importance for sorting
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_importance ON memories(importance DESC)")
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to create importance index: {}", e)))?;
|
||||
|
||||
// Create metadata table
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
json TEXT NOT NULL
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to create metadata table: {}", e)))?;
|
||||
|
||||
tracing::info!("[SqliteStorage] Database schema initialized");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Warmup semantic scorer with existing entries
|
||||
async fn warmup_scorer(&self) -> Result<()> {
|
||||
let rows = sqlx::query_as::<_, MemoryRow>(
|
||||
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories"
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to load memories for warmup: {}", e)))?;
|
||||
|
||||
let mut scorer = self.scorer.write().await;
|
||||
for row in rows {
|
||||
let entry = self.row_to_entry(&row);
|
||||
scorer.index_entry(&entry);
|
||||
}
|
||||
|
||||
let stats = scorer.stats();
|
||||
tracing::info!(
|
||||
"[SqliteStorage] Warmed up scorer with {} entries, {} terms",
|
||||
stats.indexed_entries,
|
||||
stats.unique_terms
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert database row to MemoryEntry
|
||||
fn row_to_entry(&self, row: &MemoryRow) -> MemoryEntry {
|
||||
let memory_type = crate::types::MemoryType::parse(&row.memory_type);
|
||||
let keywords: Vec<String> = serde_json::from_str(&row.keywords).unwrap_or_default();
|
||||
let created_at = chrono::DateTime::parse_from_rfc3339(&row.created_at)
|
||||
.map(|dt| dt.with_timezone(&chrono::Utc))
|
||||
.unwrap_or_else(|_| chrono::Utc::now());
|
||||
let last_accessed = chrono::DateTime::parse_from_rfc3339(&row.last_accessed)
|
||||
.map(|dt| dt.with_timezone(&chrono::Utc))
|
||||
.unwrap_or_else(|_| chrono::Utc::now());
|
||||
|
||||
MemoryEntry {
|
||||
uri: row.uri.clone(),
|
||||
memory_type,
|
||||
content: row.content.clone(),
|
||||
keywords,
|
||||
importance: row.importance as u8,
|
||||
access_count: row.access_count as u32,
|
||||
created_at,
|
||||
last_accessed,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update access count and last accessed time
|
||||
async fn touch_entry(&self, uri: &str) -> Result<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query(
|
||||
"UPDATE memories SET access_count = access_count + 1, last_accessed = ? WHERE uri = ?"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(uri)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to update access count: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl sqlx::FromRow<'_, SqliteRow> for MemoryRow {
|
||||
fn from_row(row: &SqliteRow) -> sqlx::Result<Self> {
|
||||
Ok(MemoryRow {
|
||||
uri: row.try_get("uri")?,
|
||||
memory_type: row.try_get("memory_type")?,
|
||||
content: row.try_get("content")?,
|
||||
keywords: row.try_get("keywords")?,
|
||||
importance: row.try_get("importance")?,
|
||||
access_count: row.try_get("access_count")?,
|
||||
created_at: row.try_get("created_at")?,
|
||||
last_accessed: row.try_get("last_accessed")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VikingStorage for SqliteStorage {
|
||||
async fn store(&self, entry: &MemoryEntry) -> Result<()> {
|
||||
let keywords_json = serde_json::to_string(&entry.keywords)
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to serialize keywords: {}", e)))?;
|
||||
|
||||
let created_at = entry.created_at.to_rfc3339();
|
||||
let last_accessed = entry.last_accessed.to_rfc3339();
|
||||
let memory_type = entry.memory_type.to_string();
|
||||
|
||||
// Insert into main table
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT OR REPLACE INTO memories
|
||||
(uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&entry.uri)
|
||||
.bind(&memory_type)
|
||||
.bind(&entry.content)
|
||||
.bind(&keywords_json)
|
||||
.bind(entry.importance as i32)
|
||||
.bind(entry.access_count as i32)
|
||||
.bind(&created_at)
|
||||
.bind(&last_accessed)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to store memory: {}", e)))?;
|
||||
|
||||
// Update FTS index - delete old and insert new
|
||||
let _ = sqlx::query("DELETE FROM memories_fts WHERE uri = ?")
|
||||
.bind(&entry.uri)
|
||||
.execute(&self.pool)
|
||||
.await;
|
||||
|
||||
let keywords_text = entry.keywords.join(" ");
|
||||
let _ = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO memories_fts (uri, content, keywords)
|
||||
VALUES (?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&entry.uri)
|
||||
.bind(&entry.content)
|
||||
.bind(&keywords_text)
|
||||
.execute(&self.pool)
|
||||
.await;
|
||||
|
||||
// Update semantic scorer
|
||||
let mut scorer = self.scorer.write().await;
|
||||
scorer.index_entry(entry);
|
||||
|
||||
tracing::debug!("[SqliteStorage] Stored memory: {}", entry.uri);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, uri: &str) -> Result<Option<MemoryEntry>> {
|
||||
let row = sqlx::query_as::<_, MemoryRow>(
|
||||
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories WHERE uri = ?"
|
||||
)
|
||||
.bind(uri)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to get memory: {}", e)))?;
|
||||
|
||||
if let Some(row) = row {
|
||||
let entry = self.row_to_entry(&row);
|
||||
|
||||
// Update access count
|
||||
self.touch_entry(&entry.uri).await?;
|
||||
|
||||
Ok(Some(entry))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
|
||||
// Get all matching entries
|
||||
let rows = if let Some(ref scope) = options.scope {
|
||||
sqlx::query_as::<_, MemoryRow>(
|
||||
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories WHERE uri LIKE ?"
|
||||
)
|
||||
.bind(format!("{}%", scope))
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to find memories: {}", e)))?
|
||||
} else {
|
||||
sqlx::query_as::<_, MemoryRow>(
|
||||
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories"
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to find memories: {}", e)))?
|
||||
};
|
||||
|
||||
// Convert to entries and compute semantic scores
|
||||
let scorer = self.scorer.read().await;
|
||||
let mut scored_entries: Vec<(f32, MemoryEntry)> = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let entry = self.row_to_entry(&row);
|
||||
|
||||
// Compute semantic score using TF-IDF
|
||||
let semantic_score = scorer.score_similarity(query, &entry);
|
||||
|
||||
// Apply similarity threshold
|
||||
if let Some(min_similarity) = options.min_similarity {
|
||||
if semantic_score < min_similarity {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
scored_entries.push((semantic_score, entry));
|
||||
}
|
||||
|
||||
// Sort by score (descending), then by importance and access count
|
||||
scored_entries.sort_by(|a, b| {
|
||||
b.0.partial_cmp(&a.0)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
.then_with(|| b.1.importance.cmp(&a.1.importance))
|
||||
.then_with(|| b.1.access_count.cmp(&a.1.access_count))
|
||||
});
|
||||
|
||||
// Apply limit
|
||||
if let Some(limit) = options.limit {
|
||||
scored_entries.truncate(limit);
|
||||
}
|
||||
|
||||
Ok(scored_entries.into_iter().map(|(_, entry)| entry).collect())
|
||||
}
|
||||
|
||||
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
|
||||
let rows = sqlx::query_as::<_, MemoryRow>(
|
||||
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories WHERE uri LIKE ?"
|
||||
)
|
||||
.bind(format!("{}%", prefix))
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to find by prefix: {}", e)))?;
|
||||
|
||||
let entries = rows.iter().map(|row| self.row_to_entry(row)).collect();
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
async fn delete(&self, uri: &str) -> Result<()> {
|
||||
sqlx::query("DELETE FROM memories WHERE uri = ?")
|
||||
.bind(uri)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to delete memory: {}", e)))?;
|
||||
|
||||
// Remove from FTS
|
||||
let _ = sqlx::query("DELETE FROM memories_fts WHERE uri = ?")
|
||||
.bind(uri)
|
||||
.execute(&self.pool)
|
||||
.await;
|
||||
|
||||
// Remove from scorer
|
||||
let mut scorer = self.scorer.write().await;
|
||||
scorer.remove_entry(uri);
|
||||
|
||||
tracing::debug!("[SqliteStorage] Deleted memory: {}", uri);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_metadata_json(&self, key: &str, json: &str) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT OR REPLACE INTO metadata (key, json)
|
||||
VALUES (?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(key)
|
||||
.bind(json)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to store metadata: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_metadata_json(&self, key: &str) -> Result<Option<String>> {
|
||||
let result = sqlx::query_scalar::<_, String>("SELECT json FROM metadata WHERE key = ?")
|
||||
.bind(key)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(format!("Failed to get metadata: {}", e)))?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::MemoryType;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sqlite_storage_store_and_get() {
|
||||
let storage = SqliteStorage::in_memory().await;
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"User prefers concise responses".to_string(),
|
||||
);
|
||||
|
||||
storage.store(&entry).await.unwrap();
|
||||
let retrieved = storage.get(&entry.uri).await.unwrap();
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().content, "User prefers concise responses");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sqlite_storage_semantic_search() {
|
||||
let storage = SqliteStorage::in_memory().await;
|
||||
|
||||
// Store entries with different content
|
||||
let entry1 = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"rust",
|
||||
"Rust is a systems programming language focused on safety".to_string(),
|
||||
).with_keywords(vec!["rust".to_string(), "programming".to_string(), "safety".to_string()]);
|
||||
|
||||
let entry2 = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"python",
|
||||
"Python is a high-level programming language".to_string(),
|
||||
).with_keywords(vec!["python".to_string(), "programming".to_string()]);
|
||||
|
||||
storage.store(&entry1).await.unwrap();
|
||||
storage.store(&entry2).await.unwrap();
|
||||
|
||||
// Search for "rust safety"
|
||||
let results = storage.find(
|
||||
"rust safety",
|
||||
FindOptions {
|
||||
scope: Some("agent://agent-1".to_string()),
|
||||
limit: Some(10),
|
||||
min_similarity: Some(0.1),
|
||||
},
|
||||
).await.unwrap();
|
||||
|
||||
// Should find the Rust entry with higher score
|
||||
assert!(!results.is_empty());
|
||||
assert!(results[0].content.contains("Rust"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sqlite_storage_delete() {
|
||||
let storage = SqliteStorage::in_memory().await;
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
storage.store(&entry).await.unwrap();
|
||||
storage.delete(&entry.uri).await.unwrap();
|
||||
|
||||
let retrieved = storage.get(&entry.uri).await.unwrap();
|
||||
assert!(retrieved.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_persistence() {
|
||||
let path = std::env::temp_dir().join("zclaw_test_memories.db");
|
||||
|
||||
// Clean up any existing test db
|
||||
let _ = std::fs::remove_file(&path);
|
||||
|
||||
// Create and store
|
||||
{
|
||||
let storage = SqliteStorage::new(&path).await.unwrap();
|
||||
let entry = MemoryEntry::new(
|
||||
"persist-test",
|
||||
MemoryType::Knowledge,
|
||||
"test",
|
||||
"This should persist".to_string(),
|
||||
);
|
||||
storage.store(&entry).await.unwrap();
|
||||
}
|
||||
|
||||
// Reopen and verify
|
||||
{
|
||||
let storage = SqliteStorage::new(&path).await.unwrap();
|
||||
let results = storage.find_by_prefix("agent://persist-test").await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
assert_eq!(results[0].content, "This should persist");
|
||||
}
|
||||
|
||||
// Clean up
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metadata_storage() {
|
||||
let storage = SqliteStorage::in_memory().await;
|
||||
|
||||
let json = r#"{"test": "value"}"#;
|
||||
storage.store_metadata_json("test-key", json).await.unwrap();
|
||||
|
||||
let retrieved = storage.get_metadata_json("test-key").await.unwrap();
|
||||
assert_eq!(retrieved, Some(json.to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_access_count() {
|
||||
let storage = SqliteStorage::in_memory().await;
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Knowledge,
|
||||
"test",
|
||||
"test content".to_string(),
|
||||
);
|
||||
|
||||
storage.store(&entry).await.unwrap();
|
||||
|
||||
// Access multiple times
|
||||
for _ in 0..3 {
|
||||
let _ = storage.get(&entry.uri).await.unwrap();
|
||||
}
|
||||
|
||||
let retrieved = storage.get(&entry.uri).await.unwrap().unwrap();
|
||||
assert!(retrieved.access_count >= 3);
|
||||
}
|
||||
}
|
||||
212
crates/zclaw-growth/src/tracker.rs
Normal file
212
crates/zclaw-growth/src/tracker.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
//! Growth Tracker - Tracks agent growth metrics and evolution
|
||||
//!
|
||||
//! This module provides the `GrowthTracker` which monitors and records
|
||||
//! the evolution of an agent's capabilities and knowledge over time.
|
||||
|
||||
use crate::types::{GrowthStats, MemoryType};
|
||||
use crate::viking_adapter::VikingAdapter;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use zclaw_types::{AgentId, Result};
|
||||
|
||||
/// Growth Tracker - tracks agent growth metrics
|
||||
pub struct GrowthTracker {
|
||||
/// OpenViking adapter for storage
|
||||
viking: Arc<VikingAdapter>,
|
||||
}
|
||||
|
||||
impl GrowthTracker {
|
||||
/// Create a new growth tracker
|
||||
pub fn new(viking: Arc<VikingAdapter>) -> Self {
|
||||
Self { viking }
|
||||
}
|
||||
|
||||
/// Get current growth statistics for an agent
|
||||
pub async fn get_stats(&self, agent_id: &AgentId) -> Result<GrowthStats> {
|
||||
// Query all memories for the agent
|
||||
let memories = self.viking.find_by_prefix(&format!("agent://{}", agent_id)).await?;
|
||||
|
||||
let mut stats = GrowthStats::default();
|
||||
stats.total_memories = memories.len();
|
||||
|
||||
for memory in &memories {
|
||||
match memory.memory_type {
|
||||
MemoryType::Preference => stats.preference_count += 1,
|
||||
MemoryType::Knowledge => stats.knowledge_count += 1,
|
||||
MemoryType::Experience => stats.experience_count += 1,
|
||||
MemoryType::Session => stats.sessions_processed += 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Get last learning time from metadata
|
||||
let meta: Option<AgentMetadata> = self.viking
|
||||
.get_metadata(&format!("agent://{}", agent_id))
|
||||
.await?;
|
||||
|
||||
if let Some(meta) = meta {
|
||||
stats.last_learning_time = meta.last_learning_time;
|
||||
}
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
/// Record a learning event
|
||||
pub async fn record_learning(
|
||||
&self,
|
||||
agent_id: &AgentId,
|
||||
session_id: &str,
|
||||
memories_extracted: usize,
|
||||
) -> Result<()> {
|
||||
let event = LearningEvent {
|
||||
agent_id: agent_id.to_string(),
|
||||
session_id: session_id.to_string(),
|
||||
memories_extracted,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
// Store learning event
|
||||
self.viking
|
||||
.store_metadata(
|
||||
&format!("agent://{}/events/{}", agent_id, session_id),
|
||||
&event,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Update last learning time
|
||||
self.viking
|
||||
.store_metadata(
|
||||
&format!("agent://{}", agent_id),
|
||||
&AgentMetadata {
|
||||
last_learning_time: Some(Utc::now()),
|
||||
total_learning_events: None, // Will be computed
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
tracing::info!(
|
||||
"[GrowthTracker] Recorded learning event: agent={}, session={}, memories={}",
|
||||
agent_id,
|
||||
session_id,
|
||||
memories_extracted
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get growth timeline for an agent
|
||||
pub async fn get_timeline(&self, agent_id: &AgentId) -> Result<Vec<LearningEvent>> {
|
||||
let memories = self
|
||||
.viking
|
||||
.find_by_prefix(&format!("agent://{}/events/", agent_id))
|
||||
.await?;
|
||||
|
||||
// Parse events from stored memory content
|
||||
let mut timeline = Vec::new();
|
||||
for memory in memories {
|
||||
if let Ok(event) = serde_json::from_str::<LearningEvent>(&memory.content) {
|
||||
timeline.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by timestamp descending
|
||||
timeline.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
|
||||
|
||||
Ok(timeline)
|
||||
}
|
||||
|
||||
/// Calculate growth velocity (memories per day)
|
||||
pub async fn get_growth_velocity(&self, agent_id: &AgentId) -> Result<f64> {
|
||||
let timeline = self.get_timeline(agent_id).await?;
|
||||
|
||||
if timeline.is_empty() {
|
||||
return Ok(0.0);
|
||||
}
|
||||
|
||||
// Get first and last event
|
||||
let first = timeline.iter().min_by_key(|e| e.timestamp);
|
||||
let last = timeline.iter().max_by_key(|e| e.timestamp);
|
||||
|
||||
match (first, last) {
|
||||
(Some(first), Some(last)) => {
|
||||
let days = (last.timestamp - first.timestamp).num_days().max(1) as f64;
|
||||
let total_memories: usize = timeline.iter().map(|e| e.memories_extracted).sum();
|
||||
Ok(total_memories as f64 / days)
|
||||
}
|
||||
_ => Ok(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get memory distribution by category
|
||||
pub async fn get_memory_distribution(
|
||||
&self,
|
||||
agent_id: &AgentId,
|
||||
) -> Result<HashMap<String, usize>> {
|
||||
let memories = self.viking.find_by_prefix(&format!("agent://{}", agent_id)).await?;
|
||||
|
||||
let mut distribution = HashMap::new();
|
||||
for memory in memories {
|
||||
*distribution.entry(memory.memory_type.to_string()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
Ok(distribution)
|
||||
}
|
||||
}
|
||||
|
||||
/// Learning event record
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LearningEvent {
|
||||
/// Agent ID
|
||||
pub agent_id: String,
|
||||
/// Session ID where learning occurred
|
||||
pub session_id: String,
|
||||
/// Number of memories extracted
|
||||
pub memories_extracted: usize,
|
||||
/// Event timestamp
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Agent metadata stored in OpenViking
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentMetadata {
|
||||
/// Last learning time
|
||||
pub last_learning_time: Option<DateTime<Utc>>,
|
||||
/// Total learning events (computed)
|
||||
pub total_learning_events: Option<usize>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_learning_event_serialization() {
|
||||
let event = LearningEvent {
|
||||
agent_id: "test-agent".to_string(),
|
||||
session_id: "test-session".to_string(),
|
||||
memories_extracted: 5,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
let parsed: LearningEvent = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.agent_id, event.agent_id);
|
||||
assert_eq!(parsed.memories_extracted, event.memories_extracted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_metadata_serialization() {
|
||||
let meta = AgentMetadata {
|
||||
last_learning_time: Some(Utc::now()),
|
||||
total_learning_events: Some(10),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&meta).unwrap();
|
||||
let parsed: AgentMetadata = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert!(parsed.last_learning_time.is_some());
|
||||
assert_eq!(parsed.total_learning_events, Some(10));
|
||||
}
|
||||
}
|
||||
486
crates/zclaw-growth/src/types.rs
Normal file
486
crates/zclaw-growth/src/types.rs
Normal file
@@ -0,0 +1,486 @@
|
||||
//! Core type definitions for the ZCLAW Growth System
|
||||
//!
|
||||
//! This module defines the fundamental types used for memory management,
|
||||
//! extraction, retrieval, and prompt injection.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use zclaw_types::SessionId;
|
||||
|
||||
/// Memory type classification
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MemoryType {
|
||||
/// User preferences (communication style, format, language, etc.)
|
||||
Preference,
|
||||
/// Accumulated knowledge (user facts, domain knowledge, lessons learned)
|
||||
Knowledge,
|
||||
/// Skill/tool usage experience
|
||||
Experience,
|
||||
/// Conversation session history
|
||||
Session,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MemoryType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MemoryType::Preference => write!(f, "preferences"),
|
||||
MemoryType::Knowledge => write!(f, "knowledge"),
|
||||
MemoryType::Experience => write!(f, "experience"),
|
||||
MemoryType::Session => write!(f, "sessions"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for MemoryType {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"preferences" | "preference" => Ok(MemoryType::Preference),
|
||||
"knowledge" => Ok(MemoryType::Knowledge),
|
||||
"experience" => Ok(MemoryType::Experience),
|
||||
"sessions" | "session" => Ok(MemoryType::Session),
|
||||
_ => Err(format!("Unknown memory type: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryType {
|
||||
/// Parse memory type from string (returns Knowledge as default)
|
||||
pub fn parse(s: &str) -> Self {
|
||||
s.parse().unwrap_or(MemoryType::Knowledge)
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory entry stored in OpenViking
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryEntry {
|
||||
/// URI in OpenViking format: agent://{agent_id}/{type}/{category}
|
||||
pub uri: String,
|
||||
/// Type of memory
|
||||
pub memory_type: MemoryType,
|
||||
/// Memory content
|
||||
pub content: String,
|
||||
/// Keywords for semantic search
|
||||
pub keywords: Vec<String>,
|
||||
/// Importance score (1-10)
|
||||
pub importance: u8,
|
||||
/// Number of times accessed
|
||||
pub access_count: u32,
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// Last access timestamp
|
||||
pub last_accessed: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl MemoryEntry {
|
||||
/// Create a new memory entry
|
||||
pub fn new(
|
||||
agent_id: &str,
|
||||
memory_type: MemoryType,
|
||||
category: &str,
|
||||
content: String,
|
||||
) -> Self {
|
||||
let uri = format!("agent://{}/{}/{}", agent_id, memory_type, category);
|
||||
Self {
|
||||
uri,
|
||||
memory_type,
|
||||
content,
|
||||
keywords: Vec::new(),
|
||||
importance: 5,
|
||||
access_count: 0,
|
||||
created_at: Utc::now(),
|
||||
last_accessed: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add keywords to the memory entry
|
||||
pub fn with_keywords(mut self, keywords: Vec<String>) -> Self {
|
||||
self.keywords = keywords;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set importance score
|
||||
pub fn with_importance(mut self, importance: u8) -> Self {
|
||||
self.importance = importance.min(10).max(1);
|
||||
self
|
||||
}
|
||||
|
||||
/// Mark as accessed
|
||||
pub fn touch(&mut self) {
|
||||
self.access_count += 1;
|
||||
self.last_accessed = Utc::now();
|
||||
}
|
||||
|
||||
/// Estimate token count (roughly 4 characters per token for mixed content)
|
||||
/// More accurate estimation considering Chinese characters (1.5 tokens avg)
|
||||
pub fn estimated_tokens(&self) -> usize {
|
||||
let char_count = self.content.chars().count();
|
||||
let cjk_count = self.content.chars().filter(|c| is_cjk(*c)).count();
|
||||
let non_cjk_count = char_count - cjk_count;
|
||||
|
||||
// CJK: ~1.5 tokens per char, non-CJK: ~0.25 tokens per char
|
||||
(cjk_count as f32 * 1.5 + non_cjk_count as f32 * 0.25).ceil() as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracted memory from conversation analysis
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExtractedMemory {
|
||||
/// Type of extracted memory
|
||||
pub memory_type: MemoryType,
|
||||
/// Category within the memory type
|
||||
pub category: String,
|
||||
/// Memory content
|
||||
pub content: String,
|
||||
/// Extraction confidence (0.0 - 1.0)
|
||||
pub confidence: f32,
|
||||
/// Source session ID
|
||||
pub source_session: SessionId,
|
||||
/// Keywords extracted
|
||||
pub keywords: Vec<String>,
|
||||
}
|
||||
|
||||
impl ExtractedMemory {
|
||||
/// Create a new extracted memory
|
||||
pub fn new(
|
||||
memory_type: MemoryType,
|
||||
category: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
source_session: SessionId,
|
||||
) -> Self {
|
||||
Self {
|
||||
memory_type,
|
||||
category: category.into(),
|
||||
content: content.into(),
|
||||
confidence: 0.8,
|
||||
source_session,
|
||||
keywords: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set confidence score
|
||||
pub fn with_confidence(mut self, confidence: f32) -> Self {
|
||||
self.confidence = confidence.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add keywords
|
||||
pub fn with_keywords(mut self, keywords: Vec<String>) -> Self {
|
||||
self.keywords = keywords;
|
||||
self
|
||||
}
|
||||
|
||||
/// Convert to MemoryEntry for storage
|
||||
pub fn to_memory_entry(&self, agent_id: &str) -> MemoryEntry {
|
||||
MemoryEntry::new(agent_id, self.memory_type, &self.category, self.content.clone())
|
||||
.with_keywords(self.keywords.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieval configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalConfig {
|
||||
/// Total token budget for retrieved memories
|
||||
pub max_tokens: usize,
|
||||
/// Token budget for preferences
|
||||
pub preference_budget: usize,
|
||||
/// Token budget for knowledge
|
||||
pub knowledge_budget: usize,
|
||||
/// Token budget for experience
|
||||
pub experience_budget: usize,
|
||||
/// Minimum similarity threshold (0.0 - 1.0)
|
||||
pub min_similarity: f32,
|
||||
/// Maximum number of results per type
|
||||
pub max_results_per_type: usize,
|
||||
}
|
||||
|
||||
/// Check if character is CJK
|
||||
fn is_cjk(c: char) -> bool {
|
||||
matches!(c,
|
||||
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
|
||||
'\u{3400}'..='\u{4DBF}' | // CJK Unified Ideographs Extension A
|
||||
'\u{20000}'..='\u{2A6DF}' | // CJK Unified Ideographs Extension B
|
||||
'\u{F900}'..='\u{FAFF}' | // CJK Compatibility Ideographs
|
||||
'\u{3040}'..='\u{309F}' | // Hiragana
|
||||
'\u{30A0}'..='\u{30FF}' | // Katakana
|
||||
'\u{AC00}'..='\u{D7AF}' // Hangul
|
||||
)
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_tokens: 500,
|
||||
preference_budget: 200,
|
||||
knowledge_budget: 200,
|
||||
experience_budget: 100,
|
||||
min_similarity: 0.7,
|
||||
max_results_per_type: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetrievalConfig {
|
||||
/// Create a config with custom token budget
|
||||
pub fn with_budget(max_tokens: usize) -> Self {
|
||||
let pref = (max_tokens as f32 * 0.4) as usize;
|
||||
let knowledge = (max_tokens as f32 * 0.4) as usize;
|
||||
let exp = max_tokens.saturating_sub(pref).saturating_sub(knowledge);
|
||||
|
||||
Self {
|
||||
max_tokens,
|
||||
preference_budget: pref,
|
||||
knowledge_budget: knowledge,
|
||||
experience_budget: exp,
|
||||
min_similarity: 0.7,
|
||||
max_results_per_type: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieval result containing memories by type
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetrievalResult {
|
||||
/// Retrieved preferences
|
||||
pub preferences: Vec<MemoryEntry>,
|
||||
/// Retrieved knowledge
|
||||
pub knowledge: Vec<MemoryEntry>,
|
||||
/// Retrieved experience
|
||||
pub experience: Vec<MemoryEntry>,
|
||||
/// Total tokens used
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
impl RetrievalResult {
|
||||
/// Check if result is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.preferences.is_empty()
|
||||
&& self.knowledge.is_empty()
|
||||
&& self.experience.is_empty()
|
||||
}
|
||||
|
||||
/// Get total memory count
|
||||
pub fn total_count(&self) -> usize {
|
||||
self.preferences.len() + self.knowledge.len() + self.experience.len()
|
||||
}
|
||||
|
||||
/// Calculate total tokens from entries
|
||||
pub fn calculate_tokens(&self) -> usize {
|
||||
let tokens: usize = self.preferences.iter()
|
||||
.chain(self.knowledge.iter())
|
||||
.chain(self.experience.iter())
|
||||
.map(|m| m.estimated_tokens())
|
||||
.sum();
|
||||
tokens
|
||||
}
|
||||
}
|
||||
|
||||
/// Extraction configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExtractionConfig {
|
||||
/// Extract preferences from conversation
|
||||
pub extract_preferences: bool,
|
||||
/// Extract knowledge from conversation
|
||||
pub extract_knowledge: bool,
|
||||
/// Extract experience from conversation
|
||||
pub extract_experience: bool,
|
||||
/// Minimum confidence threshold for extraction
|
||||
pub min_confidence: f32,
|
||||
}
|
||||
|
||||
impl Default for ExtractionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
extract_preferences: true,
|
||||
extract_knowledge: true,
|
||||
extract_experience: true,
|
||||
min_confidence: 0.6,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Growth statistics for an agent
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct GrowthStats {
|
||||
/// Total number of memories
|
||||
pub total_memories: usize,
|
||||
/// Number of preferences
|
||||
pub preference_count: usize,
|
||||
/// Number of knowledge entries
|
||||
pub knowledge_count: usize,
|
||||
/// Number of experience entries
|
||||
pub experience_count: usize,
|
||||
/// Total sessions processed
|
||||
pub sessions_processed: usize,
|
||||
/// Last learning timestamp
|
||||
pub last_learning_time: Option<DateTime<Utc>>,
|
||||
/// Average extraction confidence
|
||||
pub avg_confidence: f32,
|
||||
}
|
||||
|
||||
/// OpenViking URI builder
|
||||
pub struct UriBuilder;
|
||||
|
||||
impl UriBuilder {
|
||||
/// Build a preference URI
|
||||
pub fn preference(agent_id: &str, category: &str) -> String {
|
||||
format!("agent://{}/preferences/{}", agent_id, category)
|
||||
}
|
||||
|
||||
/// Build a knowledge URI
|
||||
pub fn knowledge(agent_id: &str, domain: &str) -> String {
|
||||
format!("agent://{}/knowledge/{}", agent_id, domain)
|
||||
}
|
||||
|
||||
/// Build an experience URI
|
||||
pub fn experience(agent_id: &str, skill_id: &str) -> String {
|
||||
format!("agent://{}/experience/{}", agent_id, skill_id)
|
||||
}
|
||||
|
||||
/// Build a session URI
|
||||
pub fn session(agent_id: &str, session_id: &str) -> String {
|
||||
format!("agent://{}/sessions/{}", agent_id, session_id)
|
||||
}
|
||||
|
||||
/// Parse agent ID from URI
|
||||
pub fn parse_agent_id(uri: &str) -> Option<&str> {
|
||||
uri.strip_prefix("agent://")?
|
||||
.split('/')
|
||||
.next()
|
||||
}
|
||||
|
||||
/// Parse memory type from URI
|
||||
pub fn parse_memory_type(uri: &str) -> Option<MemoryType> {
|
||||
let after_agent = uri.strip_prefix("agent://")?;
|
||||
let mut parts = after_agent.split('/');
|
||||
parts.next()?; // Skip agent_id
|
||||
|
||||
match parts.next()? {
|
||||
"preferences" => Some(MemoryType::Preference),
|
||||
"knowledge" => Some(MemoryType::Knowledge),
|
||||
"experience" => Some(MemoryType::Experience),
|
||||
"sessions" => Some(MemoryType::Session),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_memory_type_display() {
|
||||
assert_eq!(format!("{}", MemoryType::Preference), "preferences");
|
||||
assert_eq!(format!("{}", MemoryType::Knowledge), "knowledge");
|
||||
assert_eq!(format!("{}", MemoryType::Experience), "experience");
|
||||
assert_eq!(format!("{}", MemoryType::Session), "sessions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_entry_creation() {
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"communication-style",
|
||||
"User prefers concise responses".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(entry.uri, "agent://test-agent/preferences/communication-style");
|
||||
assert_eq!(entry.importance, 5);
|
||||
assert_eq!(entry.access_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_entry_touch() {
|
||||
let mut entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Knowledge,
|
||||
"domain",
|
||||
"content".to_string(),
|
||||
);
|
||||
|
||||
entry.touch();
|
||||
assert_eq!(entry.access_count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimated_tokens() {
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"test",
|
||||
"This is a test content that should be around 10 tokens".to_string(),
|
||||
);
|
||||
|
||||
// ~40 chars / 4 = ~10 tokens
|
||||
assert!(entry.estimated_tokens() > 5);
|
||||
assert!(entry.estimated_tokens() < 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retrieval_config_default() {
|
||||
let config = RetrievalConfig::default();
|
||||
assert_eq!(config.max_tokens, 500);
|
||||
assert_eq!(config.preference_budget, 200);
|
||||
assert_eq!(config.knowledge_budget, 200);
|
||||
assert_eq!(config.experience_budget, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retrieval_config_with_budget() {
|
||||
let config = RetrievalConfig::with_budget(1000);
|
||||
assert_eq!(config.max_tokens, 1000);
|
||||
assert!(config.preference_budget >= 350);
|
||||
assert!(config.knowledge_budget >= 350);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uri_builder() {
|
||||
let pref_uri = UriBuilder::preference("agent-1", "style");
|
||||
assert_eq!(pref_uri, "agent://agent-1/preferences/style");
|
||||
|
||||
let knowledge_uri = UriBuilder::knowledge("agent-1", "rust");
|
||||
assert_eq!(knowledge_uri, "agent://agent-1/knowledge/rust");
|
||||
|
||||
let exp_uri = UriBuilder::experience("agent-1", "browser");
|
||||
assert_eq!(exp_uri, "agent://agent-1/experience/browser");
|
||||
|
||||
let session_uri = UriBuilder::session("agent-1", "session-123");
|
||||
assert_eq!(session_uri, "agent://agent-1/sessions/session-123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uri_parser() {
|
||||
let uri = "agent://agent-1/preferences/style";
|
||||
assert_eq!(UriBuilder::parse_agent_id(uri), Some("agent-1"));
|
||||
assert_eq!(UriBuilder::parse_memory_type(uri), Some(MemoryType::Preference));
|
||||
|
||||
let invalid_uri = "invalid-uri";
|
||||
assert!(UriBuilder::parse_agent_id(invalid_uri).is_none());
|
||||
assert!(UriBuilder::parse_memory_type(invalid_uri).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retrieval_result() {
|
||||
let result = RetrievalResult::default();
|
||||
assert!(result.is_empty());
|
||||
assert_eq!(result.total_count(), 0);
|
||||
|
||||
let result = RetrievalResult {
|
||||
preferences: vec![MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
)],
|
||||
knowledge: vec![],
|
||||
experience: vec![],
|
||||
total_tokens: 0,
|
||||
};
|
||||
assert!(!result.is_empty());
|
||||
assert_eq!(result.total_count(), 1);
|
||||
}
|
||||
}
|
||||
362
crates/zclaw-growth/src/viking_adapter.rs
Normal file
362
crates/zclaw-growth/src/viking_adapter.rs
Normal file
@@ -0,0 +1,362 @@
|
||||
//! OpenViking Adapter - Interface to the OpenViking memory system
|
||||
//!
|
||||
//! This module provides the `VikingAdapter` which wraps the OpenViking
|
||||
//! context database for storing and retrieving agent memories.
|
||||
|
||||
use crate::types::MemoryEntry;
|
||||
use async_trait::async_trait;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use zclaw_types::Result;
|
||||
|
||||
/// Search options for find operations
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FindOptions {
|
||||
/// Scope to search within (URI prefix)
|
||||
pub scope: Option<String>,
|
||||
/// Maximum results to return
|
||||
pub limit: Option<usize>,
|
||||
/// Minimum similarity threshold
|
||||
pub min_similarity: Option<f32>,
|
||||
}
|
||||
|
||||
/// VikingStorage trait - core storage operations (dyn-compatible)
|
||||
#[async_trait]
|
||||
pub trait VikingStorage: Send + Sync {
|
||||
/// Store a memory entry
|
||||
async fn store(&self, entry: &MemoryEntry) -> Result<()>;
|
||||
|
||||
/// Get a memory entry by URI
|
||||
async fn get(&self, uri: &str) -> Result<Option<MemoryEntry>>;
|
||||
|
||||
/// Find memories by query with options
|
||||
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Find memories by URI prefix
|
||||
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Delete a memory by URI
|
||||
async fn delete(&self, uri: &str) -> Result<()>;
|
||||
|
||||
/// Store metadata as JSON string
|
||||
async fn store_metadata_json(&self, key: &str, json: &str) -> Result<()>;
|
||||
|
||||
/// Get metadata as JSON string
|
||||
async fn get_metadata_json(&self, key: &str) -> Result<Option<String>>;
|
||||
}
|
||||
|
||||
/// OpenViking adapter implementation
|
||||
#[derive(Clone)]
|
||||
pub struct VikingAdapter {
|
||||
/// Storage backend
|
||||
backend: Arc<dyn VikingStorage>,
|
||||
}
|
||||
|
||||
impl VikingAdapter {
|
||||
/// Create a new Viking adapter with a storage backend
|
||||
pub fn new(backend: Arc<dyn VikingStorage>) -> Self {
|
||||
Self { backend }
|
||||
}
|
||||
|
||||
/// Create with in-memory storage (for testing)
|
||||
pub fn in_memory() -> Self {
|
||||
Self {
|
||||
backend: Arc::new(InMemoryStorage::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Store a memory entry
|
||||
pub async fn store(&self, entry: &MemoryEntry) -> Result<()> {
|
||||
self.backend.store(entry).await
|
||||
}
|
||||
|
||||
/// Get a memory entry by URI
|
||||
pub async fn get(&self, uri: &str) -> Result<Option<MemoryEntry>> {
|
||||
self.backend.get(uri).await
|
||||
}
|
||||
|
||||
/// Find memories by query
|
||||
pub async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
|
||||
self.backend.find(query, options).await
|
||||
}
|
||||
|
||||
/// Find memories by URI prefix
|
||||
pub async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
|
||||
self.backend.find_by_prefix(prefix).await
|
||||
}
|
||||
|
||||
/// Delete a memory
|
||||
pub async fn delete(&self, uri: &str) -> Result<()> {
|
||||
self.backend.delete(uri).await
|
||||
}
|
||||
|
||||
/// Store metadata (typed)
|
||||
pub async fn store_metadata<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
|
||||
let json = serde_json::to_string(value)?;
|
||||
self.backend.store_metadata_json(key, &json).await
|
||||
}
|
||||
|
||||
/// Get metadata (typed)
|
||||
pub async fn get_metadata<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
|
||||
match self.backend.get_metadata_json(key).await? {
|
||||
Some(json) => {
|
||||
let value: T = serde_json::from_str(&json)?;
|
||||
Ok(Some(value))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory storage backend (for testing and development)
|
||||
pub struct InMemoryStorage {
|
||||
memories: std::sync::RwLock<HashMap<String, MemoryEntry>>,
|
||||
metadata: std::sync::RwLock<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl InMemoryStorage {
|
||||
/// Create a new in-memory storage
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
memories: std::sync::RwLock::new(HashMap::new()),
|
||||
metadata: std::sync::RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemoryStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VikingStorage for InMemoryStorage {
|
||||
async fn store(&self, entry: &MemoryEntry) -> Result<()> {
|
||||
let mut memories = self.memories.write().unwrap();
|
||||
memories.insert(entry.uri.clone(), entry.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, uri: &str) -> Result<Option<MemoryEntry>> {
|
||||
let memories = self.memories.read().unwrap();
|
||||
Ok(memories.get(uri).cloned())
|
||||
}
|
||||
|
||||
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
|
||||
let memories = self.memories.read().unwrap();
|
||||
|
||||
let mut results: Vec<MemoryEntry> = memories
|
||||
.values()
|
||||
.filter(|entry| {
|
||||
// Apply scope filter
|
||||
if let Some(ref scope) = options.scope {
|
||||
if !entry.uri.starts_with(scope) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Simple text matching (in real implementation, use semantic search)
|
||||
if !query.is_empty() {
|
||||
let query_lower = query.to_lowercase();
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let keywords_match = entry.keywords.iter().any(|k| k.to_lowercase().contains(&query_lower));
|
||||
|
||||
content_lower.contains(&query_lower) || keywords_match
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Sort by importance and access count
|
||||
results.sort_by(|a, b| {
|
||||
b.importance
|
||||
.cmp(&a.importance)
|
||||
.then_with(|| b.access_count.cmp(&a.access_count))
|
||||
});
|
||||
|
||||
// Apply limit
|
||||
if let Some(limit) = options.limit {
|
||||
results.truncate(limit);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
|
||||
let memories = self.memories.read().unwrap();
|
||||
|
||||
let results: Vec<MemoryEntry> = memories
|
||||
.values()
|
||||
.filter(|entry| entry.uri.starts_with(prefix))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn delete(&self, uri: &str) -> Result<()> {
|
||||
let mut memories = self.memories.write().unwrap();
|
||||
memories.remove(uri);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_metadata_json(&self, key: &str, json: &str) -> Result<()> {
|
||||
let mut metadata = self.metadata.write().unwrap();
|
||||
metadata.insert(key.to_string(), json.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_metadata_json(&self, key: &str) -> Result<Option<String>> {
|
||||
let metadata = self.metadata.read().unwrap();
|
||||
Ok(metadata.get(key).cloned())
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenViking levels for storage
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VikingLevel {
|
||||
/// L0: Raw data (original content)
|
||||
L0,
|
||||
/// L1: Summarized content
|
||||
L1,
|
||||
/// L2: Keywords and metadata
|
||||
L2,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VikingLevel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
VikingLevel::L0 => write!(f, "L0"),
|
||||
VikingLevel::L1 => write!(f, "L1"),
|
||||
VikingLevel::L2 => write!(f, "L2"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::MemoryType;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_in_memory_storage_store_and_get() {
|
||||
let storage = InMemoryStorage::new();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test content".to_string(),
|
||||
);
|
||||
|
||||
storage.store(&entry).await.unwrap();
|
||||
let retrieved = storage.get(&entry.uri).await.unwrap();
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().content, "test content");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_in_memory_storage_find() {
|
||||
let storage = InMemoryStorage::new();
|
||||
|
||||
let entry1 = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"rust",
|
||||
"Rust programming tips".to_string(),
|
||||
);
|
||||
let entry2 = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"python",
|
||||
"Python programming tips".to_string(),
|
||||
);
|
||||
|
||||
storage.store(&entry1).await.unwrap();
|
||||
storage.store(&entry2).await.unwrap();
|
||||
|
||||
let results = storage
|
||||
.find(
|
||||
"Rust",
|
||||
FindOptions {
|
||||
scope: Some("agent://agent-1".to_string()),
|
||||
limit: Some(10),
|
||||
min_similarity: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results[0].content.contains("Rust"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_in_memory_storage_delete() {
|
||||
let storage = InMemoryStorage::new();
|
||||
let entry = MemoryEntry::new(
|
||||
"test-agent",
|
||||
MemoryType::Preference,
|
||||
"style",
|
||||
"test".to_string(),
|
||||
);
|
||||
|
||||
storage.store(&entry).await.unwrap();
|
||||
storage.delete(&entry.uri).await.unwrap();
|
||||
|
||||
let retrieved = storage.get(&entry.uri).await.unwrap();
|
||||
assert!(retrieved.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metadata_storage() {
|
||||
let storage = InMemoryStorage::new();
|
||||
|
||||
#[derive(Serialize, serde::Deserialize)]
|
||||
struct TestData {
|
||||
value: String,
|
||||
}
|
||||
|
||||
let data = TestData {
|
||||
value: "test".to_string(),
|
||||
};
|
||||
|
||||
storage.store_metadata_json("test-key", &serde_json::to_string(&data).unwrap()).await.unwrap();
|
||||
let json = storage.get_metadata_json("test-key").await.unwrap();
|
||||
|
||||
assert!(json.is_some());
|
||||
let retrieved: TestData = serde_json::from_str(&json.unwrap()).unwrap();
|
||||
assert_eq!(retrieved.value, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_viking_adapter_typed_metadata() {
|
||||
let adapter = VikingAdapter::in_memory();
|
||||
|
||||
#[derive(Serialize, serde::Deserialize)]
|
||||
struct TestData {
|
||||
value: String,
|
||||
}
|
||||
|
||||
let data = TestData {
|
||||
value: "test".to_string(),
|
||||
};
|
||||
|
||||
adapter.store_metadata("test-key", &data).await.unwrap();
|
||||
let retrieved: Option<TestData> = adapter.get_metadata("test-key").await.unwrap();
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().value, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_viking_level_display() {
|
||||
assert_eq!(format!("{}", VikingLevel::L0), "L0");
|
||||
assert_eq!(format!("{}", VikingLevel::L1), "L1");
|
||||
assert_eq!(format!("{}", VikingLevel::L2), "L2");
|
||||
}
|
||||
}
|
||||
412
crates/zclaw-growth/tests/integration_test.rs
Normal file
412
crates/zclaw-growth/tests/integration_test.rs
Normal file
@@ -0,0 +1,412 @@
|
||||
//! Integration tests for ZCLAW Growth System
|
||||
//!
|
||||
//! Tests the complete flow: store → find → inject
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_growth::{
|
||||
FindOptions, MemoryEntry, MemoryRetriever, MemoryType, PromptInjector,
|
||||
RetrievalConfig, RetrievalResult, SqliteStorage, VikingAdapter,
|
||||
};
|
||||
use zclaw_types::AgentId;
|
||||
|
||||
/// Test complete memory lifecycle
|
||||
#[tokio::test]
|
||||
async fn test_memory_lifecycle() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
// Create agent ID and use its string form for storage
|
||||
let agent_id = AgentId::new();
|
||||
let agent_str = agent_id.to_string();
|
||||
|
||||
// 1. Store a preference
|
||||
let pref = MemoryEntry::new(
|
||||
&agent_str,
|
||||
MemoryType::Preference,
|
||||
"communication-style",
|
||||
"用户偏好简洁的回复,不喜欢冗长的解释".to_string(),
|
||||
)
|
||||
.with_keywords(vec!["简洁".to_string(), "沟通风格".to_string()])
|
||||
.with_importance(8);
|
||||
|
||||
adapter.store(&pref).await.unwrap();
|
||||
|
||||
// 2. Store knowledge
|
||||
let knowledge = MemoryEntry::new(
|
||||
&agent_str,
|
||||
MemoryType::Knowledge,
|
||||
"rust-expertise",
|
||||
"用户是 Rust 开发者,熟悉 async/await 和 trait 系统".to_string(),
|
||||
)
|
||||
.with_keywords(vec!["Rust".to_string(), "开发者".to_string()]);
|
||||
|
||||
adapter.store(&knowledge).await.unwrap();
|
||||
|
||||
// 3. Store experience
|
||||
let experience = MemoryEntry::new(
|
||||
&agent_str,
|
||||
MemoryType::Experience,
|
||||
"browser-skill",
|
||||
"浏览器技能在搜索技术文档时效果很好".to_string(),
|
||||
)
|
||||
.with_keywords(vec!["浏览器".to_string(), "技能".to_string()]);
|
||||
|
||||
adapter.store(&experience).await.unwrap();
|
||||
|
||||
// 4. Retrieve memories - directly from adapter first
|
||||
let direct_results = adapter
|
||||
.find(
|
||||
"Rust",
|
||||
FindOptions {
|
||||
scope: Some(format!("agent://{}", agent_str)),
|
||||
limit: Some(10),
|
||||
min_similarity: Some(0.1),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
println!("Direct find results: {:?}", direct_results.len());
|
||||
|
||||
let retriever = MemoryRetriever::new(adapter.clone());
|
||||
// Use lower similarity threshold for testing
|
||||
let config = RetrievalConfig {
|
||||
min_similarity: 0.1,
|
||||
..RetrievalConfig::default()
|
||||
};
|
||||
let retriever = retriever.with_config(config);
|
||||
let result = retriever
|
||||
.retrieve(&agent_id, "Rust 编程")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("Knowledge results: {:?}", result.knowledge.len());
|
||||
println!("Preferences results: {:?}", result.preferences.len());
|
||||
println!("Experience results: {:?}", result.experience.len());
|
||||
|
||||
// Should find the knowledge entry
|
||||
assert!(!result.knowledge.is_empty(), "Expected to find knowledge entries but found none. Direct results: {}", direct_results.len());
|
||||
assert!(result.knowledge[0].content.contains("Rust"));
|
||||
|
||||
// 5. Inject into prompt
|
||||
let injector = PromptInjector::new();
|
||||
let base_prompt = "你是一个有帮助的 AI 助手。";
|
||||
let enhanced = injector.inject_with_format(base_prompt, &result);
|
||||
|
||||
// Enhanced prompt should contain memory context
|
||||
assert!(enhanced.len() > base_prompt.len());
|
||||
}
|
||||
|
||||
/// Test semantic search ranking
|
||||
#[tokio::test]
|
||||
async fn test_semantic_search_ranking() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage.clone()));
|
||||
|
||||
// Store multiple entries with different relevance
|
||||
let entries = vec![
|
||||
MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"rust-basics",
|
||||
"Rust 是一门系统编程语言,注重安全性和性能".to_string(),
|
||||
)
|
||||
.with_keywords(vec!["Rust".to_string(), "系统编程".to_string()]),
|
||||
MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"python-basics",
|
||||
"Python 是一门高级编程语言,易于学习".to_string(),
|
||||
)
|
||||
.with_keywords(vec!["Python".to_string(), "高级语言".to_string()]),
|
||||
MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"rust-async",
|
||||
"Rust 的 async/await 语法用于异步编程".to_string(),
|
||||
)
|
||||
.with_keywords(vec!["Rust".to_string(), "async".to_string(), "异步".to_string()]),
|
||||
];
|
||||
|
||||
for entry in &entries {
|
||||
adapter.store(entry).await.unwrap();
|
||||
}
|
||||
|
||||
// Search for "Rust 异步编程"
|
||||
let results = adapter
|
||||
.find(
|
||||
"Rust 异步编程",
|
||||
FindOptions {
|
||||
scope: Some("agent://agent-1".to_string()),
|
||||
limit: Some(10),
|
||||
min_similarity: Some(0.1),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Rust async entry should rank highest
|
||||
assert!(!results.is_empty());
|
||||
assert!(results[0].content.contains("async") || results[0].content.contains("Rust"));
|
||||
}
|
||||
|
||||
/// Test memory importance and access count
|
||||
#[tokio::test]
|
||||
async fn test_importance_and_access() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage.clone()));
|
||||
|
||||
// Create entries with different importance
|
||||
let high_importance = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Preference,
|
||||
"critical",
|
||||
"这是非常重要的偏好".to_string(),
|
||||
)
|
||||
.with_importance(10);
|
||||
|
||||
let low_importance = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Preference,
|
||||
"minor",
|
||||
"这是不太重要的偏好".to_string(),
|
||||
)
|
||||
.with_importance(2);
|
||||
|
||||
adapter.store(&high_importance).await.unwrap();
|
||||
adapter.store(&low_importance).await.unwrap();
|
||||
|
||||
// Access the low importance one multiple times
|
||||
for _ in 0..5 {
|
||||
let _ = adapter.get(&low_importance.uri).await;
|
||||
}
|
||||
|
||||
// Search should consider both importance and access count
|
||||
let results = adapter
|
||||
.find(
|
||||
"偏好",
|
||||
FindOptions {
|
||||
scope: Some("agent://agent-1".to_string()),
|
||||
limit: Some(10),
|
||||
min_similarity: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
|
||||
/// Test prompt injection with token budget
|
||||
#[tokio::test]
|
||||
async fn test_prompt_injection_token_budget() {
|
||||
let mut result = RetrievalResult::default();
|
||||
|
||||
// Add memories that exceed budget
|
||||
for i in 0..10 {
|
||||
result.preferences.push(
|
||||
MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Preference,
|
||||
&format!("pref-{}", i),
|
||||
"这是一个很长的偏好描述,用于测试 token 预算控制功能。".repeat(5),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
result.total_tokens = result.calculate_tokens();
|
||||
|
||||
// Budget is 500 tokens by default
|
||||
let injector = PromptInjector::new();
|
||||
let base = "Base prompt";
|
||||
let enhanced = injector.inject_with_format(base, &result);
|
||||
|
||||
// Should include memory context
|
||||
assert!(enhanced.len() > base.len());
|
||||
}
|
||||
|
||||
/// Test metadata storage
|
||||
#[tokio::test]
|
||||
async fn test_metadata_operations() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
// Store metadata using typed API
|
||||
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
|
||||
struct Config {
|
||||
version: String,
|
||||
auto_extract: bool,
|
||||
}
|
||||
|
||||
let config = Config {
|
||||
version: "1.0.0".to_string(),
|
||||
auto_extract: true,
|
||||
};
|
||||
|
||||
adapter.store_metadata("agent-config", &config).await.unwrap();
|
||||
|
||||
// Retrieve metadata
|
||||
let retrieved: Option<Config> = adapter.get_metadata("agent-config").await.unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let parsed = retrieved.unwrap();
|
||||
assert_eq!(parsed.version, "1.0.0");
|
||||
assert_eq!(parsed.auto_extract, true);
|
||||
}
|
||||
|
||||
/// Test memory deletion and cleanup
|
||||
#[tokio::test]
|
||||
async fn test_memory_deletion() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
let entry = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"temp",
|
||||
"Temporary knowledge".to_string(),
|
||||
);
|
||||
|
||||
adapter.store(&entry).await.unwrap();
|
||||
|
||||
// Verify stored
|
||||
let retrieved = adapter.get(&entry.uri).await.unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
// Delete
|
||||
adapter.delete(&entry.uri).await.unwrap();
|
||||
|
||||
// Verify deleted
|
||||
let retrieved = adapter.get(&entry.uri).await.unwrap();
|
||||
assert!(retrieved.is_none());
|
||||
|
||||
// Verify not in search results
|
||||
let results = adapter
|
||||
.find(
|
||||
"Temporary",
|
||||
FindOptions {
|
||||
scope: Some("agent://agent-1".to_string()),
|
||||
limit: Some(10),
|
||||
min_similarity: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
/// Test cross-agent isolation
|
||||
#[tokio::test]
|
||||
async fn test_agent_isolation() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
// Store memories for different agents
|
||||
let agent1_memory = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
"secret",
|
||||
"Agent 1 的秘密信息".to_string(),
|
||||
);
|
||||
|
||||
let agent2_memory = MemoryEntry::new(
|
||||
"agent-2",
|
||||
MemoryType::Knowledge,
|
||||
"secret",
|
||||
"Agent 2 的秘密信息".to_string(),
|
||||
);
|
||||
|
||||
adapter.store(&agent1_memory).await.unwrap();
|
||||
adapter.store(&agent2_memory).await.unwrap();
|
||||
|
||||
// Agent 1 should only see its own memories
|
||||
let results = adapter
|
||||
.find(
|
||||
"秘密",
|
||||
FindOptions {
|
||||
scope: Some("agent://agent-1".to_string()),
|
||||
limit: Some(10),
|
||||
min_similarity: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results[0].content.contains("Agent 1"));
|
||||
|
||||
// Agent 2 should only see its own memories
|
||||
let results = adapter
|
||||
.find(
|
||||
"秘密",
|
||||
FindOptions {
|
||||
scope: Some("agent://agent-2".to_string()),
|
||||
limit: Some(10),
|
||||
min_similarity: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results[0].content.contains("Agent 2"));
|
||||
}
|
||||
|
||||
/// Test Chinese text handling
|
||||
#[tokio::test]
|
||||
async fn test_chinese_text_handling() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
let entry = MemoryEntry::new(
|
||||
"中文测试",
|
||||
MemoryType::Knowledge,
|
||||
"中文知识",
|
||||
"这是一个中文测试,包含关键词:人工智能、机器学习、深度学习。".to_string(),
|
||||
)
|
||||
.with_keywords(vec!["人工智能".to_string(), "机器学习".to_string()]);
|
||||
|
||||
adapter.store(&entry).await.unwrap();
|
||||
|
||||
// Search with Chinese query
|
||||
let results = adapter
|
||||
.find(
|
||||
"人工智能",
|
||||
FindOptions {
|
||||
scope: Some("agent://中文测试".to_string()),
|
||||
limit: Some(10),
|
||||
min_similarity: Some(0.1),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!results.is_empty());
|
||||
assert!(results[0].content.contains("人工智能"));
|
||||
}
|
||||
|
||||
/// Test find by prefix
|
||||
#[tokio::test]
|
||||
async fn test_find_by_prefix() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
// Store multiple entries under same agent
|
||||
for i in 0..5 {
|
||||
let entry = MemoryEntry::new(
|
||||
"agent-1",
|
||||
MemoryType::Knowledge,
|
||||
&format!("topic-{}", i),
|
||||
format!("Content for topic {}", i),
|
||||
);
|
||||
adapter.store(&entry).await.unwrap();
|
||||
}
|
||||
|
||||
// Find all entries for agent-1
|
||||
let results = adapter
|
||||
.find_by_prefix("agent://agent-1")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 5);
|
||||
}
|
||||
Reference in New Issue
Block a user