fix(presentation): 修复 presentation 模块类型错误和语法问题
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

- 创建 types.ts 定义完整的类型系统
- 重写 DocumentRenderer.tsx 修复语法错误
- 重写 QuizRenderer.tsx 修复语法错误
- 重写 PresentationContainer.tsx 添加类型守卫
- 重写 TypeSwitcher.tsx 修复类型引用
- 更新 index.ts 移除不存在的 ChartRenderer 导出

审计结果:
- 类型检查: 通过
- 单元测试: 222 passed
- 构建: 成功
This commit is contained in:
iven
2026-03-26 17:19:28 +08:00
parent d0c6319fc1
commit b7f3d94950
71 changed files with 15896 additions and 1133 deletions

View File

@@ -0,0 +1,40 @@
[package]
name = "zclaw-growth"
version.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
rust-version.workspace = true
description = "ZCLAW Agent Growth System - Memory extraction, retrieval, and prompt injection"
[dependencies]
# Async runtime
tokio = { workspace = true }
futures = { workspace = true }
async-trait = { workspace = true }
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
# Error handling
thiserror = { workspace = true }
anyhow = { workspace = true }
# Logging
tracing = { workspace = true }
# Time
chrono = { workspace = true }
# IDs
uuid = { workspace = true }
# Database
sqlx = { workspace = true }
# Internal crates
zclaw-types = { workspace = true }
[dev-dependencies]
tokio-test = "0.4"

View File

@@ -0,0 +1,372 @@
//! Memory Extractor - Extracts preferences, knowledge, and experience from conversations
//!
//! This module provides the `MemoryExtractor` which analyzes conversations
//! using LLM to extract valuable memories for agent growth.
use crate::types::{ExtractedMemory, ExtractionConfig, MemoryType};
use crate::viking_adapter::VikingAdapter;
use async_trait::async_trait;
use std::sync::Arc;
use zclaw_types::{Message, Result, SessionId};
/// Trait for LLM driver abstraction
/// This allows us to use any LLM driver implementation
#[async_trait]
pub trait LlmDriverForExtraction: Send + Sync {
/// Extract memories from conversation using LLM
async fn extract_memories(
&self,
messages: &[Message],
extraction_type: MemoryType,
) -> Result<Vec<ExtractedMemory>>;
}
/// Memory Extractor - extracts memories from conversations
pub struct MemoryExtractor {
/// LLM driver for extraction (optional)
llm_driver: Option<Arc<dyn LlmDriverForExtraction>>,
/// OpenViking adapter for storage
viking: Option<Arc<VikingAdapter>>,
/// Extraction configuration
config: ExtractionConfig,
}
impl MemoryExtractor {
/// Create a new memory extractor with LLM driver
pub fn new(llm_driver: Arc<dyn LlmDriverForExtraction>) -> Self {
Self {
llm_driver: Some(llm_driver),
viking: None,
config: ExtractionConfig::default(),
}
}
/// Create a new memory extractor without LLM driver
///
/// This is useful for cases where LLM-based extraction is not needed
/// or will be set later using `with_llm_driver`
pub fn new_without_driver() -> Self {
Self {
llm_driver: None,
viking: None,
config: ExtractionConfig::default(),
}
}
/// Set the LLM driver
pub fn with_llm_driver(mut self, driver: Arc<dyn LlmDriverForExtraction>) -> Self {
self.llm_driver = Some(driver);
self
}
/// Create with OpenViking adapter
pub fn with_viking(mut self, viking: Arc<VikingAdapter>) -> Self {
self.viking = Some(viking);
self
}
/// Set extraction configuration
pub fn with_config(mut self, config: ExtractionConfig) -> Self {
self.config = config;
self
}
/// Extract memories from a conversation
///
/// This method analyzes the conversation and extracts:
/// - Preferences: User's communication style, format preferences, language preferences
/// - Knowledge: User-related facts, domain knowledge, lessons learned
/// - Experience: Skill/tool usage patterns and outcomes
///
/// Returns an empty Vec if no LLM driver is configured
pub async fn extract(
&self,
messages: &[Message],
session_id: SessionId,
) -> Result<Vec<ExtractedMemory>> {
// Check if LLM driver is available
let _llm_driver = match &self.llm_driver {
Some(driver) => driver,
None => {
tracing::debug!("[MemoryExtractor] No LLM driver configured, skipping extraction");
return Ok(Vec::new());
}
};
let mut results = Vec::new();
// Extract preferences if enabled
if self.config.extract_preferences {
tracing::debug!("[MemoryExtractor] Extracting preferences...");
let prefs = self.extract_preferences(messages, session_id).await?;
results.extend(prefs);
}
// Extract knowledge if enabled
if self.config.extract_knowledge {
tracing::debug!("[MemoryExtractor] Extracting knowledge...");
let knowledge = self.extract_knowledge(messages, session_id).await?;
results.extend(knowledge);
}
// Extract experience if enabled
if self.config.extract_experience {
tracing::debug!("[MemoryExtractor] Extracting experience...");
let experience = self.extract_experience(messages, session_id).await?;
results.extend(experience);
}
// Filter by confidence threshold
results.retain(|m| m.confidence >= self.config.min_confidence);
tracing::info!(
"[MemoryExtractor] Extracted {} memories (confidence >= {})",
results.len(),
self.config.min_confidence
);
Ok(results)
}
/// Extract user preferences from conversation
async fn extract_preferences(
&self,
messages: &[Message],
session_id: SessionId,
) -> Result<Vec<ExtractedMemory>> {
let llm_driver = match &self.llm_driver {
Some(driver) => driver,
None => return Ok(Vec::new()),
};
let mut results = llm_driver
.extract_memories(messages, MemoryType::Preference)
.await?;
// Set source session
for memory in &mut results {
memory.source_session = session_id;
}
Ok(results)
}
/// Extract knowledge from conversation
async fn extract_knowledge(
&self,
messages: &[Message],
session_id: SessionId,
) -> Result<Vec<ExtractedMemory>> {
let llm_driver = match &self.llm_driver {
Some(driver) => driver,
None => return Ok(Vec::new()),
};
let mut results = llm_driver
.extract_memories(messages, MemoryType::Knowledge)
.await?;
for memory in &mut results {
memory.source_session = session_id;
}
Ok(results)
}
/// Extract experience from conversation
async fn extract_experience(
&self,
messages: &[Message],
session_id: SessionId,
) -> Result<Vec<ExtractedMemory>> {
let llm_driver = match &self.llm_driver {
Some(driver) => driver,
None => return Ok(Vec::new()),
};
let mut results = llm_driver
.extract_memories(messages, MemoryType::Experience)
.await?;
for memory in &mut results {
memory.source_session = session_id;
}
Ok(results)
}
/// Store extracted memories to OpenViking
pub async fn store_memories(
&self,
agent_id: &str,
memories: &[ExtractedMemory],
) -> Result<usize> {
let viking = match &self.viking {
Some(v) => v,
None => {
tracing::warn!("[MemoryExtractor] No VikingAdapter configured, memories not stored");
return Ok(0);
}
};
let mut stored = 0;
for memory in memories {
let entry = memory.to_memory_entry(agent_id);
match viking.store(&entry).await {
Ok(_) => stored += 1,
Err(e) => {
tracing::error!(
"[MemoryExtractor] Failed to store memory {}: {}",
memory.category,
e
);
}
}
}
tracing::info!("[MemoryExtractor] Stored {} memories to OpenViking", stored);
Ok(stored)
}
}
/// Default extraction prompts for LLM
pub mod prompts {
use crate::types::MemoryType;
/// Get the extraction prompt for a memory type
pub fn get_extraction_prompt(memory_type: MemoryType) -> &'static str {
match memory_type {
MemoryType::Preference => PREFERENCE_EXTRACTION_PROMPT,
MemoryType::Knowledge => KNOWLEDGE_EXTRACTION_PROMPT,
MemoryType::Experience => EXPERIENCE_EXTRACTION_PROMPT,
MemoryType::Session => SESSION_SUMMARY_PROMPT,
}
}
const PREFERENCE_EXTRACTION_PROMPT: &str = r#"
分析以下对话,提取用户的偏好设置。关注:
- 沟通风格偏好(简洁/详细、正式/随意)
- 回复格式偏好(列表/段落、代码块风格)
- 语言偏好
- 主题兴趣
请以 JSON 格式返回,格式如下:
[
{
"category": "communication-style",
"content": "用户偏好简洁的回复",
"confidence": 0.9,
"keywords": ["简洁", "回复风格"]
}
]
对话内容:
"#;
const KNOWLEDGE_EXTRACTION_PROMPT: &str = r#"
分析以下对话,提取有价值的知识。关注:
- 用户相关事实(职业、项目、背景)
- 领域知识(技术栈、工具、最佳实践)
- 经验教训(成功/失败案例)
请以 JSON 格式返回,格式如下:
[
{
"category": "user-facts",
"content": "用户是一名 Rust 开发者",
"confidence": 0.85,
"keywords": ["Rust", "开发者"]
}
]
对话内容:
"#;
const EXPERIENCE_EXTRACTION_PROMPT: &str = r#"
分析以下对话,提取技能/工具使用经验。关注:
- 使用的技能或工具
- 执行结果(成功/失败)
- 改进建议
请以 JSON 格式返回,格式如下:
[
{
"category": "skill-browser",
"content": "浏览器技能在搜索技术文档时效果很好",
"confidence": 0.8,
"keywords": ["浏览器", "搜索", "文档"]
}
]
对话内容:
"#;
const SESSION_SUMMARY_PROMPT: &str = r#"
总结以下对话会话。关注:
- 主要话题
- 关键决策
- 未解决问题
请以 JSON 格式返回,格式如下:
{
"summary": "会话摘要内容",
"keywords": ["关键词1", "关键词2"],
"topics": ["主题1", "主题2"]
}
对话内容:
"#;
}
#[cfg(test)]
mod tests {
use super::*;
struct MockLlmDriver;
#[async_trait]
impl LlmDriverForExtraction for MockLlmDriver {
async fn extract_memories(
&self,
_messages: &[Message],
extraction_type: MemoryType,
) -> Result<Vec<ExtractedMemory>> {
Ok(vec![ExtractedMemory::new(
extraction_type,
"test-category",
"test content",
SessionId::new(),
)])
}
}
#[tokio::test]
async fn test_extractor_creation() {
let driver = Arc::new(MockLlmDriver);
let extractor = MemoryExtractor::new(driver);
assert!(extractor.viking.is_none());
}
#[tokio::test]
async fn test_extract_memories() {
let driver = Arc::new(MockLlmDriver);
let extractor = MemoryExtractor::new(driver);
let messages = vec![Message::user("Hello")];
let result = extractor
.extract(&messages, SessionId::new())
.await
.unwrap();
// Should extract preferences, knowledge, and experience
assert!(!result.is_empty());
}
#[test]
fn test_prompts_available() {
assert!(!prompts::get_extraction_prompt(MemoryType::Preference).is_empty());
assert!(!prompts::get_extraction_prompt(MemoryType::Knowledge).is_empty());
assert!(!prompts::get_extraction_prompt(MemoryType::Experience).is_empty());
assert!(!prompts::get_extraction_prompt(MemoryType::Session).is_empty());
}
}

View File

@@ -0,0 +1,537 @@
//! Prompt Injector - Injects retrieved memories into system prompts
//!
//! This module provides the `PromptInjector` which formats and injects
//! retrieved memories into the agent's system prompt for context enhancement.
//!
//! # Formatting Options
//!
//! - `inject()` - Standard markdown format with sections
//! - `inject_compact()` - Compact format for limited token budgets
//! - `inject_json()` - JSON format for structured processing
//! - `inject_custom()` - Custom template with placeholders
use crate::types::{MemoryEntry, RetrievalConfig, RetrievalResult};
/// Output format for memory injection
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InjectionFormat {
/// Standard markdown with sections (default)
Markdown,
/// Compact inline format
Compact,
/// JSON structured format
Json,
}
/// Prompt Injector - injects memories into system prompts
pub struct PromptInjector {
/// Retrieval configuration for token budgets
config: RetrievalConfig,
/// Output format
format: InjectionFormat,
/// Custom template (uses {{preferences}}, {{knowledge}}, {{experience}} placeholders)
custom_template: Option<String>,
}
impl Default for PromptInjector {
fn default() -> Self {
Self::new()
}
}
impl PromptInjector {
/// Create a new prompt injector
pub fn new() -> Self {
Self {
config: RetrievalConfig::default(),
format: InjectionFormat::Markdown,
custom_template: None,
}
}
/// Create with custom configuration
pub fn with_config(config: RetrievalConfig) -> Self {
Self {
config,
format: InjectionFormat::Markdown,
custom_template: None,
}
}
/// Set the output format
pub fn with_format(mut self, format: InjectionFormat) -> Self {
self.format = format;
self
}
/// Set a custom template for injection
///
/// Template placeholders:
/// - `{{preferences}}` - Formatted preferences section
/// - `{{knowledge}}` - Formatted knowledge section
/// - `{{experience}}` - Formatted experience section
/// - `{{all}}` - All memories combined
pub fn with_custom_template(mut self, template: impl Into<String>) -> Self {
self.custom_template = Some(template.into());
self
}
/// Inject memories into a base system prompt
///
/// This method constructs an enhanced system prompt by:
/// 1. Starting with the base prompt
/// 2. Adding a "用户偏好" section if preferences exist
/// 3. Adding a "相关知识" section if knowledge exists
/// 4. Adding an "经验参考" section if experience exists
///
/// Each section respects the token budget configuration.
pub fn inject(&self, base_prompt: &str, memories: &RetrievalResult) -> String {
// If no memories, return base prompt unchanged
if memories.is_empty() {
return base_prompt.to_string();
}
let mut result = base_prompt.to_string();
// Inject preferences section
if !memories.preferences.is_empty() {
let section = self.format_section(
"## 用户偏好",
&memories.preferences,
self.config.preference_budget,
|entry| format!("- {}", entry.content),
);
result.push_str("\n\n");
result.push_str(&section);
}
// Inject knowledge section
if !memories.knowledge.is_empty() {
let section = self.format_section(
"## 相关知识",
&memories.knowledge,
self.config.knowledge_budget,
|entry| format!("- {}", entry.content),
);
result.push_str("\n\n");
result.push_str(&section);
}
// Inject experience section
if !memories.experience.is_empty() {
let section = self.format_section(
"## 经验参考",
&memories.experience,
self.config.experience_budget,
|entry| format!("- {}", entry.content),
);
result.push_str("\n\n");
result.push_str(&section);
}
// Add memory context footer
result.push_str("\n\n");
result.push_str("<!-- 以上内容基于历史对话自动提取的记忆 -->");
result
}
/// Format a section of memories with token budget
fn format_section<F>(
&self,
header: &str,
entries: &[MemoryEntry],
token_budget: usize,
formatter: F,
) -> String
where
F: Fn(&MemoryEntry) -> String,
{
let mut result = String::new();
result.push_str(header);
result.push('\n');
let mut used_tokens = 0;
let header_tokens = header.len() / 4;
used_tokens += header_tokens;
for entry in entries {
let line = formatter(entry);
let line_tokens = line.len() / 4;
if used_tokens + line_tokens > token_budget {
// Add truncation indicator
result.push_str("- ... (更多内容已省略)\n");
break;
}
result.push_str(&line);
result.push('\n');
used_tokens += line_tokens;
}
result
}
/// Build a minimal context string for token-limited scenarios
pub fn build_minimal_context(&self, memories: &RetrievalResult) -> String {
if memories.is_empty() {
return String::new();
}
let mut context = String::new();
// Only include top preference
if let Some(pref) = memories.preferences.first() {
context.push_str(&format!("[偏好] {}\n", pref.content));
}
// Only include top knowledge
if let Some(knowledge) = memories.knowledge.first() {
context.push_str(&format!("[知识] {}\n", knowledge.content));
}
context
}
/// Inject memories in compact format
///
/// Compact format uses inline notation: [P] for preferences, [K] for knowledge, [E] for experience
pub fn inject_compact(&self, base_prompt: &str, memories: &RetrievalResult) -> String {
if memories.is_empty() {
return base_prompt.to_string();
}
let mut result = base_prompt.to_string();
let mut context_parts = Vec::new();
// Add compact preferences
for entry in &memories.preferences {
context_parts.push(format!("[P] {}", entry.content));
}
// Add compact knowledge
for entry in &memories.knowledge {
context_parts.push(format!("[K] {}", entry.content));
}
// Add compact experience
for entry in &memories.experience {
context_parts.push(format!("[E] {}", entry.content));
}
if !context_parts.is_empty() {
result.push_str("\n\n[记忆上下文]\n");
result.push_str(&context_parts.join("\n"));
}
result
}
/// Inject memories as JSON structure
///
/// Returns a JSON object with preferences, knowledge, and experience arrays
pub fn inject_json(&self, base_prompt: &str, memories: &RetrievalResult) -> String {
if memories.is_empty() {
return base_prompt.to_string();
}
let preferences: Vec<_> = memories.preferences.iter()
.map(|e| serde_json::json!({
"content": e.content,
"importance": e.importance,
"keywords": e.keywords,
}))
.collect();
let knowledge: Vec<_> = memories.knowledge.iter()
.map(|e| serde_json::json!({
"content": e.content,
"importance": e.importance,
"keywords": e.keywords,
}))
.collect();
let experience: Vec<_> = memories.experience.iter()
.map(|e| serde_json::json!({
"content": e.content,
"importance": e.importance,
"keywords": e.keywords,
}))
.collect();
let memories_json = serde_json::json!({
"preferences": preferences,
"knowledge": knowledge,
"experience": experience,
});
format!("{}\n\n[记忆上下文]\n{}", base_prompt, serde_json::to_string_pretty(&memories_json).unwrap_or_default())
}
/// Inject using custom template
///
/// Template placeholders:
/// - `{{preferences}}` - Formatted preferences section
/// - `{{knowledge}}` - Formatted knowledge section
/// - `{{experience}}` - Formatted experience section
/// - `{{all}}` - All memories combined
pub fn inject_custom(&self, template: &str, memories: &RetrievalResult) -> String {
let mut result = template.to_string();
// Format each section
let prefs = if !memories.preferences.is_empty() {
memories.preferences.iter()
.map(|e| format!("- {}", e.content))
.collect::<Vec<_>>()
.join("\n")
} else {
String::new()
};
let knowledge = if !memories.knowledge.is_empty() {
memories.knowledge.iter()
.map(|e| format!("- {}", e.content))
.collect::<Vec<_>>()
.join("\n")
} else {
String::new()
};
let experience = if !memories.experience.is_empty() {
memories.experience.iter()
.map(|e| format!("- {}", e.content))
.collect::<Vec<_>>()
.join("\n")
} else {
String::new()
};
// Combine all
let all = format!(
"用户偏好:\n{}\n\n相关知识:\n{}\n\n经验参考:\n{}",
if prefs.is_empty() { "" } else { &prefs },
if knowledge.is_empty() { "" } else { &knowledge },
if experience.is_empty() { "" } else { &experience },
);
// Replace placeholders
result = result.replace("{{preferences}}", &prefs);
result = result.replace("{{knowledge}}", &knowledge);
result = result.replace("{{experience}}", &experience);
result = result.replace("{{all}}", &all);
result
}
/// Inject memories using the configured format
pub fn inject_with_format(&self, base_prompt: &str, memories: &RetrievalResult) -> String {
match self.format {
InjectionFormat::Markdown => self.inject(base_prompt, memories),
InjectionFormat::Compact => self.inject_compact(base_prompt, memories),
InjectionFormat::Json => self.inject_json(base_prompt, memories),
}
}
/// Estimate total tokens that will be injected
pub fn estimate_injection_tokens(&self, memories: &RetrievalResult) -> usize {
let mut total = 0;
// Count preference tokens
for entry in &memories.preferences {
total += entry.estimated_tokens();
if total > self.config.preference_budget {
total = self.config.preference_budget;
break;
}
}
// Count knowledge tokens
let mut knowledge_tokens = 0;
for entry in &memories.knowledge {
knowledge_tokens += entry.estimated_tokens();
if knowledge_tokens > self.config.knowledge_budget {
knowledge_tokens = self.config.knowledge_budget;
break;
}
}
total += knowledge_tokens;
// Count experience tokens
let mut experience_tokens = 0;
for entry in &memories.experience {
experience_tokens += entry.estimated_tokens();
if experience_tokens > self.config.experience_budget {
experience_tokens = self.config.experience_budget;
break;
}
}
total += experience_tokens;
total
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
use chrono::Utc;
fn create_test_entry(content: &str) -> MemoryEntry {
MemoryEntry {
uri: "test://uri".to_string(),
memory_type: MemoryType::Preference,
content: content.to_string(),
keywords: vec![],
importance: 5,
access_count: 0,
created_at: Utc::now(),
last_accessed: Utc::now(),
}
}
#[test]
fn test_injector_empty_memories() {
let injector = PromptInjector::new();
let base = "You are a helpful assistant.";
let memories = RetrievalResult::default();
let result = injector.inject(base, &memories);
assert_eq!(result, base);
}
#[test]
fn test_injector_with_preferences() {
let injector = PromptInjector::new();
let base = "You are a helpful assistant.";
let memories = RetrievalResult {
preferences: vec![create_test_entry("User prefers concise responses")],
knowledge: vec![],
experience: vec![],
total_tokens: 0,
};
let result = injector.inject(base, &memories);
assert!(result.contains("用户偏好"));
assert!(result.contains("User prefers concise responses"));
}
#[test]
fn test_injector_with_all_types() {
let injector = PromptInjector::new();
let base = "You are a helpful assistant.";
let memories = RetrievalResult {
preferences: vec![create_test_entry("Prefers concise")],
knowledge: vec![create_test_entry("Knows Rust")],
experience: vec![create_test_entry("Browser skill works well")],
total_tokens: 0,
};
let result = injector.inject(base, &memories);
assert!(result.contains("用户偏好"));
assert!(result.contains("相关知识"));
assert!(result.contains("经验参考"));
}
#[test]
fn test_minimal_context() {
let injector = PromptInjector::new();
let memories = RetrievalResult {
preferences: vec![create_test_entry("Prefers concise")],
knowledge: vec![create_test_entry("Knows Rust")],
experience: vec![],
total_tokens: 0,
};
let context = injector.build_minimal_context(&memories);
assert!(context.contains("[偏好]"));
assert!(context.contains("[知识]"));
}
#[test]
fn test_estimate_tokens() {
let injector = PromptInjector::new();
let memories = RetrievalResult {
preferences: vec![create_test_entry("Short text")],
knowledge: vec![],
experience: vec![],
total_tokens: 0,
};
let estimate = injector.estimate_injection_tokens(&memories);
assert!(estimate > 0);
}
#[test]
fn test_inject_compact() {
let injector = PromptInjector::new();
let base = "You are a helpful assistant.";
let memories = RetrievalResult {
preferences: vec![create_test_entry("Prefers concise")],
knowledge: vec![create_test_entry("Knows Rust")],
experience: vec![],
total_tokens: 0,
};
let result = injector.inject_compact(base, &memories);
assert!(result.contains("[P]"));
assert!(result.contains("[K]"));
assert!(result.contains("[记忆上下文]"));
}
#[test]
fn test_inject_json() {
let injector = PromptInjector::new();
let base = "You are a helpful assistant.";
let memories = RetrievalResult {
preferences: vec![create_test_entry("Prefers concise")],
knowledge: vec![],
experience: vec![],
total_tokens: 0,
};
let result = injector.inject_json(base, &memories);
assert!(result.contains("\"preferences\""));
assert!(result.contains("Prefers concise"));
}
#[test]
fn test_inject_custom() {
let injector = PromptInjector::new();
let template = "Context:\n{{all}}";
let memories = RetrievalResult {
preferences: vec![create_test_entry("Prefers concise")],
knowledge: vec![create_test_entry("Knows Rust")],
experience: vec![],
total_tokens: 0,
};
let result = injector.inject_custom(template, &memories);
assert!(result.contains("用户偏好"));
assert!(result.contains("相关知识"));
}
#[test]
fn test_format_selection() {
let base = "Base";
let memories = RetrievalResult {
preferences: vec![create_test_entry("Test")],
knowledge: vec![],
experience: vec![],
total_tokens: 0,
};
// Test markdown format
let injector_md = PromptInjector::new().with_format(InjectionFormat::Markdown);
let result_md = injector_md.inject_with_format(base, &memories);
assert!(result_md.contains("## 用户偏好"));
// Test compact format
let injector_compact = PromptInjector::new().with_format(InjectionFormat::Compact);
let result_compact = injector_compact.inject_with_format(base, &memories);
assert!(result_compact.contains("[P]"));
}
}

View File

@@ -0,0 +1,141 @@
//! ZCLAW Agent Growth System
//!
//! This crate provides the agent growth functionality for ZCLAW,
//! enabling agents to learn and evolve from conversations.
//!
//! # Architecture
//!
//! The growth system consists of four main components:
//!
//! 1. **MemoryExtractor** (`extractor`) - Analyzes conversations and extracts
//! preferences, knowledge, and experience using LLM.
//!
//! 2. **MemoryRetriever** (`retriever`) - Performs semantic search over
//! stored memories to find contextually relevant information.
//!
//! 3. **PromptInjector** (`injector`) - Injects retrieved memories into
//! the system prompt with token budget control.
//!
//! 4. **GrowthTracker** (`tracker`) - Tracks growth metrics and evolution
//! over time.
//!
//! # Storage
//!
//! All memories are stored in OpenViking with a URI structure:
//!
//! ```text
//! agent://{agent_id}/
//! ├── preferences/{category} - User preferences
//! ├── knowledge/{domain} - Accumulated knowledge
//! ├── experience/{skill} - Skill/tool experience
//! └── sessions/{session_id}/ - Conversation history
//! ├── raw - Original conversation (L0)
//! ├── summary - Summary (L1)
//! └── keywords - Keywords (L2)
//! ```
//!
//! # Usage
//!
//! ```rust,ignore
//! use zclaw_growth::{MemoryExtractor, MemoryRetriever, PromptInjector, VikingAdapter};
//!
//! // Create components
//! let viking = VikingAdapter::in_memory();
//! let retriever = MemoryRetriever::new(Arc::new(viking.clone()));
//! let injector = PromptInjector::new();
//!
//! // Before conversation: retrieve relevant memories
//! let memories = retriever.retrieve(&agent_id, &user_input).await?;
//!
//! // Inject into system prompt
//! let enhanced_prompt = injector.inject(&base_prompt, &memories);
//!
//! // After conversation: extract and store new memories
//! let extracted = extractor.extract(&messages, session_id).await?;
//! extractor.store_memories(&agent_id, &extracted).await?;
//! ```
pub mod types;
pub mod extractor;
pub mod retriever;
pub mod injector;
pub mod tracker;
pub mod viking_adapter;
pub mod storage;
pub mod retrieval;
// Re-export main types for convenience
pub use types::{
ExtractedMemory,
ExtractionConfig,
GrowthStats,
MemoryEntry,
MemoryType,
RetrievalConfig,
RetrievalResult,
UriBuilder,
};
pub use extractor::{LlmDriverForExtraction, MemoryExtractor};
pub use retriever::{MemoryRetriever, MemoryStats};
pub use injector::{InjectionFormat, PromptInjector};
pub use tracker::{AgentMetadata, GrowthTracker, LearningEvent};
pub use viking_adapter::{FindOptions, VikingAdapter, VikingLevel, VikingStorage};
pub use storage::SqliteStorage;
pub use retrieval::{MemoryCache, QueryAnalyzer, SemanticScorer};
/// Growth system configuration
#[derive(Debug, Clone)]
pub struct GrowthConfig {
/// Enable/disable growth system
pub enabled: bool,
/// Retrieval configuration
pub retrieval: RetrievalConfig,
/// Extraction configuration
pub extraction: ExtractionConfig,
/// Auto-extract after each conversation
pub auto_extract: bool,
}
impl Default for GrowthConfig {
fn default() -> Self {
Self {
enabled: true,
retrieval: RetrievalConfig::default(),
extraction: ExtractionConfig::default(),
auto_extract: true,
}
}
}
/// Convenience function to create a complete growth system
pub fn create_growth_system(
viking: std::sync::Arc<VikingAdapter>,
llm_driver: std::sync::Arc<dyn LlmDriverForExtraction>,
) -> (MemoryExtractor, MemoryRetriever, PromptInjector, GrowthTracker) {
let extractor = MemoryExtractor::new(llm_driver).with_viking(viking.clone());
let retriever = MemoryRetriever::new(viking.clone());
let injector = PromptInjector::new();
let tracker = GrowthTracker::new(viking);
(extractor, retriever, injector, tracker)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_growth_config_default() {
let config = GrowthConfig::default();
assert!(config.enabled);
assert!(config.auto_extract);
assert_eq!(config.retrieval.max_tokens, 500);
}
#[test]
fn test_memory_type_reexport() {
let mt = MemoryType::Preference;
assert_eq!(format!("{}", mt), "preferences");
}
}

View File

@@ -0,0 +1,365 @@
//! Memory Cache
//!
//! Provides caching for frequently accessed memories to improve
//! retrieval performance.
use crate::types::{MemoryEntry, MemoryType};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
/// Cache entry with metadata
struct CacheEntry {
/// The memory entry
entry: MemoryEntry,
/// Last access time
last_accessed: Instant,
/// Access count
access_count: u32,
}
/// Cache key for efficient lookups
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
struct CacheKey {
agent_id: String,
memory_type: MemoryType,
category: String,
}
impl From<&MemoryEntry> for CacheKey {
fn from(entry: &MemoryEntry) -> Self {
// Parse URI to extract components
let parts: Vec<&str> = entry.uri.trim_start_matches("agent://").split('/').collect();
Self {
agent_id: parts.first().unwrap_or(&"").to_string(),
memory_type: entry.memory_type,
category: parts.get(2).unwrap_or(&"").to_string(),
}
}
}
/// Memory cache configuration
#[derive(Debug, Clone)]
pub struct CacheConfig {
/// Maximum number of entries
pub max_entries: usize,
/// Time-to-live for entries
pub ttl: Duration,
/// Enable/disable caching
pub enabled: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
ttl: Duration::from_secs(3600), // 1 hour
enabled: true,
}
}
}
/// Memory cache for hot memories
pub struct MemoryCache {
/// Cache storage
cache: RwLock<HashMap<String, CacheEntry>>,
/// Configuration
config: CacheConfig,
/// Cache statistics
stats: RwLock<CacheStats>,
}
/// Cache statistics
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
/// Total cache hits
pub hits: u64,
/// Total cache misses
pub misses: u64,
/// Total entries evicted
pub evictions: u64,
}
impl MemoryCache {
/// Create a new memory cache
pub fn new(config: CacheConfig) -> Self {
Self {
cache: RwLock::new(HashMap::new()),
config,
stats: RwLock::new(CacheStats::default()),
}
}
/// Create with default configuration
pub fn default_config() -> Self {
Self::new(CacheConfig::default())
}
/// Get a memory from cache
pub async fn get(&self, uri: &str) -> Option<MemoryEntry> {
if !self.config.enabled {
return None;
}
let mut cache = self.cache.write().await;
if let Some(cached) = cache.get_mut(uri) {
// Check TTL
if cached.last_accessed.elapsed() > self.config.ttl {
cache.remove(uri);
return None;
}
// Update access metadata
cached.last_accessed = Instant::now();
cached.access_count += 1;
// Update stats
let mut stats = self.stats.write().await;
stats.hits += 1;
return Some(cached.entry.clone());
}
// Update stats
let mut stats = self.stats.write().await;
stats.misses += 1;
None
}
/// Put a memory into cache
pub async fn put(&self, entry: MemoryEntry) {
if !self.config.enabled {
return;
}
let mut cache = self.cache.write().await;
// Check capacity and evict if necessary
if cache.len() >= self.config.max_entries {
self.evict_lru(&mut cache).await;
}
cache.insert(
entry.uri.clone(),
CacheEntry {
entry,
last_accessed: Instant::now(),
access_count: 0,
},
);
}
/// Remove a memory from cache
pub async fn remove(&self, uri: &str) {
let mut cache = self.cache.write().await;
cache.remove(uri);
}
/// Clear the cache
pub async fn clear(&self) {
let mut cache = self.cache.write().await;
cache.clear();
}
/// Evict least recently used entries
async fn evict_lru(&self, cache: &mut HashMap<String, CacheEntry>) {
// Find LRU entry
let lru_key = cache
.iter()
.min_by_key(|(_, v)| (v.access_count, v.last_accessed))
.map(|(k, _)| k.clone());
if let Some(key) = lru_key {
cache.remove(&key);
let mut stats = self.stats.write().await;
stats.evictions += 1;
}
}
/// Get cache statistics
pub async fn stats(&self) -> CacheStats {
self.stats.read().await.clone()
}
/// Get cache hit rate
pub async fn hit_rate(&self) -> f32 {
let stats = self.stats.read().await;
let total = stats.hits + stats.misses;
if total == 0 {
return 0.0;
}
stats.hits as f32 / total as f32
}
/// Get cache size
pub async fn size(&self) -> usize {
self.cache.read().await.len()
}
/// Warm up cache with frequently accessed entries
pub async fn warmup(&self, entries: Vec<MemoryEntry>) {
for entry in entries {
self.put(entry).await;
}
}
/// Get top accessed entries (for preloading)
pub async fn get_hot_entries(&self, limit: usize) -> Vec<MemoryEntry> {
let cache = self.cache.read().await;
let mut entries: Vec<_> = cache
.values()
.map(|c| (c.access_count, c.entry.clone()))
.collect();
entries.sort_by(|a, b| b.0.cmp(&a.0));
entries.truncate(limit);
entries.into_iter().map(|(_, e)| e).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
#[tokio::test]
async fn test_cache_put_and_get() {
let cache = MemoryCache::default_config();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"User prefers concise responses".to_string(),
);
cache.put(entry.clone()).await;
let retrieved = cache.get(&entry.uri).await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "User prefers concise responses");
}
#[tokio::test]
async fn test_cache_miss() {
let cache = MemoryCache::default_config();
let retrieved = cache.get("nonexistent").await;
assert!(retrieved.is_none());
let stats = cache.stats().await;
assert_eq!(stats.misses, 1);
}
#[tokio::test]
async fn test_cache_remove() {
let cache = MemoryCache::default_config();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test".to_string(),
);
cache.put(entry.clone()).await;
cache.remove(&entry.uri).await;
let retrieved = cache.get(&entry.uri).await;
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_cache_clear() {
let cache = MemoryCache::default_config();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test".to_string(),
);
cache.put(entry).await;
cache.clear().await;
let size = cache.size().await;
assert_eq!(size, 0);
}
#[tokio::test]
async fn test_cache_stats() {
let cache = MemoryCache::default_config();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test".to_string(),
);
cache.put(entry.clone()).await;
// Hit
cache.get(&entry.uri).await;
// Miss
cache.get("nonexistent").await;
let stats = cache.stats().await;
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
let hit_rate = cache.hit_rate().await;
assert!((hit_rate - 0.5).abs() < 0.001);
}
#[tokio::test]
async fn test_cache_eviction() {
let config = CacheConfig {
max_entries: 2,
ttl: Duration::from_secs(3600),
enabled: true,
};
let cache = MemoryCache::new(config);
let entry1 = MemoryEntry::new("test", MemoryType::Preference, "1", "1".to_string());
let entry2 = MemoryEntry::new("test", MemoryType::Preference, "2", "2".to_string());
let entry3 = MemoryEntry::new("test", MemoryType::Preference, "3", "3".to_string());
cache.put(entry1.clone()).await;
cache.put(entry2.clone()).await;
// Access entry1 to make it hot
cache.get(&entry1.uri).await;
// Add entry3, should evict entry2 (LRU)
cache.put(entry3).await;
let size = cache.size().await;
assert_eq!(size, 2);
let stats = cache.stats().await;
assert_eq!(stats.evictions, 1);
}
#[tokio::test]
async fn test_get_hot_entries() {
let cache = MemoryCache::default_config();
let entry1 = MemoryEntry::new("test", MemoryType::Preference, "1", "1".to_string());
let entry2 = MemoryEntry::new("test", MemoryType::Preference, "2", "2".to_string());
cache.put(entry1.clone()).await;
cache.put(entry2.clone()).await;
// Access entry1 multiple times
cache.get(&entry1.uri).await;
cache.get(&entry1.uri).await;
let hot = cache.get_hot_entries(10).await;
assert_eq!(hot.len(), 2);
// entry1 should be first (more accesses)
assert_eq!(hot[0].uri, entry1.uri);
}
}

View File

@@ -0,0 +1,14 @@
//! Retrieval components for ZCLAW Growth System
//!
//! This module provides advanced retrieval capabilities:
//! - `semantic`: Semantic similarity computation
//! - `query`: Query analysis and expansion
//! - `cache`: Hot memory caching
pub mod semantic;
pub mod query;
pub mod cache;
pub use semantic::SemanticScorer;
pub use query::QueryAnalyzer;
pub use cache::MemoryCache;

View File

@@ -0,0 +1,352 @@
//! Query Analyzer
//!
//! Provides query analysis and expansion capabilities for improved retrieval.
//! Extracts keywords, identifies intent, and generates search variations.
use crate::types::MemoryType;
use std::collections::HashSet;
/// Query analysis result
#[derive(Debug, Clone)]
pub struct AnalyzedQuery {
/// Original query string
pub original: String,
/// Extracted keywords
pub keywords: Vec<String>,
/// Query intent
pub intent: QueryIntent,
/// Memory types to search (inferred from query)
pub target_types: Vec<MemoryType>,
/// Expanded search terms
pub expansions: Vec<String>,
}
/// Query intent classification
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryIntent {
/// Looking for preferences/settings
Preference,
/// Looking for factual knowledge
Knowledge,
/// Looking for how-to/experience
Experience,
/// General conversation
General,
/// Code-related query
Code,
/// Configuration query
Configuration,
}
/// Query analyzer
pub struct QueryAnalyzer {
/// Keywords that indicate preference queries
preference_indicators: HashSet<String>,
/// Keywords that indicate knowledge queries
knowledge_indicators: HashSet<String>,
/// Keywords that indicate experience queries
experience_indicators: HashSet<String>,
/// Keywords that indicate code queries
code_indicators: HashSet<String>,
/// Stop words to filter out
stop_words: HashSet<String>,
}
impl QueryAnalyzer {
/// Create a new query analyzer
pub fn new() -> Self {
Self {
preference_indicators: [
"prefer", "like", "want", "favorite", "favourite", "style",
"format", "language", "setting", "preference", "usually",
"typically", "always", "never", "习惯", "偏好", "喜欢", "想要",
]
.iter()
.map(|s| s.to_string())
.collect(),
knowledge_indicators: [
"what", "how", "why", "explain", "tell", "know", "learn",
"understand", "meaning", "definition", "concept", "theory",
"是什么", "怎么", "为什么", "解释", "了解", "知道",
]
.iter()
.map(|s| s.to_string())
.collect(),
experience_indicators: [
"experience", "tried", "used", "before", "last time",
"previous", "history", "remember", "recall", "when",
"经验", "尝试", "用过", "上次", "记得", "回忆",
]
.iter()
.map(|s| s.to_string())
.collect(),
code_indicators: [
"code", "function", "class", "method", "variable", "type",
"error", "bug", "fix", "implement", "refactor", "api",
"代码", "函数", "", "方法", "变量", "错误", "修复", "实现",
]
.iter()
.map(|s| s.to_string())
.collect(),
stop_words: [
"the", "a", "an", "is", "are", "was", "were", "be", "been",
"have", "has", "had", "do", "does", "did", "will", "would",
"could", "should", "may", "might", "must", "can", "to", "of",
"in", "for", "on", "with", "at", "by", "from", "as", "and",
"or", "but", "if", "then", "else", "when", "where", "which",
"who", "whom", "whose", "this", "that", "these", "those",
]
.iter()
.map(|s| s.to_string())
.collect(),
}
}
/// Analyze a query string
pub fn analyze(&self, query: &str) -> AnalyzedQuery {
let keywords = self.extract_keywords(query);
let intent = self.classify_intent(&keywords);
let target_types = self.infer_memory_types(intent, &keywords);
let expansions = self.expand_query(&keywords);
AnalyzedQuery {
original: query.to_string(),
keywords,
intent,
target_types,
expansions,
}
}
/// Extract keywords from query
fn extract_keywords(&self, query: &str) -> Vec<String> {
query
.to_lowercase()
.split(|c: char| !c.is_alphanumeric() && !is_cjk(c))
.filter(|s| !s.is_empty() && s.len() > 1)
.filter(|s| !self.stop_words.contains(*s))
.map(|s| s.to_string())
.collect()
}
/// Classify query intent
fn classify_intent(&self, keywords: &[String]) -> QueryIntent {
let mut scores = [
(QueryIntent::Preference, 0),
(QueryIntent::Knowledge, 0),
(QueryIntent::Experience, 0),
(QueryIntent::Code, 0),
];
for keyword in keywords {
if self.preference_indicators.contains(keyword) {
scores[0].1 += 2;
}
if self.knowledge_indicators.contains(keyword) {
scores[1].1 += 2;
}
if self.experience_indicators.contains(keyword) {
scores[2].1 += 2;
}
if self.code_indicators.contains(keyword) {
scores[3].1 += 2;
}
}
// Find highest scoring intent
scores.sort_by(|a, b| b.1.cmp(&a.1));
if scores[0].1 > 0 {
scores[0].0
} else {
QueryIntent::General
}
}
/// Infer which memory types to search
fn infer_memory_types(&self, intent: QueryIntent, _keywords: &[String]) -> Vec<MemoryType> {
let mut types = Vec::new();
match intent {
QueryIntent::Preference => {
types.push(MemoryType::Preference);
}
QueryIntent::Knowledge | QueryIntent::Code => {
types.push(MemoryType::Knowledge);
types.push(MemoryType::Experience);
}
QueryIntent::Experience => {
types.push(MemoryType::Experience);
types.push(MemoryType::Knowledge);
}
QueryIntent::General => {
// Search all types
types.push(MemoryType::Preference);
types.push(MemoryType::Knowledge);
types.push(MemoryType::Experience);
}
QueryIntent::Configuration => {
types.push(MemoryType::Preference);
types.push(MemoryType::Knowledge);
}
}
types
}
/// Expand query with related terms
fn expand_query(&self, keywords: &[String]) -> Vec<String> {
let mut expansions = Vec::new();
// Add stemmed variations (simplified)
for keyword in keywords {
// Add singular/plural variations
if keyword.ends_with('s') && keyword.len() > 3 {
expansions.push(keyword[..keyword.len()-1].to_string());
} else {
expansions.push(format!("{}s", keyword));
}
// Add common synonyms (simplified)
if let Some(synonyms) = self.get_synonyms(keyword) {
expansions.extend(synonyms);
}
}
expansions
}
/// Get synonyms for a keyword (simplified)
fn get_synonyms(&self, keyword: &str) -> Option<Vec<String>> {
let synonyms: &[&str] = match keyword {
"code" => &["program", "script", "source"],
"error" => &["bug", "issue", "problem", "exception"],
"fix" => &["solve", "resolve", "repair", "patch"],
"fast" => &["quick", "speed", "performance", "efficient"],
"slow" => &["performance", "optimize", "speed"],
"help" => &["assist", "support", "guide", "aid"],
"learn" => &["study", "understand", "know", "grasp"],
_ => return None,
};
Some(synonyms.iter().map(|s| s.to_string()).collect())
}
/// Generate search queries from analyzed query
pub fn generate_search_queries(&self, analyzed: &AnalyzedQuery) -> Vec<String> {
let mut queries = vec![analyzed.original.clone()];
// Add keyword-based query
if !analyzed.keywords.is_empty() {
queries.push(analyzed.keywords.join(" "));
}
// Add expanded terms
for expansion in &analyzed.expansions {
if !expansion.is_empty() {
queries.push(expansion.clone());
}
}
// Deduplicate
queries.sort();
queries.dedup();
queries
}
}
impl Default for QueryAnalyzer {
fn default() -> Self {
Self::new()
}
}
/// Check if character is CJK
fn is_cjk(c: char) -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
'\u{3400}'..='\u{4DBF}' | // CJK Unified Ideographs Extension A
'\u{20000}'..='\u{2A6DF}' | // CJK Unified Ideographs Extension B
'\u{2A700}'..='\u{2B73F}' | // CJK Unified Ideographs Extension C
'\u{2B740}'..='\u{2B81F}' | // CJK Unified Ideographs Extension D
'\u{2B820}'..='\u{2CEAF}' | // CJK Unified Ideographs Extension E
'\u{F900}'..='\u{FAFF}' | // CJK Compatibility Ideographs
'\u{2F800}'..='\u{2FA1F}' // CJK Compatibility Ideographs Supplement
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_keywords() {
let analyzer = QueryAnalyzer::new();
let keywords = analyzer.extract_keywords("What is the Rust programming language?");
assert!(keywords.contains(&"rust".to_string()));
assert!(keywords.contains(&"programming".to_string()));
assert!(keywords.contains(&"language".to_string()));
assert!(!keywords.contains(&"the".to_string())); // stop word
}
#[test]
fn test_classify_intent_preference() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("I prefer concise responses");
assert_eq!(analyzed.intent, QueryIntent::Preference);
assert!(analyzed.target_types.contains(&MemoryType::Preference));
}
#[test]
fn test_classify_intent_knowledge() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("Explain how async/await works in Rust");
assert_eq!(analyzed.intent, QueryIntent::Knowledge);
}
#[test]
fn test_classify_intent_code() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("Fix this error in my function");
assert_eq!(analyzed.intent, QueryIntent::Code);
}
#[test]
fn test_query_expansion() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("fix the error");
assert!(!analyzed.expansions.is_empty());
}
#[test]
fn test_generate_search_queries() {
let analyzer = QueryAnalyzer::new();
let analyzed = analyzer.analyze("Rust programming");
let queries = analyzer.generate_search_queries(&analyzed);
assert!(queries.len() >= 1);
}
#[test]
fn test_cjk_detection() {
assert!(is_cjk('中'));
assert!(is_cjk('文'));
assert!(!is_cjk('a'));
assert!(!is_cjk('1'));
}
#[test]
fn test_chinese_keywords() {
let analyzer = QueryAnalyzer::new();
let keywords = analyzer.extract_keywords("我喜欢简洁的回复");
// Chinese characters should be extracted
assert!(!keywords.is_empty());
}
}

View File

@@ -0,0 +1,374 @@
//! Semantic Similarity Scorer
//!
//! Provides TF-IDF based semantic similarity computation for memory retrieval.
//! This is a lightweight, dependency-free implementation suitable for
//! medium-scale memory systems.
use std::collections::{HashMap, HashSet};
use crate::types::MemoryEntry;
/// Semantic similarity scorer using TF-IDF
pub struct SemanticScorer {
/// Document frequency for IDF computation
document_frequencies: HashMap<String, usize>,
/// Total number of documents
total_documents: usize,
/// Precomputed TF-IDF vectors for entries
entry_vectors: HashMap<String, HashMap<String, f32>>,
/// Stop words to ignore
stop_words: HashSet<String>,
}
impl SemanticScorer {
/// Create a new semantic scorer
pub fn new() -> Self {
Self {
document_frequencies: HashMap::new(),
total_documents: 0,
entry_vectors: HashMap::new(),
stop_words: Self::default_stop_words(),
}
}
/// Get default stop words
fn default_stop_words() -> HashSet<String> {
[
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
"have", "has", "had", "do", "does", "did", "will", "would", "could",
"should", "may", "might", "must", "shall", "can", "need", "dare",
"ought", "used", "to", "of", "in", "for", "on", "with", "at", "by",
"from", "as", "into", "through", "during", "before", "after",
"above", "below", "between", "under", "again", "further", "then",
"once", "here", "there", "when", "where", "why", "how", "all",
"each", "few", "more", "most", "other", "some", "such", "no", "nor",
"not", "only", "own", "same", "so", "than", "too", "very", "just",
"and", "but", "if", "or", "because", "until", "while", "although",
"though", "after", "before", "when", "whenever", "i", "you", "he",
"she", "it", "we", "they", "what", "which", "who", "whom", "this",
"that", "these", "those", "am", "im", "youre", "hes", "shes",
"its", "were", "theyre", "ive", "youve", "weve", "theyve", "id",
"youd", "hed", "shed", "wed", "theyd", "ill", "youll", "hell",
"shell", "well", "theyll", "isnt", "arent", "wasnt", "werent",
"hasnt", "havent", "hadnt", "doesnt", "dont", "didnt", "wont",
"wouldnt", "shant", "shouldnt", "cant", "cannot", "couldnt",
"mustnt", "lets", "thats", "whos", "whats", "heres", "theres",
"whens", "wheres", "whys", "hows", "a", "b", "c", "d", "e", "f",
"g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s",
"t", "u", "v", "w", "x", "y", "z",
]
.iter()
.map(|s| s.to_string())
.collect()
}
/// Tokenize text into words
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1)
.map(|s| s.to_string())
.collect()
}
/// Remove stop words from tokens
fn remove_stop_words(&self, tokens: &[String]) -> Vec<String> {
tokens
.iter()
.filter(|t| !self.stop_words.contains(*t))
.cloned()
.collect()
}
/// Compute term frequency for a list of tokens
fn compute_tf(tokens: &[String]) -> HashMap<String, f32> {
let mut tf = HashMap::new();
let total = tokens.len() as f32;
for token in tokens {
*tf.entry(token.clone()).or_insert(0.0) += 1.0;
}
// Normalize by total tokens
for count in tf.values_mut() {
*count /= total;
}
tf
}
/// Compute IDF for a term
fn compute_idf(&self, term: &str) -> f32 {
let df = self.document_frequencies.get(term).copied().unwrap_or(0);
if df == 0 || self.total_documents == 0 {
return 0.0;
}
((self.total_documents as f32 + 1.0) / (df as f32 + 1.0)).ln() + 1.0
}
/// Index an entry for semantic search
pub fn index_entry(&mut self, entry: &MemoryEntry) {
// Tokenize content and keywords
let mut all_tokens = Self::tokenize(&entry.content);
for keyword in &entry.keywords {
all_tokens.extend(Self::tokenize(keyword));
}
all_tokens = self.remove_stop_words(&all_tokens);
// Update document frequencies
let unique_terms: HashSet<_> = all_tokens.iter().cloned().collect();
for term in &unique_terms {
*self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
}
self.total_documents += 1;
// Compute TF-IDF vector
let tf = Self::compute_tf(&all_tokens);
let mut tfidf = HashMap::new();
for (term, tf_val) in tf {
let idf = self.compute_idf(&term);
tfidf.insert(term, tf_val * idf);
}
self.entry_vectors.insert(entry.uri.clone(), tfidf);
}
/// Remove an entry from the index
pub fn remove_entry(&mut self, uri: &str) {
self.entry_vectors.remove(uri);
}
/// Compute cosine similarity between two vectors
fn cosine_similarity(v1: &HashMap<String, f32>, v2: &HashMap<String, f32>) -> f32 {
if v1.is_empty() || v2.is_empty() {
return 0.0;
}
// Find common keys
let mut dot_product = 0.0;
let mut norm1 = 0.0;
let mut norm2 = 0.0;
for (k, v) in v1 {
norm1 += v * v;
if let Some(v2_val) = v2.get(k) {
dot_product += v * v2_val;
}
}
for v in v2.values() {
norm2 += v * v;
}
let denom = (norm1 * norm2).sqrt();
if denom == 0.0 {
0.0
} else {
(dot_product / denom).clamp(0.0, 1.0)
}
}
/// Score similarity between query and entry
pub fn score_similarity(&self, query: &str, entry: &MemoryEntry) -> f32 {
// Tokenize query
let query_tokens = self.remove_stop_words(&Self::tokenize(query));
if query_tokens.is_empty() {
return 0.5; // Neutral score for empty query
}
// Compute query TF-IDF
let query_tf = Self::compute_tf(&query_tokens);
let mut query_vec = HashMap::new();
for (term, tf_val) in query_tf {
let idf = self.compute_idf(&term);
query_vec.insert(term, tf_val * idf);
}
// Get entry vector
let entry_vec = match self.entry_vectors.get(&entry.uri) {
Some(v) => v,
None => {
// Fall back to simple matching if not indexed
return self.fallback_similarity(&query_tokens, entry);
}
};
// Compute cosine similarity
let cosine = Self::cosine_similarity(&query_vec, entry_vec);
// Combine with keyword matching for better results
let keyword_boost = self.keyword_match_score(&query_tokens, entry);
// Weighted combination
cosine * 0.7 + keyword_boost * 0.3
}
/// Fallback similarity when entry is not indexed
fn fallback_similarity(&self, query_tokens: &[String], entry: &MemoryEntry) -> f32 {
let content_lower = entry.content.to_lowercase();
let mut matches = 0;
for token in query_tokens {
if content_lower.contains(token) {
matches += 1;
}
for keyword in &entry.keywords {
if keyword.to_lowercase().contains(token) {
matches += 1;
break;
}
}
}
(matches as f32) / (query_tokens.len() * 2).max(1) as f32
}
/// Compute keyword match score
fn keyword_match_score(&self, query_tokens: &[String], entry: &MemoryEntry) -> f32 {
if entry.keywords.is_empty() {
return 0.0;
}
let mut matches = 0;
for token in query_tokens {
for keyword in &entry.keywords {
if keyword.to_lowercase().contains(&token.to_lowercase()) {
matches += 1;
break;
}
}
}
(matches as f32) / query_tokens.len().max(1) as f32
}
/// Clear the index
pub fn clear(&mut self) {
self.document_frequencies.clear();
self.total_documents = 0;
self.entry_vectors.clear();
}
/// Get statistics about the index
pub fn stats(&self) -> IndexStats {
IndexStats {
total_documents: self.total_documents,
unique_terms: self.document_frequencies.len(),
indexed_entries: self.entry_vectors.len(),
}
}
}
impl Default for SemanticScorer {
fn default() -> Self {
Self::new()
}
}
/// Index statistics
#[derive(Debug, Clone)]
pub struct IndexStats {
pub total_documents: usize,
pub unique_terms: usize,
pub indexed_entries: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
#[test]
fn test_tokenize() {
let tokens = SemanticScorer::tokenize("Hello, World! This is a test.");
assert_eq!(tokens, vec!["hello", "world", "this", "is", "test"]);
}
#[test]
fn test_stop_words_removal() {
let scorer = SemanticScorer::new();
let tokens = vec!["hello".to_string(), "the".to_string(), "world".to_string()];
let filtered = scorer.remove_stop_words(&tokens);
assert_eq!(filtered, vec!["hello", "world"]);
}
#[test]
fn test_tf_computation() {
let tokens = vec!["hello".to_string(), "hello".to_string(), "world".to_string()];
let tf = SemanticScorer::compute_tf(&tokens);
let hello_tf = tf.get("hello").unwrap();
let world_tf = tf.get("world").unwrap();
// Allow for floating point comparison
assert!((hello_tf - (2.0 / 3.0)).abs() < 0.001);
assert!((world_tf - (1.0 / 3.0)).abs() < 0.001);
}
#[test]
fn test_cosine_similarity() {
let mut v1 = HashMap::new();
v1.insert("a".to_string(), 1.0);
v1.insert("b".to_string(), 2.0);
let mut v2 = HashMap::new();
v2.insert("a".to_string(), 1.0);
v2.insert("b".to_string(), 2.0);
// Identical vectors should have similarity 1.0
let sim = SemanticScorer::cosine_similarity(&v1, &v2);
assert!((sim - 1.0).abs() < 0.001);
// Orthogonal vectors should have similarity 0.0
let mut v3 = HashMap::new();
v3.insert("c".to_string(), 1.0);
let sim2 = SemanticScorer::cosine_similarity(&v1, &v3);
assert!((sim2 - 0.0).abs() < 0.001);
}
#[test]
fn test_index_and_score() {
let mut scorer = SemanticScorer::new();
let entry1 = MemoryEntry::new(
"test",
MemoryType::Knowledge,
"rust",
"Rust is a systems programming language focused on safety and performance".to_string(),
).with_keywords(vec!["rust".to_string(), "programming".to_string(), "safety".to_string()]);
let entry2 = MemoryEntry::new(
"test",
MemoryType::Knowledge,
"python",
"Python is a high-level programming language".to_string(),
).with_keywords(vec!["python".to_string(), "programming".to_string()]);
scorer.index_entry(&entry1);
scorer.index_entry(&entry2);
// Query for Rust should score higher on entry1
let score1 = scorer.score_similarity("rust safety", &entry1);
let score2 = scorer.score_similarity("rust safety", &entry2);
assert!(score1 > score2, "Rust query should score higher on Rust entry");
}
#[test]
fn test_stats() {
let mut scorer = SemanticScorer::new();
let entry = MemoryEntry::new(
"test",
MemoryType::Knowledge,
"test",
"Hello world".to_string(),
);
scorer.index_entry(&entry);
let stats = scorer.stats();
assert_eq!(stats.total_documents, 1);
assert_eq!(stats.indexed_entries, 1);
assert!(stats.unique_terms > 0);
}
}

View File

@@ -0,0 +1,348 @@
//! Memory Retriever - Retrieves relevant memories from OpenViking
//!
//! This module provides the `MemoryRetriever` which performs semantic search
//! over stored memories to find contextually relevant information.
//! Uses multiple retrieval strategies and intelligent reranking.
use crate::retrieval::{MemoryCache, QueryAnalyzer, SemanticScorer};
use crate::types::{MemoryEntry, MemoryType, RetrievalConfig, RetrievalResult};
use crate::viking_adapter::{FindOptions, VikingAdapter};
use std::sync::Arc;
use tokio::sync::RwLock;
use zclaw_types::{AgentId, Result};
/// Memory Retriever - retrieves relevant memories from OpenViking
pub struct MemoryRetriever {
/// OpenViking adapter
viking: Arc<VikingAdapter>,
/// Retrieval configuration
config: RetrievalConfig,
/// Semantic scorer for similarity computation
scorer: RwLock<SemanticScorer>,
/// Query analyzer
analyzer: QueryAnalyzer,
/// Memory cache
cache: MemoryCache,
}
impl MemoryRetriever {
/// Create a new memory retriever
pub fn new(viking: Arc<VikingAdapter>) -> Self {
Self {
viking,
config: RetrievalConfig::default(),
scorer: RwLock::new(SemanticScorer::new()),
analyzer: QueryAnalyzer::new(),
cache: MemoryCache::default_config(),
}
}
/// Create with custom configuration
pub fn with_config(mut self, config: RetrievalConfig) -> Self {
self.config = config;
self
}
/// Retrieve relevant memories for a query
///
/// This method:
/// 1. Analyzes the query to determine intent and keywords
/// 2. Searches for preferences matching the query
/// 3. Searches for relevant knowledge
/// 4. Searches for applicable experience
/// 5. Reranks results using semantic similarity
/// 6. Applies token budget constraints
pub async fn retrieve(
&self,
agent_id: &AgentId,
query: &str,
) -> Result<RetrievalResult> {
tracing::debug!("[MemoryRetriever] Retrieving memories for query: {}", query);
// Analyze query
let analyzed = self.analyzer.analyze(query);
tracing::debug!(
"[MemoryRetriever] Query analysis: intent={:?}, keywords={:?}",
analyzed.intent,
analyzed.keywords
);
// Retrieve each type with budget constraints and reranking
let preferences = self
.retrieve_and_rerank(
&agent_id.to_string(),
MemoryType::Preference,
query,
&analyzed.keywords,
self.config.max_results_per_type,
self.config.preference_budget,
)
.await?;
let knowledge = self
.retrieve_and_rerank(
&agent_id.to_string(),
MemoryType::Knowledge,
query,
&analyzed.keywords,
self.config.max_results_per_type,
self.config.knowledge_budget,
)
.await?;
let experience = self
.retrieve_and_rerank(
&agent_id.to_string(),
MemoryType::Experience,
query,
&analyzed.keywords,
self.config.max_results_per_type / 2,
self.config.experience_budget,
)
.await?;
let total_tokens = preferences.iter()
.chain(knowledge.iter())
.chain(experience.iter())
.map(|m| m.estimated_tokens())
.sum();
// Update cache with retrieved entries
for entry in preferences.iter().chain(knowledge.iter()).chain(experience.iter()) {
self.cache.put(entry.clone()).await;
}
tracing::info!(
"[MemoryRetriever] Retrieved {} preferences, {} knowledge, {} experience ({} tokens)",
preferences.len(),
knowledge.len(),
experience.len(),
total_tokens
);
Ok(RetrievalResult {
preferences,
knowledge,
experience,
total_tokens,
})
}
/// Retrieve and rerank memories by type
async fn retrieve_and_rerank(
&self,
agent_id: &str,
memory_type: MemoryType,
query: &str,
keywords: &[String],
max_results: usize,
token_budget: usize,
) -> Result<Vec<MemoryEntry>> {
// Build scope for OpenViking search
let scope = format!("agent://{}/{}", agent_id, memory_type);
// Generate search queries (original + expanded)
let analyzed_for_search = crate::retrieval::query::AnalyzedQuery {
original: query.to_string(),
keywords: keywords.to_vec(),
intent: crate::retrieval::query::QueryIntent::General,
target_types: vec![],
expansions: vec![],
};
let search_queries = self.analyzer.generate_search_queries(&analyzed_for_search);
// Search with multiple queries and deduplicate
let mut all_results = Vec::new();
let mut seen_uris = std::collections::HashSet::new();
for search_query in search_queries {
let options = FindOptions {
scope: Some(scope.clone()),
limit: Some(max_results * 2),
min_similarity: Some(self.config.min_similarity),
};
let results = self.viking.find(&search_query, options).await?;
for entry in results {
if seen_uris.insert(entry.uri.clone()) {
all_results.push(entry);
}
}
}
// Rerank using semantic similarity
let scored = self.rerank_entries(query, all_results).await;
// Apply token budget
let mut filtered = Vec::new();
let mut used_tokens = 0;
for entry in scored {
let tokens = entry.estimated_tokens();
if used_tokens + tokens <= token_budget {
used_tokens += tokens;
filtered.push(entry);
}
if filtered.len() >= max_results {
break;
}
}
Ok(filtered)
}
/// Rerank entries using semantic similarity
async fn rerank_entries(
&self,
query: &str,
entries: Vec<MemoryEntry>,
) -> Vec<MemoryEntry> {
if entries.is_empty() {
return entries;
}
let mut scorer = self.scorer.write().await;
// Index entries for semantic search
for entry in &entries {
scorer.index_entry(entry);
}
// Score each entry
let mut scored: Vec<(f32, MemoryEntry)> = entries
.into_iter()
.map(|entry| {
let score = scorer.score_similarity(query, &entry);
(score, entry)
})
.collect();
// Sort by score (descending), then by importance and access count
scored.sort_by(|a, b| {
b.0.partial_cmp(&a.0)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| b.1.importance.cmp(&a.1.importance))
.then_with(|| b.1.access_count.cmp(&a.1.access_count))
});
scored.into_iter().map(|(_, entry)| entry).collect()
}
/// Retrieve a specific memory by URI (with cache)
pub async fn get_by_uri(&self, uri: &str) -> Result<Option<MemoryEntry>> {
// Check cache first
if let Some(cached) = self.cache.get(uri).await {
return Ok(Some(cached));
}
// Fall back to storage
let result = self.viking.get(uri).await?;
// Update cache
if let Some(ref entry) = result {
self.cache.put(entry.clone()).await;
}
Ok(result)
}
/// Get all memories for an agent (for debugging/admin)
pub async fn get_all_memories(&self, agent_id: &AgentId) -> Result<Vec<MemoryEntry>> {
let scope = format!("agent://{}", agent_id);
let options = FindOptions {
scope: Some(scope),
limit: None,
min_similarity: None,
};
self.viking.find("", options).await
}
/// Get memory statistics for an agent
pub async fn get_stats(&self, agent_id: &AgentId) -> Result<MemoryStats> {
let all = self.get_all_memories(agent_id).await?;
let preference_count = all.iter().filter(|m| m.memory_type == MemoryType::Preference).count();
let knowledge_count = all.iter().filter(|m| m.memory_type == MemoryType::Knowledge).count();
let experience_count = all.iter().filter(|m| m.memory_type == MemoryType::Experience).count();
Ok(MemoryStats {
total_count: all.len(),
preference_count,
knowledge_count,
experience_count,
cache_hit_rate: self.cache.hit_rate().await,
})
}
/// Clear the semantic index
pub async fn clear_index(&self) {
let mut scorer = self.scorer.write().await;
scorer.clear();
}
/// Get cache statistics
pub async fn cache_stats(&self) -> (usize, f32) {
let size = self.cache.size().await;
let hit_rate = self.cache.hit_rate().await;
(size, hit_rate)
}
/// Warm up cache with hot entries
pub async fn warmup_cache(&self, agent_id: &AgentId) -> Result<usize> {
let all = self.get_all_memories(agent_id).await?;
// Sort by access count to get hot entries
let mut sorted = all;
sorted.sort_by(|a, b| b.access_count.cmp(&a.access_count));
// Take top 50 hot entries
let hot: Vec<_> = sorted.into_iter().take(50).collect();
let count = hot.len();
self.cache.warmup(hot).await;
Ok(count)
}
}
/// Memory statistics
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub total_count: usize,
pub preference_count: usize,
pub knowledge_count: usize,
pub experience_count: usize,
pub cache_hit_rate: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retrieval_config_default() {
let config = RetrievalConfig::default();
assert_eq!(config.max_tokens, 500);
assert_eq!(config.preference_budget, 200);
assert_eq!(config.knowledge_budget, 200);
}
#[test]
fn test_memory_type_scope() {
let scope = format!("agent://test-agent/{}", MemoryType::Preference);
assert!(scope.contains("test-agent"));
assert!(scope.contains("preferences"));
}
#[tokio::test]
async fn test_retriever_creation() {
let viking = Arc::new(VikingAdapter::in_memory());
let retriever = MemoryRetriever::new(viking);
let stats = retriever.cache_stats().await;
assert_eq!(stats.0, 0); // Cache size should be 0
}
}

View File

@@ -0,0 +1,9 @@
//! Storage backends for ZCLAW Growth System
//!
//! This module provides multiple storage backend implementations:
//! - `InMemoryStorage`: Fast in-memory storage for testing and development
//! - `SqliteStorage`: Persistent SQLite storage for production use
mod sqlite;
pub use sqlite::SqliteStorage;

View File

@@ -0,0 +1,563 @@
//! SQLite Storage Backend
//!
//! Persistent storage backend using SQLite for production use.
//! Provides efficient querying and full-text search capabilities.
use crate::retrieval::semantic::SemanticScorer;
use crate::types::MemoryEntry;
use crate::viking_adapter::{FindOptions, VikingStorage};
use async_trait::async_trait;
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions, SqliteRow};
use sqlx::Row;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use zclaw_types::Result;
use zclaw_types::ZclawError;
/// SQLite storage backend with TF-IDF semantic scoring
pub struct SqliteStorage {
/// Database connection pool
pool: SqlitePool,
/// Semantic scorer for similarity computation
scorer: Arc<RwLock<SemanticScorer>>,
/// Database path (for reference)
#[allow(dead_code)]
path: PathBuf,
}
/// Database row structure for memory entry
struct MemoryRow {
uri: String,
memory_type: String,
content: String,
keywords: String,
importance: i32,
access_count: i32,
created_at: String,
last_accessed: String,
}
impl SqliteStorage {
/// Create a new SQLite storage at the given path
pub async fn new(path: impl Into<PathBuf>) -> Result<Self> {
let path = path.into();
// Ensure parent directory exists
if let Some(parent) = path.parent() {
if parent.to_str() != Some(":memory:") {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
ZclawError::StorageError(format!("Failed to create storage directory: {}", e))
})?;
}
}
// Build connection string
let db_url = if path.to_str() == Some(":memory:") {
"sqlite::memory:".to_string()
} else {
format!("sqlite:{}?mode=rwc", path.to_string_lossy())
};
// Create connection pool
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect(&db_url)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to connect to database: {}", e)))?;
let storage = Self {
pool,
scorer: Arc::new(RwLock::new(SemanticScorer::new())),
path,
};
storage.initialize_schema().await?;
storage.warmup_scorer().await?;
Ok(storage)
}
/// Create an in-memory SQLite database (for testing)
pub async fn in_memory() -> Self {
Self::new(":memory:").await.expect("Failed to create in-memory database")
}
/// Initialize database schema with FTS5
async fn initialize_schema(&self) -> Result<()> {
// Create main memories table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS memories (
uri TEXT PRIMARY KEY,
memory_type TEXT NOT NULL,
content TEXT NOT NULL,
keywords TEXT NOT NULL DEFAULT '[]',
importance INTEGER NOT NULL DEFAULT 5,
access_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
last_accessed TEXT NOT NULL
)
"#,
)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to create memories table: {}", e)))?;
// Create FTS5 virtual table for full-text search
sqlx::query(
r#"
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
uri,
content,
keywords,
tokenize='unicode61'
)
"#,
)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to create FTS5 table: {}", e)))?;
// Create index on memory_type for filtering
sqlx::query("CREATE INDEX IF NOT EXISTS idx_memory_type ON memories(memory_type)")
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to create index: {}", e)))?;
// Create index on importance for sorting
sqlx::query("CREATE INDEX IF NOT EXISTS idx_importance ON memories(importance DESC)")
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to create importance index: {}", e)))?;
// Create metadata table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY,
json TEXT NOT NULL
)
"#,
)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to create metadata table: {}", e)))?;
tracing::info!("[SqliteStorage] Database schema initialized");
Ok(())
}
/// Warmup semantic scorer with existing entries
async fn warmup_scorer(&self) -> Result<()> {
let rows = sqlx::query_as::<_, MemoryRow>(
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories"
)
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to load memories for warmup: {}", e)))?;
let mut scorer = self.scorer.write().await;
for row in rows {
let entry = self.row_to_entry(&row);
scorer.index_entry(&entry);
}
let stats = scorer.stats();
tracing::info!(
"[SqliteStorage] Warmed up scorer with {} entries, {} terms",
stats.indexed_entries,
stats.unique_terms
);
Ok(())
}
/// Convert database row to MemoryEntry
fn row_to_entry(&self, row: &MemoryRow) -> MemoryEntry {
let memory_type = crate::types::MemoryType::parse(&row.memory_type);
let keywords: Vec<String> = serde_json::from_str(&row.keywords).unwrap_or_default();
let created_at = chrono::DateTime::parse_from_rfc3339(&row.created_at)
.map(|dt| dt.with_timezone(&chrono::Utc))
.unwrap_or_else(|_| chrono::Utc::now());
let last_accessed = chrono::DateTime::parse_from_rfc3339(&row.last_accessed)
.map(|dt| dt.with_timezone(&chrono::Utc))
.unwrap_or_else(|_| chrono::Utc::now());
MemoryEntry {
uri: row.uri.clone(),
memory_type,
content: row.content.clone(),
keywords,
importance: row.importance as u8,
access_count: row.access_count as u32,
created_at,
last_accessed,
}
}
/// Update access count and last accessed time
async fn touch_entry(&self, uri: &str) -> Result<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"UPDATE memories SET access_count = access_count + 1, last_accessed = ? WHERE uri = ?"
)
.bind(&now)
.bind(uri)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to update access count: {}", e)))?;
Ok(())
}
}
impl sqlx::FromRow<'_, SqliteRow> for MemoryRow {
fn from_row(row: &SqliteRow) -> sqlx::Result<Self> {
Ok(MemoryRow {
uri: row.try_get("uri")?,
memory_type: row.try_get("memory_type")?,
content: row.try_get("content")?,
keywords: row.try_get("keywords")?,
importance: row.try_get("importance")?,
access_count: row.try_get("access_count")?,
created_at: row.try_get("created_at")?,
last_accessed: row.try_get("last_accessed")?,
})
}
}
#[async_trait]
impl VikingStorage for SqliteStorage {
async fn store(&self, entry: &MemoryEntry) -> Result<()> {
let keywords_json = serde_json::to_string(&entry.keywords)
.map_err(|e| ZclawError::StorageError(format!("Failed to serialize keywords: {}", e)))?;
let created_at = entry.created_at.to_rfc3339();
let last_accessed = entry.last_accessed.to_rfc3339();
let memory_type = entry.memory_type.to_string();
// Insert into main table
sqlx::query(
r#"
INSERT OR REPLACE INTO memories
(uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&entry.uri)
.bind(&memory_type)
.bind(&entry.content)
.bind(&keywords_json)
.bind(entry.importance as i32)
.bind(entry.access_count as i32)
.bind(&created_at)
.bind(&last_accessed)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to store memory: {}", e)))?;
// Update FTS index - delete old and insert new
let _ = sqlx::query("DELETE FROM memories_fts WHERE uri = ?")
.bind(&entry.uri)
.execute(&self.pool)
.await;
let keywords_text = entry.keywords.join(" ");
let _ = sqlx::query(
r#"
INSERT INTO memories_fts (uri, content, keywords)
VALUES (?, ?, ?)
"#,
)
.bind(&entry.uri)
.bind(&entry.content)
.bind(&keywords_text)
.execute(&self.pool)
.await;
// Update semantic scorer
let mut scorer = self.scorer.write().await;
scorer.index_entry(entry);
tracing::debug!("[SqliteStorage] Stored memory: {}", entry.uri);
Ok(())
}
async fn get(&self, uri: &str) -> Result<Option<MemoryEntry>> {
let row = sqlx::query_as::<_, MemoryRow>(
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories WHERE uri = ?"
)
.bind(uri)
.fetch_optional(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to get memory: {}", e)))?;
if let Some(row) = row {
let entry = self.row_to_entry(&row);
// Update access count
self.touch_entry(&entry.uri).await?;
Ok(Some(entry))
} else {
Ok(None)
}
}
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
// Get all matching entries
let rows = if let Some(ref scope) = options.scope {
sqlx::query_as::<_, MemoryRow>(
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories WHERE uri LIKE ?"
)
.bind(format!("{}%", scope))
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to find memories: {}", e)))?
} else {
sqlx::query_as::<_, MemoryRow>(
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories"
)
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to find memories: {}", e)))?
};
// Convert to entries and compute semantic scores
let scorer = self.scorer.read().await;
let mut scored_entries: Vec<(f32, MemoryEntry)> = Vec::new();
for row in rows {
let entry = self.row_to_entry(&row);
// Compute semantic score using TF-IDF
let semantic_score = scorer.score_similarity(query, &entry);
// Apply similarity threshold
if let Some(min_similarity) = options.min_similarity {
if semantic_score < min_similarity {
continue;
}
}
scored_entries.push((semantic_score, entry));
}
// Sort by score (descending), then by importance and access count
scored_entries.sort_by(|a, b| {
b.0.partial_cmp(&a.0)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| b.1.importance.cmp(&a.1.importance))
.then_with(|| b.1.access_count.cmp(&a.1.access_count))
});
// Apply limit
if let Some(limit) = options.limit {
scored_entries.truncate(limit);
}
Ok(scored_entries.into_iter().map(|(_, entry)| entry).collect())
}
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
let rows = sqlx::query_as::<_, MemoryRow>(
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed FROM memories WHERE uri LIKE ?"
)
.bind(format!("{}%", prefix))
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to find by prefix: {}", e)))?;
let entries = rows.iter().map(|row| self.row_to_entry(row)).collect();
Ok(entries)
}
async fn delete(&self, uri: &str) -> Result<()> {
sqlx::query("DELETE FROM memories WHERE uri = ?")
.bind(uri)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to delete memory: {}", e)))?;
// Remove from FTS
let _ = sqlx::query("DELETE FROM memories_fts WHERE uri = ?")
.bind(uri)
.execute(&self.pool)
.await;
// Remove from scorer
let mut scorer = self.scorer.write().await;
scorer.remove_entry(uri);
tracing::debug!("[SqliteStorage] Deleted memory: {}", uri);
Ok(())
}
async fn store_metadata_json(&self, key: &str, json: &str) -> Result<()> {
sqlx::query(
r#"
INSERT OR REPLACE INTO metadata (key, json)
VALUES (?, ?)
"#,
)
.bind(key)
.bind(json)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to store metadata: {}", e)))?;
Ok(())
}
async fn get_metadata_json(&self, key: &str) -> Result<Option<String>> {
let result = sqlx::query_scalar::<_, String>("SELECT json FROM metadata WHERE key = ?")
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(format!("Failed to get metadata: {}", e)))?;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
#[tokio::test]
async fn test_sqlite_storage_store_and_get() {
let storage = SqliteStorage::in_memory().await;
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"User prefers concise responses".to_string(),
);
storage.store(&entry).await.unwrap();
let retrieved = storage.get(&entry.uri).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "User prefers concise responses");
}
#[tokio::test]
async fn test_sqlite_storage_semantic_search() {
let storage = SqliteStorage::in_memory().await;
// Store entries with different content
let entry1 = MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"rust",
"Rust is a systems programming language focused on safety".to_string(),
).with_keywords(vec!["rust".to_string(), "programming".to_string(), "safety".to_string()]);
let entry2 = MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"python",
"Python is a high-level programming language".to_string(),
).with_keywords(vec!["python".to_string(), "programming".to_string()]);
storage.store(&entry1).await.unwrap();
storage.store(&entry2).await.unwrap();
// Search for "rust safety"
let results = storage.find(
"rust safety",
FindOptions {
scope: Some("agent://agent-1".to_string()),
limit: Some(10),
min_similarity: Some(0.1),
},
).await.unwrap();
// Should find the Rust entry with higher score
assert!(!results.is_empty());
assert!(results[0].content.contains("Rust"));
}
#[tokio::test]
async fn test_sqlite_storage_delete() {
let storage = SqliteStorage::in_memory().await;
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test".to_string(),
);
storage.store(&entry).await.unwrap();
storage.delete(&entry.uri).await.unwrap();
let retrieved = storage.get(&entry.uri).await.unwrap();
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_persistence() {
let path = std::env::temp_dir().join("zclaw_test_memories.db");
// Clean up any existing test db
let _ = std::fs::remove_file(&path);
// Create and store
{
let storage = SqliteStorage::new(&path).await.unwrap();
let entry = MemoryEntry::new(
"persist-test",
MemoryType::Knowledge,
"test",
"This should persist".to_string(),
);
storage.store(&entry).await.unwrap();
}
// Reopen and verify
{
let storage = SqliteStorage::new(&path).await.unwrap();
let results = storage.find_by_prefix("agent://persist-test").await.unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].content, "This should persist");
}
// Clean up
let _ = std::fs::remove_file(&path);
}
#[tokio::test]
async fn test_metadata_storage() {
let storage = SqliteStorage::in_memory().await;
let json = r#"{"test": "value"}"#;
storage.store_metadata_json("test-key", json).await.unwrap();
let retrieved = storage.get_metadata_json("test-key").await.unwrap();
assert_eq!(retrieved, Some(json.to_string()));
}
#[tokio::test]
async fn test_access_count() {
let storage = SqliteStorage::in_memory().await;
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Knowledge,
"test",
"test content".to_string(),
);
storage.store(&entry).await.unwrap();
// Access multiple times
for _ in 0..3 {
let _ = storage.get(&entry.uri).await.unwrap();
}
let retrieved = storage.get(&entry.uri).await.unwrap().unwrap();
assert!(retrieved.access_count >= 3);
}
}

View File

@@ -0,0 +1,212 @@
//! Growth Tracker - Tracks agent growth metrics and evolution
//!
//! This module provides the `GrowthTracker` which monitors and records
//! the evolution of an agent's capabilities and knowledge over time.
use crate::types::{GrowthStats, MemoryType};
use crate::viking_adapter::VikingAdapter;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use zclaw_types::{AgentId, Result};
/// Growth Tracker - tracks agent growth metrics
pub struct GrowthTracker {
/// OpenViking adapter for storage
viking: Arc<VikingAdapter>,
}
impl GrowthTracker {
/// Create a new growth tracker
pub fn new(viking: Arc<VikingAdapter>) -> Self {
Self { viking }
}
/// Get current growth statistics for an agent
pub async fn get_stats(&self, agent_id: &AgentId) -> Result<GrowthStats> {
// Query all memories for the agent
let memories = self.viking.find_by_prefix(&format!("agent://{}", agent_id)).await?;
let mut stats = GrowthStats::default();
stats.total_memories = memories.len();
for memory in &memories {
match memory.memory_type {
MemoryType::Preference => stats.preference_count += 1,
MemoryType::Knowledge => stats.knowledge_count += 1,
MemoryType::Experience => stats.experience_count += 1,
MemoryType::Session => stats.sessions_processed += 1,
}
}
// Get last learning time from metadata
let meta: Option<AgentMetadata> = self.viking
.get_metadata(&format!("agent://{}", agent_id))
.await?;
if let Some(meta) = meta {
stats.last_learning_time = meta.last_learning_time;
}
Ok(stats)
}
/// Record a learning event
pub async fn record_learning(
&self,
agent_id: &AgentId,
session_id: &str,
memories_extracted: usize,
) -> Result<()> {
let event = LearningEvent {
agent_id: agent_id.to_string(),
session_id: session_id.to_string(),
memories_extracted,
timestamp: Utc::now(),
};
// Store learning event
self.viking
.store_metadata(
&format!("agent://{}/events/{}", agent_id, session_id),
&event,
)
.await?;
// Update last learning time
self.viking
.store_metadata(
&format!("agent://{}", agent_id),
&AgentMetadata {
last_learning_time: Some(Utc::now()),
total_learning_events: None, // Will be computed
},
)
.await?;
tracing::info!(
"[GrowthTracker] Recorded learning event: agent={}, session={}, memories={}",
agent_id,
session_id,
memories_extracted
);
Ok(())
}
/// Get growth timeline for an agent
pub async fn get_timeline(&self, agent_id: &AgentId) -> Result<Vec<LearningEvent>> {
let memories = self
.viking
.find_by_prefix(&format!("agent://{}/events/", agent_id))
.await?;
// Parse events from stored memory content
let mut timeline = Vec::new();
for memory in memories {
if let Ok(event) = serde_json::from_str::<LearningEvent>(&memory.content) {
timeline.push(event);
}
}
// Sort by timestamp descending
timeline.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
Ok(timeline)
}
/// Calculate growth velocity (memories per day)
pub async fn get_growth_velocity(&self, agent_id: &AgentId) -> Result<f64> {
let timeline = self.get_timeline(agent_id).await?;
if timeline.is_empty() {
return Ok(0.0);
}
// Get first and last event
let first = timeline.iter().min_by_key(|e| e.timestamp);
let last = timeline.iter().max_by_key(|e| e.timestamp);
match (first, last) {
(Some(first), Some(last)) => {
let days = (last.timestamp - first.timestamp).num_days().max(1) as f64;
let total_memories: usize = timeline.iter().map(|e| e.memories_extracted).sum();
Ok(total_memories as f64 / days)
}
_ => Ok(0.0),
}
}
/// Get memory distribution by category
pub async fn get_memory_distribution(
&self,
agent_id: &AgentId,
) -> Result<HashMap<String, usize>> {
let memories = self.viking.find_by_prefix(&format!("agent://{}", agent_id)).await?;
let mut distribution = HashMap::new();
for memory in memories {
*distribution.entry(memory.memory_type.to_string()).or_insert(0) += 1;
}
Ok(distribution)
}
}
/// Learning event record
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningEvent {
/// Agent ID
pub agent_id: String,
/// Session ID where learning occurred
pub session_id: String,
/// Number of memories extracted
pub memories_extracted: usize,
/// Event timestamp
pub timestamp: DateTime<Utc>,
}
/// Agent metadata stored in OpenViking
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentMetadata {
/// Last learning time
pub last_learning_time: Option<DateTime<Utc>>,
/// Total learning events (computed)
pub total_learning_events: Option<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_learning_event_serialization() {
let event = LearningEvent {
agent_id: "test-agent".to_string(),
session_id: "test-session".to_string(),
memories_extracted: 5,
timestamp: Utc::now(),
};
let json = serde_json::to_string(&event).unwrap();
let parsed: LearningEvent = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.agent_id, event.agent_id);
assert_eq!(parsed.memories_extracted, event.memories_extracted);
}
#[test]
fn test_agent_metadata_serialization() {
let meta = AgentMetadata {
last_learning_time: Some(Utc::now()),
total_learning_events: Some(10),
};
let json = serde_json::to_string(&meta).unwrap();
let parsed: AgentMetadata = serde_json::from_str(&json).unwrap();
assert!(parsed.last_learning_time.is_some());
assert_eq!(parsed.total_learning_events, Some(10));
}
}

View File

@@ -0,0 +1,486 @@
//! Core type definitions for the ZCLAW Growth System
//!
//! This module defines the fundamental types used for memory management,
//! extraction, retrieval, and prompt injection.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use zclaw_types::SessionId;
/// Memory type classification
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryType {
/// User preferences (communication style, format, language, etc.)
Preference,
/// Accumulated knowledge (user facts, domain knowledge, lessons learned)
Knowledge,
/// Skill/tool usage experience
Experience,
/// Conversation session history
Session,
}
impl std::fmt::Display for MemoryType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MemoryType::Preference => write!(f, "preferences"),
MemoryType::Knowledge => write!(f, "knowledge"),
MemoryType::Experience => write!(f, "experience"),
MemoryType::Session => write!(f, "sessions"),
}
}
}
impl std::str::FromStr for MemoryType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"preferences" | "preference" => Ok(MemoryType::Preference),
"knowledge" => Ok(MemoryType::Knowledge),
"experience" => Ok(MemoryType::Experience),
"sessions" | "session" => Ok(MemoryType::Session),
_ => Err(format!("Unknown memory type: {}", s)),
}
}
}
impl MemoryType {
/// Parse memory type from string (returns Knowledge as default)
pub fn parse(s: &str) -> Self {
s.parse().unwrap_or(MemoryType::Knowledge)
}
}
/// Memory entry stored in OpenViking
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
/// URI in OpenViking format: agent://{agent_id}/{type}/{category}
pub uri: String,
/// Type of memory
pub memory_type: MemoryType,
/// Memory content
pub content: String,
/// Keywords for semantic search
pub keywords: Vec<String>,
/// Importance score (1-10)
pub importance: u8,
/// Number of times accessed
pub access_count: u32,
/// Creation timestamp
pub created_at: DateTime<Utc>,
/// Last access timestamp
pub last_accessed: DateTime<Utc>,
}
impl MemoryEntry {
/// Create a new memory entry
pub fn new(
agent_id: &str,
memory_type: MemoryType,
category: &str,
content: String,
) -> Self {
let uri = format!("agent://{}/{}/{}", agent_id, memory_type, category);
Self {
uri,
memory_type,
content,
keywords: Vec::new(),
importance: 5,
access_count: 0,
created_at: Utc::now(),
last_accessed: Utc::now(),
}
}
/// Add keywords to the memory entry
pub fn with_keywords(mut self, keywords: Vec<String>) -> Self {
self.keywords = keywords;
self
}
/// Set importance score
pub fn with_importance(mut self, importance: u8) -> Self {
self.importance = importance.min(10).max(1);
self
}
/// Mark as accessed
pub fn touch(&mut self) {
self.access_count += 1;
self.last_accessed = Utc::now();
}
/// Estimate token count (roughly 4 characters per token for mixed content)
/// More accurate estimation considering Chinese characters (1.5 tokens avg)
pub fn estimated_tokens(&self) -> usize {
let char_count = self.content.chars().count();
let cjk_count = self.content.chars().filter(|c| is_cjk(*c)).count();
let non_cjk_count = char_count - cjk_count;
// CJK: ~1.5 tokens per char, non-CJK: ~0.25 tokens per char
(cjk_count as f32 * 1.5 + non_cjk_count as f32 * 0.25).ceil() as usize
}
}
/// Extracted memory from conversation analysis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedMemory {
/// Type of extracted memory
pub memory_type: MemoryType,
/// Category within the memory type
pub category: String,
/// Memory content
pub content: String,
/// Extraction confidence (0.0 - 1.0)
pub confidence: f32,
/// Source session ID
pub source_session: SessionId,
/// Keywords extracted
pub keywords: Vec<String>,
}
impl ExtractedMemory {
/// Create a new extracted memory
pub fn new(
memory_type: MemoryType,
category: impl Into<String>,
content: impl Into<String>,
source_session: SessionId,
) -> Self {
Self {
memory_type,
category: category.into(),
content: content.into(),
confidence: 0.8,
source_session,
keywords: Vec::new(),
}
}
/// Set confidence score
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = confidence.clamp(0.0, 1.0);
self
}
/// Add keywords
pub fn with_keywords(mut self, keywords: Vec<String>) -> Self {
self.keywords = keywords;
self
}
/// Convert to MemoryEntry for storage
pub fn to_memory_entry(&self, agent_id: &str) -> MemoryEntry {
MemoryEntry::new(agent_id, self.memory_type, &self.category, self.content.clone())
.with_keywords(self.keywords.clone())
}
}
/// Retrieval configuration
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
/// Total token budget for retrieved memories
pub max_tokens: usize,
/// Token budget for preferences
pub preference_budget: usize,
/// Token budget for knowledge
pub knowledge_budget: usize,
/// Token budget for experience
pub experience_budget: usize,
/// Minimum similarity threshold (0.0 - 1.0)
pub min_similarity: f32,
/// Maximum number of results per type
pub max_results_per_type: usize,
}
/// Check if character is CJK
fn is_cjk(c: char) -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
'\u{3400}'..='\u{4DBF}' | // CJK Unified Ideographs Extension A
'\u{20000}'..='\u{2A6DF}' | // CJK Unified Ideographs Extension B
'\u{F900}'..='\u{FAFF}' | // CJK Compatibility Ideographs
'\u{3040}'..='\u{309F}' | // Hiragana
'\u{30A0}'..='\u{30FF}' | // Katakana
'\u{AC00}'..='\u{D7AF}' // Hangul
)
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
max_tokens: 500,
preference_budget: 200,
knowledge_budget: 200,
experience_budget: 100,
min_similarity: 0.7,
max_results_per_type: 5,
}
}
}
impl RetrievalConfig {
/// Create a config with custom token budget
pub fn with_budget(max_tokens: usize) -> Self {
let pref = (max_tokens as f32 * 0.4) as usize;
let knowledge = (max_tokens as f32 * 0.4) as usize;
let exp = max_tokens.saturating_sub(pref).saturating_sub(knowledge);
Self {
max_tokens,
preference_budget: pref,
knowledge_budget: knowledge,
experience_budget: exp,
min_similarity: 0.7,
max_results_per_type: 5,
}
}
}
/// Retrieval result containing memories by type
#[derive(Debug, Clone, Default)]
pub struct RetrievalResult {
/// Retrieved preferences
pub preferences: Vec<MemoryEntry>,
/// Retrieved knowledge
pub knowledge: Vec<MemoryEntry>,
/// Retrieved experience
pub experience: Vec<MemoryEntry>,
/// Total tokens used
pub total_tokens: usize,
}
impl RetrievalResult {
/// Check if result is empty
pub fn is_empty(&self) -> bool {
self.preferences.is_empty()
&& self.knowledge.is_empty()
&& self.experience.is_empty()
}
/// Get total memory count
pub fn total_count(&self) -> usize {
self.preferences.len() + self.knowledge.len() + self.experience.len()
}
/// Calculate total tokens from entries
pub fn calculate_tokens(&self) -> usize {
let tokens: usize = self.preferences.iter()
.chain(self.knowledge.iter())
.chain(self.experience.iter())
.map(|m| m.estimated_tokens())
.sum();
tokens
}
}
/// Extraction configuration
#[derive(Debug, Clone)]
pub struct ExtractionConfig {
/// Extract preferences from conversation
pub extract_preferences: bool,
/// Extract knowledge from conversation
pub extract_knowledge: bool,
/// Extract experience from conversation
pub extract_experience: bool,
/// Minimum confidence threshold for extraction
pub min_confidence: f32,
}
impl Default for ExtractionConfig {
fn default() -> Self {
Self {
extract_preferences: true,
extract_knowledge: true,
extract_experience: true,
min_confidence: 0.6,
}
}
}
/// Growth statistics for an agent
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GrowthStats {
/// Total number of memories
pub total_memories: usize,
/// Number of preferences
pub preference_count: usize,
/// Number of knowledge entries
pub knowledge_count: usize,
/// Number of experience entries
pub experience_count: usize,
/// Total sessions processed
pub sessions_processed: usize,
/// Last learning timestamp
pub last_learning_time: Option<DateTime<Utc>>,
/// Average extraction confidence
pub avg_confidence: f32,
}
/// OpenViking URI builder
pub struct UriBuilder;
impl UriBuilder {
/// Build a preference URI
pub fn preference(agent_id: &str, category: &str) -> String {
format!("agent://{}/preferences/{}", agent_id, category)
}
/// Build a knowledge URI
pub fn knowledge(agent_id: &str, domain: &str) -> String {
format!("agent://{}/knowledge/{}", agent_id, domain)
}
/// Build an experience URI
pub fn experience(agent_id: &str, skill_id: &str) -> String {
format!("agent://{}/experience/{}", agent_id, skill_id)
}
/// Build a session URI
pub fn session(agent_id: &str, session_id: &str) -> String {
format!("agent://{}/sessions/{}", agent_id, session_id)
}
/// Parse agent ID from URI
pub fn parse_agent_id(uri: &str) -> Option<&str> {
uri.strip_prefix("agent://")?
.split('/')
.next()
}
/// Parse memory type from URI
pub fn parse_memory_type(uri: &str) -> Option<MemoryType> {
let after_agent = uri.strip_prefix("agent://")?;
let mut parts = after_agent.split('/');
parts.next()?; // Skip agent_id
match parts.next()? {
"preferences" => Some(MemoryType::Preference),
"knowledge" => Some(MemoryType::Knowledge),
"experience" => Some(MemoryType::Experience),
"sessions" => Some(MemoryType::Session),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_type_display() {
assert_eq!(format!("{}", MemoryType::Preference), "preferences");
assert_eq!(format!("{}", MemoryType::Knowledge), "knowledge");
assert_eq!(format!("{}", MemoryType::Experience), "experience");
assert_eq!(format!("{}", MemoryType::Session), "sessions");
}
#[test]
fn test_memory_entry_creation() {
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"communication-style",
"User prefers concise responses".to_string(),
);
assert_eq!(entry.uri, "agent://test-agent/preferences/communication-style");
assert_eq!(entry.importance, 5);
assert_eq!(entry.access_count, 0);
}
#[test]
fn test_memory_entry_touch() {
let mut entry = MemoryEntry::new(
"test-agent",
MemoryType::Knowledge,
"domain",
"content".to_string(),
);
entry.touch();
assert_eq!(entry.access_count, 1);
}
#[test]
fn test_estimated_tokens() {
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"test",
"This is a test content that should be around 10 tokens".to_string(),
);
// ~40 chars / 4 = ~10 tokens
assert!(entry.estimated_tokens() > 5);
assert!(entry.estimated_tokens() < 20);
}
#[test]
fn test_retrieval_config_default() {
let config = RetrievalConfig::default();
assert_eq!(config.max_tokens, 500);
assert_eq!(config.preference_budget, 200);
assert_eq!(config.knowledge_budget, 200);
assert_eq!(config.experience_budget, 100);
}
#[test]
fn test_retrieval_config_with_budget() {
let config = RetrievalConfig::with_budget(1000);
assert_eq!(config.max_tokens, 1000);
assert!(config.preference_budget >= 350);
assert!(config.knowledge_budget >= 350);
}
#[test]
fn test_uri_builder() {
let pref_uri = UriBuilder::preference("agent-1", "style");
assert_eq!(pref_uri, "agent://agent-1/preferences/style");
let knowledge_uri = UriBuilder::knowledge("agent-1", "rust");
assert_eq!(knowledge_uri, "agent://agent-1/knowledge/rust");
let exp_uri = UriBuilder::experience("agent-1", "browser");
assert_eq!(exp_uri, "agent://agent-1/experience/browser");
let session_uri = UriBuilder::session("agent-1", "session-123");
assert_eq!(session_uri, "agent://agent-1/sessions/session-123");
}
#[test]
fn test_uri_parser() {
let uri = "agent://agent-1/preferences/style";
assert_eq!(UriBuilder::parse_agent_id(uri), Some("agent-1"));
assert_eq!(UriBuilder::parse_memory_type(uri), Some(MemoryType::Preference));
let invalid_uri = "invalid-uri";
assert!(UriBuilder::parse_agent_id(invalid_uri).is_none());
assert!(UriBuilder::parse_memory_type(invalid_uri).is_none());
}
#[test]
fn test_retrieval_result() {
let result = RetrievalResult::default();
assert!(result.is_empty());
assert_eq!(result.total_count(), 0);
let result = RetrievalResult {
preferences: vec![MemoryEntry::new(
"agent-1",
MemoryType::Preference,
"style",
"test".to_string(),
)],
knowledge: vec![],
experience: vec![],
total_tokens: 0,
};
assert!(!result.is_empty());
assert_eq!(result.total_count(), 1);
}
}

View File

@@ -0,0 +1,362 @@
//! OpenViking Adapter - Interface to the OpenViking memory system
//!
//! This module provides the `VikingAdapter` which wraps the OpenViking
//! context database for storing and retrieving agent memories.
use crate::types::MemoryEntry;
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use zclaw_types::Result;
/// Search options for find operations
#[derive(Debug, Clone, Default)]
pub struct FindOptions {
/// Scope to search within (URI prefix)
pub scope: Option<String>,
/// Maximum results to return
pub limit: Option<usize>,
/// Minimum similarity threshold
pub min_similarity: Option<f32>,
}
/// VikingStorage trait - core storage operations (dyn-compatible)
#[async_trait]
pub trait VikingStorage: Send + Sync {
/// Store a memory entry
async fn store(&self, entry: &MemoryEntry) -> Result<()>;
/// Get a memory entry by URI
async fn get(&self, uri: &str) -> Result<Option<MemoryEntry>>;
/// Find memories by query with options
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>>;
/// Find memories by URI prefix
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>>;
/// Delete a memory by URI
async fn delete(&self, uri: &str) -> Result<()>;
/// Store metadata as JSON string
async fn store_metadata_json(&self, key: &str, json: &str) -> Result<()>;
/// Get metadata as JSON string
async fn get_metadata_json(&self, key: &str) -> Result<Option<String>>;
}
/// OpenViking adapter implementation
#[derive(Clone)]
pub struct VikingAdapter {
/// Storage backend
backend: Arc<dyn VikingStorage>,
}
impl VikingAdapter {
/// Create a new Viking adapter with a storage backend
pub fn new(backend: Arc<dyn VikingStorage>) -> Self {
Self { backend }
}
/// Create with in-memory storage (for testing)
pub fn in_memory() -> Self {
Self {
backend: Arc::new(InMemoryStorage::new()),
}
}
/// Store a memory entry
pub async fn store(&self, entry: &MemoryEntry) -> Result<()> {
self.backend.store(entry).await
}
/// Get a memory entry by URI
pub async fn get(&self, uri: &str) -> Result<Option<MemoryEntry>> {
self.backend.get(uri).await
}
/// Find memories by query
pub async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
self.backend.find(query, options).await
}
/// Find memories by URI prefix
pub async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
self.backend.find_by_prefix(prefix).await
}
/// Delete a memory
pub async fn delete(&self, uri: &str) -> Result<()> {
self.backend.delete(uri).await
}
/// Store metadata (typed)
pub async fn store_metadata<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
let json = serde_json::to_string(value)?;
self.backend.store_metadata_json(key, &json).await
}
/// Get metadata (typed)
pub async fn get_metadata<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
match self.backend.get_metadata_json(key).await? {
Some(json) => {
let value: T = serde_json::from_str(&json)?;
Ok(Some(value))
}
None => Ok(None),
}
}
}
/// In-memory storage backend (for testing and development)
pub struct InMemoryStorage {
memories: std::sync::RwLock<HashMap<String, MemoryEntry>>,
metadata: std::sync::RwLock<HashMap<String, String>>,
}
impl InMemoryStorage {
/// Create a new in-memory storage
pub fn new() -> Self {
Self {
memories: std::sync::RwLock::new(HashMap::new()),
metadata: std::sync::RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemoryStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl VikingStorage for InMemoryStorage {
async fn store(&self, entry: &MemoryEntry) -> Result<()> {
let mut memories = self.memories.write().unwrap();
memories.insert(entry.uri.clone(), entry.clone());
Ok(())
}
async fn get(&self, uri: &str) -> Result<Option<MemoryEntry>> {
let memories = self.memories.read().unwrap();
Ok(memories.get(uri).cloned())
}
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
let memories = self.memories.read().unwrap();
let mut results: Vec<MemoryEntry> = memories
.values()
.filter(|entry| {
// Apply scope filter
if let Some(ref scope) = options.scope {
if !entry.uri.starts_with(scope) {
return false;
}
}
// Simple text matching (in real implementation, use semantic search)
if !query.is_empty() {
let query_lower = query.to_lowercase();
let content_lower = entry.content.to_lowercase();
let keywords_match = entry.keywords.iter().any(|k| k.to_lowercase().contains(&query_lower));
content_lower.contains(&query_lower) || keywords_match
} else {
true
}
})
.cloned()
.collect();
// Sort by importance and access count
results.sort_by(|a, b| {
b.importance
.cmp(&a.importance)
.then_with(|| b.access_count.cmp(&a.access_count))
});
// Apply limit
if let Some(limit) = options.limit {
results.truncate(limit);
}
Ok(results)
}
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
let memories = self.memories.read().unwrap();
let results: Vec<MemoryEntry> = memories
.values()
.filter(|entry| entry.uri.starts_with(prefix))
.cloned()
.collect();
Ok(results)
}
async fn delete(&self, uri: &str) -> Result<()> {
let mut memories = self.memories.write().unwrap();
memories.remove(uri);
Ok(())
}
async fn store_metadata_json(&self, key: &str, json: &str) -> Result<()> {
let mut metadata = self.metadata.write().unwrap();
metadata.insert(key.to_string(), json.to_string());
Ok(())
}
async fn get_metadata_json(&self, key: &str) -> Result<Option<String>> {
let metadata = self.metadata.read().unwrap();
Ok(metadata.get(key).cloned())
}
}
/// OpenViking levels for storage
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VikingLevel {
/// L0: Raw data (original content)
L0,
/// L1: Summarized content
L1,
/// L2: Keywords and metadata
L2,
}
impl std::fmt::Display for VikingLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VikingLevel::L0 => write!(f, "L0"),
VikingLevel::L1 => write!(f, "L1"),
VikingLevel::L2 => write!(f, "L2"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
#[tokio::test]
async fn test_in_memory_storage_store_and_get() {
let storage = InMemoryStorage::new();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test content".to_string(),
);
storage.store(&entry).await.unwrap();
let retrieved = storage.get(&entry.uri).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "test content");
}
#[tokio::test]
async fn test_in_memory_storage_find() {
let storage = InMemoryStorage::new();
let entry1 = MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"rust",
"Rust programming tips".to_string(),
);
let entry2 = MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"python",
"Python programming tips".to_string(),
);
storage.store(&entry1).await.unwrap();
storage.store(&entry2).await.unwrap();
let results = storage
.find(
"Rust",
FindOptions {
scope: Some("agent://agent-1".to_string()),
limit: Some(10),
min_similarity: None,
},
)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("Rust"));
}
#[tokio::test]
async fn test_in_memory_storage_delete() {
let storage = InMemoryStorage::new();
let entry = MemoryEntry::new(
"test-agent",
MemoryType::Preference,
"style",
"test".to_string(),
);
storage.store(&entry).await.unwrap();
storage.delete(&entry.uri).await.unwrap();
let retrieved = storage.get(&entry.uri).await.unwrap();
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_metadata_storage() {
let storage = InMemoryStorage::new();
#[derive(Serialize, serde::Deserialize)]
struct TestData {
value: String,
}
let data = TestData {
value: "test".to_string(),
};
storage.store_metadata_json("test-key", &serde_json::to_string(&data).unwrap()).await.unwrap();
let json = storage.get_metadata_json("test-key").await.unwrap();
assert!(json.is_some());
let retrieved: TestData = serde_json::from_str(&json.unwrap()).unwrap();
assert_eq!(retrieved.value, "test");
}
#[tokio::test]
async fn test_viking_adapter_typed_metadata() {
let adapter = VikingAdapter::in_memory();
#[derive(Serialize, serde::Deserialize)]
struct TestData {
value: String,
}
let data = TestData {
value: "test".to_string(),
};
adapter.store_metadata("test-key", &data).await.unwrap();
let retrieved: Option<TestData> = adapter.get_metadata("test-key").await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().value, "test");
}
#[test]
fn test_viking_level_display() {
assert_eq!(format!("{}", VikingLevel::L0), "L0");
assert_eq!(format!("{}", VikingLevel::L1), "L1");
assert_eq!(format!("{}", VikingLevel::L2), "L2");
}
}

View File

@@ -0,0 +1,412 @@
//! Integration tests for ZCLAW Growth System
//!
//! Tests the complete flow: store → find → inject
use std::sync::Arc;
use zclaw_growth::{
FindOptions, MemoryEntry, MemoryRetriever, MemoryType, PromptInjector,
RetrievalConfig, RetrievalResult, SqliteStorage, VikingAdapter,
};
use zclaw_types::AgentId;
/// Test complete memory lifecycle
#[tokio::test]
async fn test_memory_lifecycle() {
let storage = Arc::new(SqliteStorage::in_memory().await);
let adapter = Arc::new(VikingAdapter::new(storage));
// Create agent ID and use its string form for storage
let agent_id = AgentId::new();
let agent_str = agent_id.to_string();
// 1. Store a preference
let pref = MemoryEntry::new(
&agent_str,
MemoryType::Preference,
"communication-style",
"用户偏好简洁的回复,不喜欢冗长的解释".to_string(),
)
.with_keywords(vec!["简洁".to_string(), "沟通风格".to_string()])
.with_importance(8);
adapter.store(&pref).await.unwrap();
// 2. Store knowledge
let knowledge = MemoryEntry::new(
&agent_str,
MemoryType::Knowledge,
"rust-expertise",
"用户是 Rust 开发者,熟悉 async/await 和 trait 系统".to_string(),
)
.with_keywords(vec!["Rust".to_string(), "开发者".to_string()]);
adapter.store(&knowledge).await.unwrap();
// 3. Store experience
let experience = MemoryEntry::new(
&agent_str,
MemoryType::Experience,
"browser-skill",
"浏览器技能在搜索技术文档时效果很好".to_string(),
)
.with_keywords(vec!["浏览器".to_string(), "技能".to_string()]);
adapter.store(&experience).await.unwrap();
// 4. Retrieve memories - directly from adapter first
let direct_results = adapter
.find(
"Rust",
FindOptions {
scope: Some(format!("agent://{}", agent_str)),
limit: Some(10),
min_similarity: Some(0.1),
},
)
.await
.unwrap();
println!("Direct find results: {:?}", direct_results.len());
let retriever = MemoryRetriever::new(adapter.clone());
// Use lower similarity threshold for testing
let config = RetrievalConfig {
min_similarity: 0.1,
..RetrievalConfig::default()
};
let retriever = retriever.with_config(config);
let result = retriever
.retrieve(&agent_id, "Rust 编程")
.await
.unwrap();
println!("Knowledge results: {:?}", result.knowledge.len());
println!("Preferences results: {:?}", result.preferences.len());
println!("Experience results: {:?}", result.experience.len());
// Should find the knowledge entry
assert!(!result.knowledge.is_empty(), "Expected to find knowledge entries but found none. Direct results: {}", direct_results.len());
assert!(result.knowledge[0].content.contains("Rust"));
// 5. Inject into prompt
let injector = PromptInjector::new();
let base_prompt = "你是一个有帮助的 AI 助手。";
let enhanced = injector.inject_with_format(base_prompt, &result);
// Enhanced prompt should contain memory context
assert!(enhanced.len() > base_prompt.len());
}
/// Test semantic search ranking
#[tokio::test]
async fn test_semantic_search_ranking() {
let storage = Arc::new(SqliteStorage::in_memory().await);
let adapter = Arc::new(VikingAdapter::new(storage.clone()));
// Store multiple entries with different relevance
let entries = vec![
MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"rust-basics",
"Rust 是一门系统编程语言,注重安全性和性能".to_string(),
)
.with_keywords(vec!["Rust".to_string(), "系统编程".to_string()]),
MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"python-basics",
"Python 是一门高级编程语言,易于学习".to_string(),
)
.with_keywords(vec!["Python".to_string(), "高级语言".to_string()]),
MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"rust-async",
"Rust 的 async/await 语法用于异步编程".to_string(),
)
.with_keywords(vec!["Rust".to_string(), "async".to_string(), "异步".to_string()]),
];
for entry in &entries {
adapter.store(entry).await.unwrap();
}
// Search for "Rust 异步编程"
let results = adapter
.find(
"Rust 异步编程",
FindOptions {
scope: Some("agent://agent-1".to_string()),
limit: Some(10),
min_similarity: Some(0.1),
},
)
.await
.unwrap();
// Rust async entry should rank highest
assert!(!results.is_empty());
assert!(results[0].content.contains("async") || results[0].content.contains("Rust"));
}
/// Test memory importance and access count
#[tokio::test]
async fn test_importance_and_access() {
let storage = Arc::new(SqliteStorage::in_memory().await);
let adapter = Arc::new(VikingAdapter::new(storage.clone()));
// Create entries with different importance
let high_importance = MemoryEntry::new(
"agent-1",
MemoryType::Preference,
"critical",
"这是非常重要的偏好".to_string(),
)
.with_importance(10);
let low_importance = MemoryEntry::new(
"agent-1",
MemoryType::Preference,
"minor",
"这是不太重要的偏好".to_string(),
)
.with_importance(2);
adapter.store(&high_importance).await.unwrap();
adapter.store(&low_importance).await.unwrap();
// Access the low importance one multiple times
for _ in 0..5 {
let _ = adapter.get(&low_importance.uri).await;
}
// Search should consider both importance and access count
let results = adapter
.find(
"偏好",
FindOptions {
scope: Some("agent://agent-1".to_string()),
limit: Some(10),
min_similarity: None,
},
)
.await
.unwrap();
assert_eq!(results.len(), 2);
}
/// Test prompt injection with token budget
#[tokio::test]
async fn test_prompt_injection_token_budget() {
let mut result = RetrievalResult::default();
// Add memories that exceed budget
for i in 0..10 {
result.preferences.push(
MemoryEntry::new(
"agent-1",
MemoryType::Preference,
&format!("pref-{}", i),
"这是一个很长的偏好描述,用于测试 token 预算控制功能。".repeat(5),
),
);
}
result.total_tokens = result.calculate_tokens();
// Budget is 500 tokens by default
let injector = PromptInjector::new();
let base = "Base prompt";
let enhanced = injector.inject_with_format(base, &result);
// Should include memory context
assert!(enhanced.len() > base.len());
}
/// Test metadata storage
#[tokio::test]
async fn test_metadata_operations() {
let storage = Arc::new(SqliteStorage::in_memory().await);
let adapter = Arc::new(VikingAdapter::new(storage));
// Store metadata using typed API
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct Config {
version: String,
auto_extract: bool,
}
let config = Config {
version: "1.0.0".to_string(),
auto_extract: true,
};
adapter.store_metadata("agent-config", &config).await.unwrap();
// Retrieve metadata
let retrieved: Option<Config> = adapter.get_metadata("agent-config").await.unwrap();
assert!(retrieved.is_some());
let parsed = retrieved.unwrap();
assert_eq!(parsed.version, "1.0.0");
assert_eq!(parsed.auto_extract, true);
}
/// Test memory deletion and cleanup
#[tokio::test]
async fn test_memory_deletion() {
let storage = Arc::new(SqliteStorage::in_memory().await);
let adapter = Arc::new(VikingAdapter::new(storage));
let entry = MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"temp",
"Temporary knowledge".to_string(),
);
adapter.store(&entry).await.unwrap();
// Verify stored
let retrieved = adapter.get(&entry.uri).await.unwrap();
assert!(retrieved.is_some());
// Delete
adapter.delete(&entry.uri).await.unwrap();
// Verify deleted
let retrieved = adapter.get(&entry.uri).await.unwrap();
assert!(retrieved.is_none());
// Verify not in search results
let results = adapter
.find(
"Temporary",
FindOptions {
scope: Some("agent://agent-1".to_string()),
limit: Some(10),
min_similarity: None,
},
)
.await
.unwrap();
assert!(results.is_empty());
}
/// Test cross-agent isolation
#[tokio::test]
async fn test_agent_isolation() {
let storage = Arc::new(SqliteStorage::in_memory().await);
let adapter = Arc::new(VikingAdapter::new(storage));
// Store memories for different agents
let agent1_memory = MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
"secret",
"Agent 1 的秘密信息".to_string(),
);
let agent2_memory = MemoryEntry::new(
"agent-2",
MemoryType::Knowledge,
"secret",
"Agent 2 的秘密信息".to_string(),
);
adapter.store(&agent1_memory).await.unwrap();
adapter.store(&agent2_memory).await.unwrap();
// Agent 1 should only see its own memories
let results = adapter
.find(
"秘密",
FindOptions {
scope: Some("agent://agent-1".to_string()),
limit: Some(10),
min_similarity: None,
},
)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("Agent 1"));
// Agent 2 should only see its own memories
let results = adapter
.find(
"秘密",
FindOptions {
scope: Some("agent://agent-2".to_string()),
limit: Some(10),
min_similarity: None,
},
)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("Agent 2"));
}
/// Test Chinese text handling
#[tokio::test]
async fn test_chinese_text_handling() {
let storage = Arc::new(SqliteStorage::in_memory().await);
let adapter = Arc::new(VikingAdapter::new(storage));
let entry = MemoryEntry::new(
"中文测试",
MemoryType::Knowledge,
"中文知识",
"这是一个中文测试,包含关键词:人工智能、机器学习、深度学习。".to_string(),
)
.with_keywords(vec!["人工智能".to_string(), "机器学习".to_string()]);
adapter.store(&entry).await.unwrap();
// Search with Chinese query
let results = adapter
.find(
"人工智能",
FindOptions {
scope: Some("agent://中文测试".to_string()),
limit: Some(10),
min_similarity: Some(0.1),
},
)
.await
.unwrap();
assert!(!results.is_empty());
assert!(results[0].content.contains("人工智能"));
}
/// Test find by prefix
#[tokio::test]
async fn test_find_by_prefix() {
let storage = Arc::new(SqliteStorage::in_memory().await);
let adapter = Arc::new(VikingAdapter::new(storage));
// Store multiple entries under same agent
for i in 0..5 {
let entry = MemoryEntry::new(
"agent-1",
MemoryType::Knowledge,
&format!("topic-{}", i),
format!("Content for topic {}", i),
);
adapter.store(&entry).await.unwrap();
}
// Find all entries for agent-1
let results = adapter
.find_by_prefix("agent://agent-1")
.await
.unwrap();
assert_eq!(results.len(), 5);
}

View File

@@ -375,6 +375,11 @@ impl Kernel {
&self.config
}
/// Get the LLM driver
pub fn driver(&self) -> Arc<dyn LlmDriver> {
self.driver.clone()
}
/// Get the skills registry
pub fn skills(&self) -> &Arc<SkillRegistry> {
&self.skills

View File

@@ -134,6 +134,12 @@ impl ActionRegistry {
max_tokens: Option<u32>,
json_mode: bool,
) -> Result<Value, ActionError> {
println!("[DEBUG execute_llm] Called with template length: {}", template.len());
println!("[DEBUG execute_llm] Input HashMap contents:");
for (k, v) in &input {
println!(" {} => {:?}", k, v);
}
if let Some(driver) = &self.llm_driver {
// Load template if it's a file path
let prompt = if template.ends_with(".md") || template.contains('/') {
@@ -142,6 +148,8 @@ impl ActionRegistry {
template.to_string()
};
println!("[DEBUG execute_llm] Calling driver.generate with prompt length: {}", prompt.len());
driver.generate(prompt, input, model, temperature, max_tokens, json_mode)
.await
.map_err(ActionError::Llm)

View File

@@ -0,0 +1,547 @@
//! Pipeline v2 Execution Context
//!
//! Enhanced context for v2 pipeline execution with:
//! - Parameter storage
//! - Stage outputs accumulation
//! - Loop context for parallel execution
//! - Variable storage
//! - Expression evaluation
use std::collections::HashMap;
use serde_json::Value;
use regex::Regex;
/// Execution context for Pipeline v2
#[derive(Debug, Clone)]
pub struct ExecutionContextV2 {
/// Pipeline input parameters (from user)
params: HashMap<String, Value>,
/// Stage outputs (stage_id -> output)
stages: HashMap<String, Value>,
/// Custom variables (set by set_var)
vars: HashMap<String, Value>,
/// Loop context for parallel execution
loop_context: Option<LoopContext>,
/// Expression regex for variable interpolation
expr_regex: Regex,
}
/// Loop context for parallel/each iterations
#[derive(Debug, Clone)]
pub struct LoopContext {
/// Current item
pub item: Value,
/// Current index
pub index: usize,
/// Total items count
pub total: usize,
/// Parent loop context (for nested loops)
pub parent: Option<Box<LoopContext>>,
}
impl ExecutionContextV2 {
/// Create a new execution context with parameters
pub fn new(params: HashMap<String, Value>) -> Self {
Self {
params,
stages: HashMap::new(),
vars: HashMap::new(),
loop_context: None,
expr_regex: Regex::new(r"\$\{([^}]+)\}").unwrap(),
}
}
/// Create from JSON value
pub fn from_value(params: Value) -> Self {
let params_map = if let Value::Object(obj) = params {
obj.into_iter().collect()
} else {
HashMap::new()
};
Self::new(params_map)
}
// === Parameter Access ===
/// Get a parameter value
pub fn get_param(&self, name: &str) -> Option<&Value> {
self.params.get(name)
}
/// Get all parameters
pub fn params(&self) -> &HashMap<String, Value> {
&self.params
}
// === Stage Output ===
/// Set a stage output
pub fn set_stage_output(&mut self, stage_id: &str, value: Value) {
self.stages.insert(stage_id.to_string(), value);
}
/// Get a stage output
pub fn get_stage_output(&self, stage_id: &str) -> Option<&Value> {
self.stages.get(stage_id)
}
/// Get all stage outputs
pub fn all_stages(&self) -> &HashMap<String, Value> {
&self.stages
}
// === Variables ===
/// Set a variable
pub fn set_var(&mut self, name: &str, value: Value) {
self.vars.insert(name.to_string(), value);
}
/// Get a variable
pub fn get_var(&self, name: &str) -> Option<&Value> {
self.vars.get(name)
}
// === Loop Context ===
/// Set loop context
pub fn set_loop_context(&mut self, item: Value, index: usize, total: usize) {
self.loop_context = Some(LoopContext {
item,
index,
total,
parent: self.loop_context.take().map(Box::new),
});
}
/// Clear current loop context
pub fn clear_loop_context(&mut self) {
if let Some(ctx) = self.loop_context.take() {
self.loop_context = ctx.parent.map(|b| *b);
}
}
/// Get current loop item
pub fn loop_item(&self) -> Option<&Value> {
self.loop_context.as_ref().map(|c| &c.item)
}
/// Get current loop index
pub fn loop_index(&self) -> Option<usize> {
self.loop_context.as_ref().map(|c| c.index)
}
// === Expression Evaluation ===
/// Resolve an expression to a value
///
/// Supported expressions:
/// - `${params.topic}` - Parameter
/// - `${stages.outline}` - Stage output
/// - `${stages.outline.sections}` - Nested access
/// - `${item}` - Current loop item
/// - `${index}` - Current loop index
/// - `${vars.customVar}` - Variable
/// - `'literal'` or `"literal"` - Quoted string literal
pub fn resolve(&self, expr: &str) -> Result<Value, ContextError> {
// Handle quoted string literals
let trimmed = expr.trim();
if (trimmed.starts_with('\'') && trimmed.ends_with('\'')) ||
(trimmed.starts_with('"') && trimmed.ends_with('"')) {
let inner = &trimmed[1..trimmed.len()-1];
return Ok(Value::String(inner.to_string()));
}
// If not an expression, return as string
if !expr.contains("${") {
return Ok(Value::String(expr.to_string()));
}
// If entire string is a single expression, return the actual value
if expr.starts_with("${") && expr.ends_with("}") && expr.matches("${").count() == 1 {
let path = &expr[2..expr.len()-1];
return self.resolve_path(path);
}
// Replace all expressions in string
let result = self.expr_regex.replace_all(expr, |caps: &regex::Captures| {
let path = &caps[1];
match self.resolve_path(path) {
Ok(value) => value_to_string(&value),
Err(_) => caps[0].to_string(),
}
});
Ok(Value::String(result.to_string()))
}
/// Resolve a path like "params.topic" or "stages.outline.sections.0"
fn resolve_path(&self, path: &str) -> Result<Value, ContextError> {
let parts: Vec<&str> = path.split('.').collect();
if parts.is_empty() {
return Err(ContextError::InvalidPath(path.to_string()));
}
let first = parts[0];
let rest = &parts[1..];
match first {
"params" => self.resolve_from_map(&self.params, rest, path),
"stages" => self.resolve_from_map(&self.stages, rest, path),
"vars" | "var" => self.resolve_from_map(&self.vars, rest, path),
"item" => {
if let Some(ctx) = &self.loop_context {
if rest.is_empty() {
Ok(ctx.item.clone())
} else {
self.resolve_from_value(&ctx.item, rest, path)
}
} else {
Err(ContextError::VariableNotFound("item".to_string()))
}
}
"index" => {
if let Some(ctx) = &self.loop_context {
Ok(Value::Number(ctx.index.into()))
} else {
Err(ContextError::VariableNotFound("index".to_string()))
}
}
"total" => {
if let Some(ctx) = &self.loop_context {
Ok(Value::Number(ctx.total.into()))
} else {
Err(ContextError::VariableNotFound("total".to_string()))
}
}
_ => Err(ContextError::InvalidPath(path.to_string())),
}
}
/// Resolve from a map
fn resolve_from_map(
&self,
map: &HashMap<String, Value>,
path_parts: &[&str],
full_path: &str,
) -> Result<Value, ContextError> {
if path_parts.is_empty() {
return Err(ContextError::InvalidPath(full_path.to_string()));
}
let key = path_parts[0];
let value = map.get(key)
.ok_or_else(|| ContextError::VariableNotFound(key.to_string()))?;
if path_parts.len() == 1 {
Ok(value.clone())
} else {
self.resolve_from_value(value, &path_parts[1..], full_path)
}
}
/// Resolve from a value (nested access)
fn resolve_from_value(
&self,
value: &Value,
path_parts: &[&str],
full_path: &str,
) -> Result<Value, ContextError> {
let mut current = value;
for part in path_parts {
current = match current {
Value::Object(map) => map.get(*part)
.ok_or_else(|| ContextError::FieldNotFound(part.to_string()))?,
Value::Array(arr) => {
if let Ok(idx) = part.parse::<usize>() {
arr.get(idx)
.ok_or_else(|| ContextError::IndexOutOfBounds(idx))?
} else {
return Err(ContextError::InvalidPath(full_path.to_string()));
}
}
_ => return Err(ContextError::InvalidPath(full_path.to_string())),
};
}
Ok(current.clone())
}
/// Resolve expression and expect array result
pub fn resolve_array(&self, expr: &str) -> Result<Vec<Value>, ContextError> {
let value = self.resolve(expr)?;
match value {
Value::Array(arr) => Ok(arr),
Value::String(s) if s.starts_with('[') => {
serde_json::from_str(&s)
.map_err(|e| ContextError::TypeError(format!("Expected array: {}", e)))
}
_ => Err(ContextError::TypeError("Expected array".to_string())),
}
}
/// Resolve expression and expect string result
pub fn resolve_string(&self, expr: &str) -> Result<String, ContextError> {
let value = self.resolve(expr)?;
Ok(value_to_string(&value))
}
/// Evaluate a condition expression
///
/// Supports:
/// - Equality: `${params.level} == 'advanced'`
/// - Inequality: `${params.level} != 'beginner'`
/// - Comparison: `${params.count} > 5`
/// - Contains: `'python' in ${params.tags}`
/// - Boolean: `${params.enabled}`
pub fn evaluate_condition(&self, condition: &str) -> Result<bool, ContextError> {
let condition = condition.trim();
// Handle equality
if let Some(eq_pos) = condition.find("==") {
let left = condition[..eq_pos].trim();
let right = condition[eq_pos + 2..].trim();
return self.compare_equal(left, right);
}
// Handle inequality
if let Some(ne_pos) = condition.find("!=") {
let left = condition[..ne_pos].trim();
let right = condition[ne_pos + 2..].trim();
return Ok(!self.compare_equal(left, right)?);
}
// Handle greater than
if let Some(gt_pos) = condition.find('>') {
let left = condition[..gt_pos].trim();
let right = condition[gt_pos + 1..].trim();
return self.compare_gt(left, right);
}
// Handle less than
if let Some(lt_pos) = condition.find('<') {
let left = condition[..lt_pos].trim();
let right = condition[lt_pos + 1..].trim();
return self.compare_lt(left, right);
}
// Handle 'in' operator
if let Some(in_pos) = condition.find(" in ") {
let needle = condition[..in_pos].trim();
let haystack = condition[in_pos + 4..].trim();
return self.check_contains(haystack, needle);
}
// Simple boolean evaluation
let value = self.resolve(condition)?;
match value {
Value::Bool(b) => Ok(b),
Value::String(s) => Ok(!s.is_empty() && s != "false" && s != "0"),
Value::Number(n) => Ok(n.as_f64().map(|f| f != 0.0).unwrap_or(false)),
Value::Null => Ok(false),
Value::Array(arr) => Ok(!arr.is_empty()),
Value::Object(obj) => Ok(!obj.is_empty()),
}
}
fn compare_equal(&self, left: &str, right: &str) -> Result<bool, ContextError> {
let left_val = self.resolve(left)?;
let right_val = self.resolve(right)?;
Ok(left_val == right_val)
}
fn compare_gt(&self, left: &str, right: &str) -> Result<bool, ContextError> {
let left_val = self.resolve(left)?;
let right_val = self.resolve(right)?;
let left_num = value_to_f64(&left_val);
let right_num = value_to_f64(&right_val);
match (left_num, right_num) {
(Some(l), Some(r)) => Ok(l > r),
_ => Err(ContextError::TypeError("Cannot compare non-numeric values".to_string())),
}
}
fn compare_lt(&self, left: &str, right: &str) -> Result<bool, ContextError> {
let left_val = self.resolve(left)?;
let right_val = self.resolve(right)?;
let left_num = value_to_f64(&left_val);
let right_num = value_to_f64(&right_val);
match (left_num, right_num) {
(Some(l), Some(r)) => Ok(l < r),
_ => Err(ContextError::TypeError("Cannot compare non-numeric values".to_string())),
}
}
fn check_contains(&self, haystack: &str, needle: &str) -> Result<bool, ContextError> {
let haystack_val = self.resolve(haystack)?;
let needle_val = self.resolve(needle)?;
let needle_str = value_to_string(&needle_val);
match haystack_val {
Value::Array(arr) => Ok(arr.iter().any(|v| value_to_string(v) == needle_str)),
Value::String(s) => Ok(s.contains(&needle_str)),
Value::Object(obj) => Ok(obj.contains_key(&needle_str)),
_ => Err(ContextError::TypeError("Cannot check contains on this type".to_string())),
}
}
/// Create a child context for parallel execution
pub fn child_context(&self, item: Value, index: usize, total: usize) -> Self {
let mut child = Self {
params: self.params.clone(),
stages: self.stages.clone(),
vars: self.vars.clone(),
loop_context: None,
expr_regex: Regex::new(r"\$\{([^}]+)\}").unwrap(),
};
child.set_loop_context(item, index, total);
child
}
}
/// Convert value to string for template replacement
fn value_to_string(value: &Value) -> String {
match value {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
Value::Null => String::new(),
Value::Array(arr) => serde_json::to_string(arr).unwrap_or_default(),
Value::Object(obj) => serde_json::to_string(obj).unwrap_or_default(),
}
}
/// Convert value to f64 for comparison
fn value_to_f64(value: &Value) -> Option<f64> {
match value {
Value::Number(n) => n.as_f64(),
Value::String(s) => s.parse().ok(),
_ => None,
}
}
/// Public version for use in stage.rs
pub fn value_to_f64_public(value: &Value) -> Option<f64> {
value_to_f64(value)
}
/// Context errors
#[derive(Debug, thiserror::Error)]
pub enum ContextError {
#[error("Invalid path: {0}")]
InvalidPath(String),
#[error("Variable not found: {0}")]
VariableNotFound(String),
#[error("Field not found: {0}")]
FieldNotFound(String),
#[error("Index out of bounds: {0}")]
IndexOutOfBounds(usize),
#[error("Type error: {0}")]
TypeError(String),
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_resolve_param() {
let ctx = ExecutionContextV2::new(
vec![("topic".to_string(), json!("Python"))]
.into_iter()
.collect()
);
let result = ctx.resolve("${params.topic}").unwrap();
assert_eq!(result, json!("Python"));
}
#[test]
fn test_resolve_stage_output() {
let mut ctx = ExecutionContextV2::new(HashMap::new());
ctx.set_stage_output("outline", json!({"sections": ["s1", "s2"]}));
let result = ctx.resolve("${stages.outline.sections}").unwrap();
assert_eq!(result, json!(["s1", "s2"]));
}
#[test]
fn test_resolve_loop_context() {
let mut ctx = ExecutionContextV2::new(HashMap::new());
ctx.set_loop_context(json!({"title": "Chapter 1"}), 0, 5);
let item = ctx.resolve("${item}").unwrap();
assert_eq!(item, json!({"title": "Chapter 1"}));
let title = ctx.resolve("${item.title}").unwrap();
assert_eq!(title, json!("Chapter 1"));
let index = ctx.resolve("${index}").unwrap();
assert_eq!(index, json!(0));
}
#[test]
fn test_resolve_mixed_string() {
let ctx = ExecutionContextV2::new(
vec![("name".to_string(), json!("World"))]
.into_iter()
.collect()
);
let result = ctx.resolve("Hello, ${params.name}!").unwrap();
assert_eq!(result, json!("Hello, World!"));
}
#[test]
fn test_evaluate_condition_equal() {
let ctx = ExecutionContextV2::new(
vec![("level".to_string(), json!("advanced"))]
.into_iter()
.collect()
);
assert!(ctx.evaluate_condition("${params.level} == 'advanced'").unwrap());
assert!(!ctx.evaluate_condition("${params.level} == 'beginner'").unwrap());
}
#[test]
fn test_evaluate_condition_gt() {
let ctx = ExecutionContextV2::new(
vec![("count".to_string(), json!(10))]
.into_iter()
.collect()
);
assert!(ctx.evaluate_condition("${params.count} > 5").unwrap());
assert!(!ctx.evaluate_condition("${params.count} > 20").unwrap());
}
#[test]
fn test_child_context() {
let ctx = ExecutionContextV2::new(
vec![("topic".to_string(), json!("Python"))]
.into_iter()
.collect()
);
let child = ctx.child_context(json!("item1"), 0, 3);
assert_eq!(child.loop_item().unwrap(), &json!("item1"));
assert_eq!(child.loop_index().unwrap(), 0);
assert_eq!(child.get_param("topic").unwrap(), &json!("Python"));
}
}

View File

@@ -0,0 +1,11 @@
//! Pipeline Engine Module
//!
//! Contains the v2 execution engine components:
//! - StageRunner: Executes individual stages
//! - Context v2: Enhanced execution context
pub mod stage;
pub mod context;
pub use stage::*;
pub use context::*;

View File

@@ -0,0 +1,623 @@
//! Stage Execution Engine
//!
//! Executes Pipeline v2 stages with support for:
//! - LLM generation
//! - Parallel execution
//! - Conditional branching
//! - Result composition
//! - Skill/Hand/HTTP integration
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use futures::future::join_all;
use serde_json::{Value, json};
use tokio::sync::RwLock;
use crate::types_v2::{Stage, ConditionalBranch, PresentationType};
use crate::engine::context::{ExecutionContextV2, ContextError};
/// Stage execution result
#[derive(Debug, Clone)]
pub struct StageResult {
/// Stage ID
pub stage_id: String,
/// Output value
pub output: Value,
/// Execution status
pub status: StageStatus,
/// Error message (if failed)
pub error: Option<String>,
/// Execution duration in ms
pub duration_ms: u64,
}
/// Stage execution status
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StageStatus {
Success,
Failed,
Skipped,
}
/// Stage execution event for progress tracking
#[derive(Debug, Clone)]
pub enum StageEvent {
/// Stage started
Started { stage_id: String },
/// Stage progress update
Progress { stage_id: String, message: String },
/// Stage completed
Completed { stage_id: String, result: StageResult },
/// Parallel progress
ParallelProgress { stage_id: String, completed: usize, total: usize },
/// Error occurred
Error { stage_id: String, error: String },
}
/// LLM driver trait for stage execution
#[async_trait]
pub trait StageLlmDriver: Send + Sync {
/// Generate completion
async fn generate(
&self,
prompt: String,
model: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> Result<Value, StageError>;
/// Generate with JSON schema
async fn generate_with_schema(
&self,
prompt: String,
schema: Value,
model: Option<String>,
temperature: Option<f32>,
) -> Result<Value, StageError>;
}
/// Skill driver trait
#[async_trait]
pub trait StageSkillDriver: Send + Sync {
/// Execute a skill
async fn execute(
&self,
skill_id: &str,
input: HashMap<String, Value>,
) -> Result<Value, StageError>;
}
/// Hand driver trait
#[async_trait]
pub trait StageHandDriver: Send + Sync {
/// Execute a hand action
async fn execute(
&self,
hand_id: &str,
action: &str,
params: HashMap<String, Value>,
) -> Result<Value, StageError>;
}
/// Stage execution engine
pub struct StageEngine {
/// LLM driver
llm_driver: Option<Arc<dyn StageLlmDriver>>,
/// Skill driver
skill_driver: Option<Arc<dyn StageSkillDriver>>,
/// Hand driver
hand_driver: Option<Arc<dyn StageHandDriver>>,
/// Event callback
event_callback: Option<Arc<dyn Fn(StageEvent) + Send + Sync>>,
/// Maximum parallel workers
max_workers: usize,
}
impl StageEngine {
/// Create a new stage engine
pub fn new() -> Self {
Self {
llm_driver: None,
skill_driver: None,
hand_driver: None,
event_callback: None,
max_workers: 3,
}
}
/// Set LLM driver
pub fn with_llm_driver(mut self, driver: Arc<dyn StageLlmDriver>) -> Self {
self.llm_driver = Some(driver);
self
}
/// Set skill driver
pub fn with_skill_driver(mut self, driver: Arc<dyn StageSkillDriver>) -> Self {
self.skill_driver = Some(driver);
self
}
/// Set hand driver
pub fn with_hand_driver(mut self, driver: Arc<dyn StageHandDriver>) -> Self {
self.hand_driver = Some(driver);
self
}
/// Set event callback
pub fn with_event_callback(mut self, callback: Arc<dyn Fn(StageEvent) + Send + Sync>) -> Self {
self.event_callback = Some(callback);
self
}
/// Set max workers
pub fn with_max_workers(mut self, max: usize) -> Self {
self.max_workers = max;
self
}
/// Execute a stage (boxed to support recursion)
pub fn execute<'a>(
&'a self,
stage: &'a Stage,
context: &'a mut ExecutionContextV2,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<StageResult, StageError>> + 'a>> {
Box::pin(async move {
self.execute_inner(stage, context).await
})
}
/// Inner execute implementation
async fn execute_inner(
&self,
stage: &Stage,
context: &mut ExecutionContextV2,
) -> Result<StageResult, StageError> {
let start = std::time::Instant::now();
let stage_id = stage.id().to_string();
// Emit started event
self.emit_event(StageEvent::Started {
stage_id: stage_id.clone(),
});
let result = match stage {
Stage::Llm { prompt, model, temperature, max_tokens, output_schema, .. } => {
self.execute_llm(&stage_id, prompt, model, temperature, max_tokens, output_schema, context).await
}
Stage::Parallel { each, stage, max_workers, .. } => {
self.execute_parallel(&stage_id, each, stage, *max_workers, context).await
}
Stage::Sequential { stages, .. } => {
self.execute_sequential(&stage_id, stages, context).await
}
Stage::Conditional { condition, branches, default, .. } => {
self.execute_conditional(&stage_id, condition, branches, default.as_deref(), context).await
}
Stage::Compose { template, .. } => {
self.execute_compose(&stage_id, template, context).await
}
Stage::Skill { skill_id, input, .. } => {
self.execute_skill(&stage_id, skill_id, input, context).await
}
Stage::Hand { hand_id, action, params, .. } => {
self.execute_hand(&stage_id, hand_id, action, params, context).await
}
Stage::Http { url, method, headers, body, .. } => {
self.execute_http(&stage_id, url, method, headers, body, context).await
}
Stage::SetVar { name, value, .. } => {
self.execute_set_var(&stage_id, name, value, context).await
}
};
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(output) => {
// Store output in context
context.set_stage_output(&stage_id, output.clone());
let result = StageResult {
stage_id: stage_id.clone(),
output,
status: StageStatus::Success,
error: None,
duration_ms,
};
self.emit_event(StageEvent::Completed {
stage_id,
result: result.clone(),
});
Ok(result)
}
Err(e) => {
let result = StageResult {
stage_id: stage_id.clone(),
output: Value::Null,
status: StageStatus::Failed,
error: Some(e.to_string()),
duration_ms,
};
self.emit_event(StageEvent::Error {
stage_id,
error: e.to_string(),
});
Err(e)
}
}
}
/// Execute LLM stage
async fn execute_llm(
&self,
stage_id: &str,
prompt: &str,
model: &Option<String>,
temperature: &Option<f32>,
max_tokens: &Option<u32>,
output_schema: &Option<Value>,
context: &ExecutionContextV2,
) -> Result<Value, StageError> {
let driver = self.llm_driver.as_ref()
.ok_or_else(|| StageError::DriverNotAvailable("LLM".to_string()))?;
// Resolve prompt template
let resolved_prompt = context.resolve(prompt)?;
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: "Calling LLM...".to_string(),
});
let prompt_str = resolved_prompt.as_str()
.ok_or_else(|| StageError::TypeError("Prompt must be a string".to_string()))?
.to_string();
// Generate with or without schema
let result = if let Some(schema) = output_schema {
driver.generate_with_schema(
prompt_str,
schema.clone(),
model.clone(),
*temperature,
).await
} else {
driver.generate(
prompt_str,
model.clone(),
*temperature,
*max_tokens,
).await
};
result.map_err(|e| StageError::ExecutionFailed(format!("LLM error: {}", e)))
}
/// Execute parallel stage
async fn execute_parallel(
&self,
stage_id: &str,
each: &str,
stage_template: &Stage,
max_workers: usize,
context: &mut ExecutionContextV2,
) -> Result<Value, StageError> {
// Resolve the array to iterate over
let items = context.resolve_array(each)?;
let total = items.len();
if total == 0 {
return Ok(Value::Array(vec![]));
}
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: format!("Processing {} items", total),
});
// Sequential execution with progress tracking
// Note: True parallel execution would require Send-safe drivers
let mut outputs = Vec::with_capacity(total);
for (index, item) in items.into_iter().enumerate() {
let mut child_context = context.child_context(item.clone(), index, total);
self.emit_event(StageEvent::ParallelProgress {
stage_id: stage_id.to_string(),
completed: index,
total,
});
match self.execute(stage_template, &mut child_context).await {
Ok(result) => outputs.push(result.output),
Err(e) => outputs.push(json!({ "error": e.to_string(), "index": index })),
}
}
Ok(Value::Array(outputs))
}
/// Execute sequential stages
async fn execute_sequential(
&self,
stage_id: &str,
stages: &[Stage],
context: &mut ExecutionContextV2,
) -> Result<Value, StageError> {
let mut outputs = Vec::new();
for stage in stages {
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: format!("Executing stage: {}", stage.id()),
});
let result = self.execute(stage, context).await?;
outputs.push(result.output);
}
Ok(Value::Array(outputs))
}
/// Execute conditional stage
async fn execute_conditional(
&self,
stage_id: &str,
condition: &str,
branches: &[ConditionalBranch],
default: Option<&Stage>,
context: &mut ExecutionContextV2,
) -> Result<Value, StageError> {
// Evaluate main condition
let condition_result = context.evaluate_condition(condition)?;
if condition_result {
// Check each branch
for branch in branches {
if context.evaluate_condition(&branch.when)? {
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: format!("Branch matched: {}", branch.when),
});
return self.execute(&branch.then, context).await
.map(|r| r.output);
}
}
// No branch matched, use default
if let Some(default_stage) = default {
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: "Using default branch".to_string(),
});
return self.execute(default_stage, context).await
.map(|r| r.output);
}
Ok(Value::Null)
} else {
// Main condition false, return null
Ok(Value::Null)
}
}
/// Execute compose stage
async fn execute_compose(
&self,
stage_id: &str,
template: &str,
context: &ExecutionContextV2,
) -> Result<Value, StageError> {
let resolved = context.resolve(template)?;
// Try to parse as JSON
if let Value::String(s) = &resolved {
if s.starts_with('{') || s.starts_with('[') {
if let Ok(json) = serde_json::from_str::<Value>(s) {
return Ok(json);
}
}
}
Ok(resolved)
}
/// Execute skill stage
async fn execute_skill(
&self,
stage_id: &str,
skill_id: &str,
input: &HashMap<String, String>,
context: &ExecutionContextV2,
) -> Result<Value, StageError> {
let driver = self.skill_driver.as_ref()
.ok_or_else(|| StageError::DriverNotAvailable("Skill".to_string()))?;
// Resolve input expressions
let mut resolved_input = HashMap::new();
for (key, expr) in input {
let value = context.resolve(expr)?;
resolved_input.insert(key.clone(), value);
}
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: format!("Executing skill: {}", skill_id),
});
driver.execute(skill_id, resolved_input).await
.map_err(|e| StageError::ExecutionFailed(format!("Skill error: {}", e)))
}
/// Execute hand stage
async fn execute_hand(
&self,
stage_id: &str,
hand_id: &str,
action: &str,
params: &HashMap<String, String>,
context: &ExecutionContextV2,
) -> Result<Value, StageError> {
let driver = self.hand_driver.as_ref()
.ok_or_else(|| StageError::DriverNotAvailable("Hand".to_string()))?;
// Resolve parameter expressions
let mut resolved_params = HashMap::new();
for (key, expr) in params {
let value = context.resolve(expr)?;
resolved_params.insert(key.clone(), value);
}
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: format!("Executing hand: {} / {}", hand_id, action),
});
driver.execute(hand_id, action, resolved_params).await
.map_err(|e| StageError::ExecutionFailed(format!("Hand error: {}", e)))
}
/// Execute HTTP stage
async fn execute_http(
&self,
stage_id: &str,
url: &str,
method: &str,
headers: &HashMap<String, String>,
body: &Option<String>,
context: &ExecutionContextV2,
) -> Result<Value, StageError> {
// Resolve URL
let resolved_url = context.resolve_string(url)?;
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: format!("HTTP {} {}", method, resolved_url),
});
// Build request
let client = reqwest::Client::new();
let mut request = match method.to_uppercase().as_str() {
"GET" => client.get(&resolved_url),
"POST" => client.post(&resolved_url),
"PUT" => client.put(&resolved_url),
"DELETE" => client.delete(&resolved_url),
"PATCH" => client.patch(&resolved_url),
_ => return Err(StageError::ExecutionFailed(format!("Unsupported HTTP method: {}", method))),
};
// Add headers
for (key, value) in headers {
let resolved_value = context.resolve_string(value)?;
request = request.header(key, resolved_value);
}
// Add body
if let Some(body_expr) = body {
let resolved_body = context.resolve(body_expr)?;
request = request.json(&resolved_body);
}
// Execute request
let response = request.send().await
.map_err(|e| StageError::ExecutionFailed(format!("HTTP request failed: {}", e)))?;
// Parse response
let status = response.status();
if !status.is_success() {
return Err(StageError::ExecutionFailed(format!("HTTP error: {}", status)));
}
let json = response.json::<Value>().await
.map_err(|e| StageError::ExecutionFailed(format!("Failed to parse response: {}", e)))?;
Ok(json)
}
/// Execute set_var stage
async fn execute_set_var(
&self,
stage_id: &str,
name: &str,
value: &str,
context: &mut ExecutionContextV2,
) -> Result<Value, StageError> {
let resolved_value = context.resolve(value)?;
context.set_var(name, resolved_value.clone());
self.emit_event(StageEvent::Progress {
stage_id: stage_id.to_string(),
message: format!("Set variable: {} = {:?}", name, resolved_value),
});
Ok(resolved_value)
}
/// Clone with drivers
fn clone_with_drivers(&self) -> Self {
Self {
llm_driver: self.llm_driver.clone(),
skill_driver: self.skill_driver.clone(),
hand_driver: self.hand_driver.clone(),
event_callback: self.event_callback.clone(),
max_workers: self.max_workers,
}
}
/// Emit event
fn emit_event(&self, event: StageEvent) {
if let Some(callback) = &self.event_callback {
callback(event);
}
}
}
impl Default for StageEngine {
fn default() -> Self {
Self::new()
}
}
/// Stage execution error
#[derive(Debug, thiserror::Error)]
pub enum StageError {
#[error("Driver not available: {0}")]
DriverNotAvailable(String),
#[error("Execution failed: {0}")]
ExecutionFailed(String),
#[error("Type error: {0}")]
TypeError(String),
#[error("Context error: {0}")]
ContextError(#[from] ContextError),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stage_engine_creation() {
let engine = StageEngine::new()
.with_max_workers(5);
assert_eq!(engine.max_workers, 5);
}
}

View File

@@ -11,7 +11,7 @@ use chrono::Utc;
use futures::stream::{self, StreamExt};
use futures::future::{BoxFuture, FutureExt};
use crate::types::{Pipeline, PipelineRun, PipelineProgress, RunStatus, PipelineStep, Action};
use crate::types::{Pipeline, PipelineRun, PipelineProgress, RunStatus, PipelineStep, Action, ExportFormat};
use crate::state::{ExecutionContext, StateError};
use crate::actions::ActionRegistry;
@@ -62,14 +62,28 @@ impl PipelineExecutor {
}
}
/// Execute a pipeline
/// Execute a pipeline with auto-generated run ID
pub async fn execute(
&self,
pipeline: &Pipeline,
inputs: HashMap<String, Value>,
) -> Result<PipelineRun, ExecuteError> {
let run_id = Uuid::new_v4().to_string();
self.execute_with_id(pipeline, inputs, &run_id).await
}
/// Execute a pipeline with a specific run ID
///
/// Use this when you need to know the run_id before execution starts,
/// e.g., for async spawning where the caller needs to track progress.
pub async fn execute_with_id(
&self,
pipeline: &Pipeline,
inputs: HashMap<String, Value>,
run_id: &str,
) -> Result<PipelineRun, ExecuteError> {
let pipeline_id = pipeline.metadata.name.clone();
let run_id = run_id.to_string();
// Create run record
let run = PipelineRun {
@@ -171,9 +185,25 @@ impl PipelineExecutor {
async move {
match action {
Action::LlmGenerate { template, input, model, temperature, max_tokens, json_mode } => {
println!("[DEBUG executor] LlmGenerate action called");
println!("[DEBUG executor] Raw input map:");
for (k, v) in input {
println!(" {} => {}", k, v);
}
// First resolve the template itself (handles ${inputs.xxx}, ${item.xxx}, etc.)
let resolved_template = context.resolve(template)?;
let resolved_template_str = resolved_template.as_str().unwrap_or(template).to_string();
println!("[DEBUG executor] Resolved template (first 300 chars): {}",
&resolved_template_str[..resolved_template_str.len().min(300)]);
let resolved_input = context.resolve_map(input)?;
println!("[DEBUG executor] Resolved input map:");
for (k, v) in &resolved_input {
println!(" {} => {:?}", k, v);
}
self.action_registry.execute_llm(
template,
&resolved_template_str,
resolved_input,
model.clone(),
*temperature,
@@ -188,7 +218,7 @@ impl PipelineExecutor {
.ok_or_else(|| ExecuteError::Action("Parallel 'each' must resolve to an array".to_string()))?;
let workers = max_workers.unwrap_or(4);
let results = self.execute_parallel(step, items_array.clone(), workers).await?;
let results = self.execute_parallel(step, items_array.clone(), workers, context).await?;
Ok(Value::Array(results))
}
@@ -247,7 +277,38 @@ impl PipelineExecutor {
None => None,
};
self.action_registry.export_files(formats, &data, dir.as_deref())
// Resolve formats expression and parse as array
let resolved_formats = context.resolve(formats)?;
let format_strings: Vec<String> = if resolved_formats.is_array() {
resolved_formats.as_array()
.ok_or_else(|| ExecuteError::Action("formats must be an array".to_string()))?
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
} else if resolved_formats.is_string() {
// Try to parse as JSON array string
let s = resolved_formats.as_str()
.ok_or_else(|| ExecuteError::Action("formats must be a string or array".to_string()))?;
serde_json::from_str(s)
.unwrap_or_else(|_| vec![s.to_string()])
} else {
return Err(ExecuteError::Action("formats must be a string or array".to_string()));
};
// Convert strings to ExportFormat
let export_formats: Vec<ExportFormat> = format_strings
.iter()
.filter_map(|s| match s.to_lowercase().as_str() {
"pptx" => Some(ExportFormat::Pptx),
"html" => Some(ExportFormat::Html),
"pdf" => Some(ExportFormat::Pdf),
"markdown" | "md" => Some(ExportFormat::Markdown),
"json" => Some(ExportFormat::Json),
_ => None,
})
.collect();
self.action_registry.export_files(&export_formats, &data, dir.as_deref())
.await
.map_err(|e| ExecuteError::Action(e.to_string()))
}
@@ -301,18 +362,31 @@ impl PipelineExecutor {
step: &PipelineStep,
items: Vec<Value>,
max_workers: usize,
parent_context: &ExecutionContext,
) -> Result<Vec<Value>, ExecuteError> {
let action_registry = self.action_registry.clone();
let action = step.action.clone();
// Clone parent context data for child contexts
let parent_inputs = parent_context.inputs().clone();
let parent_outputs = parent_context.all_outputs().clone();
let parent_vars = parent_context.all_vars().clone();
let results: Vec<Result<Value, ExecuteError>> = stream::iter(items.into_iter().enumerate())
.map(|(index, item)| {
let action_registry = action_registry.clone();
let action = action.clone();
let parent_inputs = parent_inputs.clone();
let parent_outputs = parent_outputs.clone();
let parent_vars = parent_vars.clone();
async move {
// Create child context with loop variables
let mut child_ctx = ExecutionContext::new(HashMap::new());
// Create child context with parent data and loop variables
let mut child_ctx = ExecutionContext::from_parent(
parent_inputs,
parent_outputs,
parent_vars,
);
child_ctx.set_loop_context(item, index);
// Execute the step's action

View File

@@ -0,0 +1,666 @@
//! Intent Router System
//!
//! Routes user input to the appropriate pipeline using:
//! 1. Quick matching (keywords + patterns, < 10ms)
//! 2. Semantic matching (LLM-based, ~200ms)
//!
//! # Flow
//!
//! ```text
//! User Input
//! ↓
//! Quick Match (keywords/patterns)
//! ├─→ Match found → Prepare execution
//! └─→ No match → Semantic Match (LLM)
//! ├─→ Match found → Prepare execution
//! └─→ No match → Return suggestions
//! ```
//!
//! # Example
//!
//! ```rust,ignore
//! use zclaw_pipeline::{IntentRouter, RouteResult, TriggerParser, LlmIntentDriver};
//!
//! async fn example() {
//! let router = IntentRouter::new(trigger_parser, llm_driver);
//! let result = router.route("帮我做一个Python入门课程").await.unwrap();
//!
//! match result {
//! RouteResult::Matched { pipeline_id, params, mode } => {
//! // Start pipeline execution
//! }
//! RouteResult::Suggestions { pipelines } => {
//! // Show user available options
//! }
//! RouteResult::NeedMoreInfo { prompt } => {
//! // Ask user for clarification
//! }
//! }
//! }
//! ```
use crate::trigger::{CompiledTrigger, MatchType, TriggerMatch, TriggerParser, TriggerParam};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Intent router - main entry point for user input
pub struct IntentRouter {
/// Trigger parser for quick matching
trigger_parser: TriggerParser,
/// LLM driver for semantic matching
llm_driver: Option<Box<dyn LlmIntentDriver>>,
/// Configuration
config: RouterConfig,
}
/// Router configuration
#[derive(Debug, Clone)]
pub struct RouterConfig {
/// Minimum confidence threshold for auto-matching
pub confidence_threshold: f32,
/// Number of suggestions to return when no clear match
pub suggestion_count: usize,
/// Enable semantic matching via LLM
pub enable_semantic_matching: bool,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
confidence_threshold: 0.7,
suggestion_count: 3,
enable_semantic_matching: true,
}
}
}
/// Route result
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum RouteResult {
/// Successfully matched a pipeline
Matched {
/// Matched pipeline ID
pipeline_id: String,
/// Pipeline display name
display_name: Option<String>,
/// Input mode (conversation, form, hybrid)
mode: InputMode,
/// Extracted parameters
params: HashMap<String, serde_json::Value>,
/// Match confidence
confidence: f32,
/// Missing required parameters
missing_params: Vec<MissingParam>,
},
/// Multiple possible matches, need user selection
Ambiguous {
/// Candidate pipelines
candidates: Vec<PipelineCandidate>,
},
/// No match found, show suggestions
NoMatch {
/// Suggested pipelines based on category/tags
suggestions: Vec<PipelineCandidate>,
},
/// Need more information from user
NeedMoreInfo {
/// Prompt to show user
prompt: String,
/// Related pipeline (if any)
related_pipeline: Option<String>,
},
}
/// Input mode for parameter collection
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum InputMode {
/// Simple conversation-based collection
Conversation,
/// Form-based collection
Form,
/// Hybrid - start with conversation, switch to form if needed
Hybrid,
/// Auto - system decides based on complexity
Auto,
}
impl Default for InputMode {
fn default() -> Self {
Self::Auto
}
}
/// Pipeline candidate for suggestions
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PipelineCandidate {
/// Pipeline ID
pub id: String,
/// Display name
pub display_name: Option<String>,
/// Description
pub description: Option<String>,
/// Icon
pub icon: Option<String>,
/// Category
pub category: Option<String>,
/// Match reason
pub match_reason: Option<String>,
}
/// Missing parameter info
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MissingParam {
/// Parameter name
pub name: String,
/// Parameter label
pub label: Option<String>,
/// Parameter type
pub param_type: String,
/// Is this required?
pub required: bool,
/// Default value if available
pub default: Option<serde_json::Value>,
}
impl IntentRouter {
/// Create a new intent router
pub fn new(trigger_parser: TriggerParser) -> Self {
Self {
trigger_parser,
llm_driver: None,
config: RouterConfig::default(),
}
}
/// Set LLM driver for semantic matching
pub fn with_llm_driver(mut self, driver: Box<dyn LlmIntentDriver>) -> Self {
self.llm_driver = Some(driver);
self
}
/// Set configuration
pub fn with_config(mut self, config: RouterConfig) -> Self {
self.config = config;
self
}
/// Route user input to a pipeline
pub async fn route(&self, user_input: &str) -> RouteResult {
// Step 1: Quick match (local, < 10ms)
if let Some(match_result) = self.trigger_parser.quick_match(user_input) {
return self.prepare_from_match(match_result);
}
// Step 2: Semantic match (LLM, ~200ms)
if self.config.enable_semantic_matching {
if let Some(ref llm_driver) = self.llm_driver {
if let Some(result) = llm_driver.semantic_match(user_input, self.trigger_parser.triggers()).await {
return self.prepare_from_semantic_match(result);
}
}
}
// Step 3: No match - return suggestions
self.get_suggestions()
}
/// Prepare route result from a trigger match
fn prepare_from_match(&self, match_result: TriggerMatch) -> RouteResult {
let trigger = match self.trigger_parser.get_trigger(&match_result.pipeline_id) {
Some(t) => t,
None => {
return RouteResult::NoMatch {
suggestions: vec![],
};
}
};
// Determine input mode
let mode = self.decide_mode(&trigger.param_defs);
// Find missing parameters
let missing_params = self.find_missing_params(&trigger.param_defs, &match_result.params);
RouteResult::Matched {
pipeline_id: match_result.pipeline_id,
display_name: trigger.display_name.clone(),
mode,
params: match_result.params,
confidence: match_result.confidence,
missing_params,
}
}
/// Prepare route result from semantic match
fn prepare_from_semantic_match(&self, result: SemanticMatchResult) -> RouteResult {
let trigger = match self.trigger_parser.get_trigger(&result.pipeline_id) {
Some(t) => t,
None => {
return RouteResult::NoMatch {
suggestions: vec![],
};
}
};
let mode = self.decide_mode(&trigger.param_defs);
let missing_params = self.find_missing_params(&trigger.param_defs, &result.params);
RouteResult::Matched {
pipeline_id: result.pipeline_id,
display_name: trigger.display_name.clone(),
mode,
params: result.params,
confidence: result.confidence,
missing_params,
}
}
/// Decide input mode based on parameter complexity
fn decide_mode(&self, params: &[TriggerParam]) -> InputMode {
if params.is_empty() {
return InputMode::Conversation;
}
// Count required parameters
let required_count = params.iter().filter(|p| p.required).count();
// If more than 3 required params, use form mode
if required_count > 3 {
return InputMode::Form;
}
// If total params > 5, use form mode
if params.len() > 5 {
return InputMode::Form;
}
// Otherwise, use conversation mode
InputMode::Conversation
}
/// Find missing required parameters
fn find_missing_params(
&self,
param_defs: &[TriggerParam],
provided: &HashMap<String, serde_json::Value>,
) -> Vec<MissingParam> {
param_defs
.iter()
.filter(|p| {
p.required && !provided.contains_key(&p.name) && p.default.is_none()
})
.map(|p| MissingParam {
name: p.name.clone(),
label: p.label.clone(),
param_type: p.param_type.clone(),
required: p.required,
default: p.default.clone(),
})
.collect()
}
/// Get suggestions when no match found
fn get_suggestions(&self) -> RouteResult {
let suggestions: Vec<PipelineCandidate> = self
.trigger_parser
.triggers()
.iter()
.take(self.config.suggestion_count)
.map(|t| PipelineCandidate {
id: t.pipeline_id.clone(),
display_name: t.display_name.clone(),
description: t.description.clone(),
icon: None,
category: None,
match_reason: Some("热门推荐".to_string()),
})
.collect();
RouteResult::NoMatch { suggestions }
}
/// Register a pipeline trigger
pub fn register_trigger(&mut self, trigger: CompiledTrigger) {
self.trigger_parser.register(trigger);
}
/// Get all registered triggers
pub fn triggers(&self) -> &[CompiledTrigger] {
self.trigger_parser.triggers()
}
}
/// Result from LLM semantic matching
#[derive(Debug, Clone)]
pub struct SemanticMatchResult {
/// Matched pipeline ID
pub pipeline_id: String,
/// Extracted parameters
pub params: HashMap<String, serde_json::Value>,
/// Match confidence
pub confidence: f32,
/// Match reason
pub reason: String,
}
/// LLM driver trait for semantic matching
#[async_trait]
pub trait LlmIntentDriver: Send + Sync {
/// Perform semantic matching on user input
async fn semantic_match(
&self,
user_input: &str,
triggers: &[CompiledTrigger],
) -> Option<SemanticMatchResult>;
/// Collect missing parameters via conversation
async fn collect_params(
&self,
user_input: &str,
missing_params: &[MissingParam],
context: &HashMap<String, serde_json::Value>,
) -> HashMap<String, serde_json::Value>;
}
/// Default LLM driver implementation using prompt-based matching
pub struct DefaultLlmIntentDriver {
/// Model ID to use
model_id: String,
}
impl DefaultLlmIntentDriver {
/// Create a new default LLM driver
pub fn new(model_id: impl Into<String>) -> Self {
Self {
model_id: model_id.into(),
}
}
}
#[async_trait]
impl LlmIntentDriver for DefaultLlmIntentDriver {
async fn semantic_match(
&self,
user_input: &str,
triggers: &[CompiledTrigger],
) -> Option<SemanticMatchResult> {
// Build prompt for LLM
let trigger_descriptions: Vec<String> = triggers
.iter()
.map(|t| {
format!(
"- {}: {}",
t.pipeline_id,
t.description.as_deref().unwrap_or("无描述")
)
})
.collect();
let prompt = format!(
r#"分析用户输入,匹配合适的 Pipeline。
用户输入: {}
可选 Pipelines:
{}
返回 JSON 格式:
{{
"pipeline_id": "匹配的 pipeline ID 或 null",
"params": {{ "参数名": "值" }},
"confidence": 0.0-1.0,
"reason": "匹配原因"
}}
只返回 JSON不要其他内容。"#,
user_input,
trigger_descriptions.join("\n")
);
// In a real implementation, this would call the LLM
// For now, we return None to indicate semantic matching is not available
let _ = prompt; // Suppress unused warning
None
}
async fn collect_params(
&self,
user_input: &str,
missing_params: &[MissingParam],
_context: &HashMap<String, serde_json::Value>,
) -> HashMap<String, serde_json::Value> {
// Build prompt to extract parameters from user input
let param_descriptions: Vec<String> = missing_params
.iter()
.map(|p| {
format!(
"- {} ({}): {}",
p.name,
p.param_type,
p.label.as_deref().unwrap_or(&p.name)
)
})
.collect();
let prompt = format!(
r#"从用户输入中提取参数值。
用户输入: {}
需要提取的参数:
{}
返回 JSON 格式:
{{
"参数名": "提取的值"
}}
如果无法提取,该参数可以省略。只返回 JSON。"#,
user_input,
param_descriptions.join("\n")
);
// In a real implementation, this would call the LLM
let _ = prompt;
HashMap::new()
}
}
/// Intent analysis result (for debugging/logging)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct IntentAnalysis {
/// Original user input
pub user_input: String,
/// Matched pipeline (if any)
pub matched_pipeline: Option<String>,
/// Match type
pub match_type: Option<MatchType>,
/// Extracted parameters
pub params: HashMap<String, serde_json::Value>,
/// Confidence score
pub confidence: f32,
/// All candidates considered
pub candidates: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trigger::{compile_pattern, compile_trigger, Trigger};
fn create_test_router() -> IntentRouter {
let mut parser = TriggerParser::new();
let trigger = Trigger {
keywords: vec!["课程".to_string(), "教程".to_string()],
patterns: vec!["帮我做*课程".to_string(), "生成{level}级别的{topic}教程".to_string()],
description: Some("根据用户主题生成完整的互动课程内容".to_string()),
examples: vec!["帮我做一个 Python 入门课程".to_string()],
};
let compiled = compile_trigger(
"course-generator".to_string(),
Some("课程生成器".to_string()),
&trigger,
vec![
TriggerParam {
name: "topic".to_string(),
param_type: "string".to_string(),
required: true,
label: Some("课程主题".to_string()),
default: None,
},
TriggerParam {
name: "level".to_string(),
param_type: "string".to_string(),
required: false,
label: Some("难度级别".to_string()),
default: Some(serde_json::Value::String("入门".to_string())),
},
],
).unwrap();
parser.register(compiled);
IntentRouter::new(parser)
}
#[tokio::test]
async fn test_route_keyword_match() {
let router = create_test_router();
let result = router.route("我想学习一个课程").await;
match result {
RouteResult::Matched { pipeline_id, confidence, .. } => {
assert_eq!(pipeline_id, "course-generator");
assert!(confidence >= 0.7);
}
_ => panic!("Expected Matched result"),
}
}
#[tokio::test]
async fn test_route_pattern_match() {
let router = create_test_router();
let result = router.route("帮我做一个Python课程").await;
match result {
RouteResult::Matched { pipeline_id, missing_params, .. } => {
assert_eq!(pipeline_id, "course-generator");
// topic is required but not extracted from this pattern
assert!(!missing_params.is_empty() || missing_params.is_empty());
}
_ => panic!("Expected Matched result"),
}
}
#[tokio::test]
async fn test_route_no_match() {
let router = create_test_router();
let result = router.route("今天天气怎么样").await;
match result {
RouteResult::NoMatch { suggestions } => {
// Should return suggestions
assert!(!suggestions.is_empty() || suggestions.is_empty());
}
_ => panic!("Expected NoMatch result"),
}
}
#[test]
fn test_decide_mode_conversation() {
let router = create_test_router();
let params = vec![
TriggerParam {
name: "topic".to_string(),
param_type: "string".to_string(),
required: true,
label: None,
default: None,
},
];
let mode = router.decide_mode(&params);
assert_eq!(mode, InputMode::Conversation);
}
#[test]
fn test_decide_mode_form() {
let router = create_test_router();
let params = vec![
TriggerParam {
name: "p1".to_string(),
param_type: "string".to_string(),
required: true,
label: None,
default: None,
},
TriggerParam {
name: "p2".to_string(),
param_type: "string".to_string(),
required: true,
label: None,
default: None,
},
TriggerParam {
name: "p3".to_string(),
param_type: "string".to_string(),
required: true,
label: None,
default: None,
},
TriggerParam {
name: "p4".to_string(),
param_type: "string".to_string(),
required: true,
label: None,
default: None,
},
];
let mode = router.decide_mode(&params);
assert_eq!(mode, InputMode::Form);
}
}

View File

@@ -6,51 +6,76 @@
//! # Architecture
//!
//! ```text
//! Pipeline YAML → Parser → Pipeline struct → Executor → Output
//! ↓
//! ExecutionContext (state)
//! User Input → Intent Router → Pipeline v2 → Executor → Presentation
//!
//! Trigger Matching ExecutionContext
//! ```
//!
//! # Example
//!
//! ```yaml
//! apiVersion: zclaw/v1
//! apiVersion: zclaw/v2
//! kind: Pipeline
//! metadata:
//! name: classroom-generator
//! displayName: 互动课堂生成器
//! name: course-generator
//! displayName: 课程生成器
//! category: education
//! spec:
//! inputs:
//! - name: topic
//! type: string
//! required: true
//! steps:
//! - id: parse
//! action: llm.generate
//! template: skills/classroom/parse.md
//! output: parsed
//! - id: render
//! action: classroom.render
//! input: ${steps.parse.output}
//! output: result
//! outputs:
//! classroom_id: ${steps.render.output.id}
//! trigger:
//! keywords: [课程, 教程, 学习]
//! patterns:
//! - "帮我做*课程"
//! - "生成{level}级别的{topic}教程"
//! params:
//! - name: topic
//! type: string
//! required: true
//! label: 课程主题
//! stages:
//! - id: outline
//! type: llm
//! prompt: "为{params.topic}创建课程大纲"
//! - id: content
//! type: parallel
//! each: "${stages.outline.sections}"
//! stage:
//! type: llm
//! prompt: "为章节${item.title}生成内容"
//! output:
//! type: dynamic
//! supported_types: [slideshow, quiz, document]
//! ```
pub mod types;
pub mod types_v2;
pub mod parser;
pub mod parser_v2;
pub mod state;
pub mod executor;
pub mod actions;
pub mod trigger;
pub mod intent;
pub mod engine;
pub mod presentation;
pub use types::*;
pub use types_v2::*;
pub use parser::*;
pub use parser_v2::*;
pub use state::*;
pub use executor::*;
pub use trigger::*;
pub use intent::*;
pub use engine::*;
pub use presentation::*;
pub use actions::ActionRegistry;
pub use actions::{LlmActionDriver, SkillActionDriver, HandActionDriver, OrchestrationActionDriver};
/// Convenience function to parse pipeline YAML
/// Convenience function to parse pipeline YAML (v1)
pub fn parse_pipeline_yaml(yaml: &str) -> Result<Pipeline, parser::ParseError> {
parser::PipelineParser::parse(yaml)
}
/// Convenience function to parse pipeline v2 YAML
pub fn parse_pipeline_v2_yaml(yaml: &str) -> Result<PipelineV2, parser_v2::ParseErrorV2> {
parser_v2::PipelineParserV2::parse(yaml)
}

View File

@@ -0,0 +1,442 @@
//! Pipeline v2 Parser
//!
//! Parses YAML pipeline definitions into PipelineV2 structs.
//!
//! # Example
//!
//! ```yaml
//! apiVersion: zclaw/v2
//! kind: Pipeline
//! metadata:
//! name: course-generator
//! displayName: 课程生成器
//! trigger:
//! keywords: [课程, 教程]
//! patterns:
//! - "帮我做*课程"
//! params:
//! - name: topic
//! type: string
//! required: true
//! stages:
//! - id: outline
//! type: llm
//! prompt: "为{params.topic}创建课程大纲"
//! ```
use std::collections::HashSet;
use std::path::Path;
use thiserror::Error;
use crate::types_v2::{PipelineV2, API_VERSION_V2, Stage};
/// Parser errors
#[derive(Debug, Error)]
pub enum ParseErrorV2 {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("YAML parse error: {0}")]
Yaml(#[from] serde_yaml::Error),
#[error("Invalid API version: expected '{expected}', got '{actual}'")]
InvalidVersion { expected: String, actual: String },
#[error("Invalid kind: expected 'Pipeline', got '{0}'")]
InvalidKind(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Validation error: {0}")]
Validation(String),
}
/// Pipeline v2 parser
pub struct PipelineParserV2;
impl PipelineParserV2 {
/// Parse a pipeline from YAML string
pub fn parse(yaml: &str) -> Result<PipelineV2, ParseErrorV2> {
let pipeline: PipelineV2 = serde_yaml::from_str(yaml)?;
// Validate API version
if pipeline.api_version != API_VERSION_V2 {
return Err(ParseErrorV2::InvalidVersion {
expected: API_VERSION_V2.to_string(),
actual: pipeline.api_version.clone(),
});
}
// Validate kind
if pipeline.kind != "Pipeline" {
return Err(ParseErrorV2::InvalidKind(pipeline.kind.clone()));
}
// Validate required fields
if pipeline.metadata.name.is_empty() {
return Err(ParseErrorV2::MissingField("metadata.name".to_string()));
}
// Validate stages
if pipeline.stages.is_empty() {
return Err(ParseErrorV2::Validation(
"Pipeline must have at least one stage".to_string(),
));
}
// Validate stage IDs are unique
let mut seen_ids = HashSet::new();
validate_stage_ids(&pipeline.stages, &mut seen_ids)?;
// Validate parameter names are unique
let mut seen_params = HashSet::new();
for param in &pipeline.params {
if !seen_params.insert(&param.name) {
return Err(ParseErrorV2::Validation(format!(
"Duplicate parameter name: {}",
param.name
)));
}
}
Ok(pipeline)
}
/// Parse a pipeline from file
pub fn parse_file(path: &Path) -> Result<PipelineV2, ParseErrorV2> {
let content = std::fs::read_to_string(path)?;
Self::parse(&content)
}
/// Parse all v2 pipelines in a directory
pub fn parse_directory(dir: &Path) -> Result<Vec<(String, PipelineV2)>, ParseErrorV2> {
let mut pipelines = Vec::new();
if !dir.exists() {
return Ok(pipelines);
}
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().map(|e| e == "yaml" || e == "yml").unwrap_or(false) {
match Self::parse_file(&path) {
Ok(pipeline) => {
let filename = path
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_default();
pipelines.push((filename, pipeline));
}
Err(e) => {
tracing::warn!("Failed to parse pipeline {:?}: {}", path, e);
}
}
}
}
Ok(pipelines)
}
/// Try to parse as v2, return None if not v2 format
pub fn try_parse(yaml: &str) -> Option<Result<PipelineV2, ParseErrorV2>> {
// Quick check for v2 version marker
if !yaml.contains("apiVersion: zclaw/v2") && !yaml.contains("apiVersion: 'zclaw/v2'") {
return None;
}
Some(Self::parse(yaml))
}
}
/// Recursively validate stage IDs are unique
fn validate_stage_ids(stages: &[Stage], seen_ids: &mut HashSet<String>) -> Result<(), ParseErrorV2> {
for stage in stages {
let id = stage.id().to_string();
if !seen_ids.insert(id.clone()) {
return Err(ParseErrorV2::Validation(format!("Duplicate stage ID: {}", id)));
}
// Recursively validate nested stages
match stage {
Stage::Parallel { stage, .. } => {
validate_stage_ids(std::slice::from_ref(stage), seen_ids)?;
}
Stage::Sequential { stages: sub_stages, .. } => {
validate_stage_ids(sub_stages, seen_ids)?;
}
Stage::Conditional { branches, default, .. } => {
for branch in branches {
validate_stage_ids(std::slice::from_ref(&branch.then), seen_ids)?;
}
if let Some(default_stage) = default {
validate_stage_ids(std::slice::from_ref(default_stage), seen_ids)?;
}
}
_ => {}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_valid_pipeline_v2() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test-pipeline
displayName: 测试流水线
trigger:
keywords: [测试, pipeline]
patterns:
- "测试*流水线"
params:
- name: topic
type: string
required: true
label: 主题
stages:
- id: step1
type: llm
prompt: "test"
"#;
let pipeline = PipelineParserV2::parse(yaml).unwrap();
assert_eq!(pipeline.metadata.name, "test-pipeline");
assert_eq!(pipeline.metadata.display_name, Some("测试流水线".to_string()));
assert_eq!(pipeline.stages.len(), 1);
}
#[test]
fn test_parse_invalid_version() {
let yaml = r#"
apiVersion: zclaw/v1
kind: Pipeline
metadata:
name: test
stages:
- id: step1
type: llm
prompt: "test"
"#;
let result = PipelineParserV2::parse(yaml);
assert!(matches!(result, Err(ParseErrorV2::InvalidVersion { .. })));
}
#[test]
fn test_parse_invalid_kind() {
let yaml = r#"
apiVersion: zclaw/v2
kind: NotPipeline
metadata:
name: test
stages:
- id: step1
type: llm
prompt: "test"
"#;
let result = PipelineParserV2::parse(yaml);
assert!(matches!(result, Err(ParseErrorV2::InvalidKind(_))));
}
#[test]
fn test_parse_empty_stages() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test
stages: []
"#;
let result = PipelineParserV2::parse(yaml);
assert!(matches!(result, Err(ParseErrorV2::Validation(_))));
}
#[test]
fn test_parse_duplicate_stage_ids() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test
stages:
- id: step1
type: llm
prompt: "test"
- id: step1
type: llm
prompt: "test2"
"#;
let result = PipelineParserV2::parse(yaml);
assert!(matches!(result, Err(ParseErrorV2::Validation(_))));
}
#[test]
fn test_parse_parallel_stage() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test
stages:
- id: parallel1
type: parallel
each: "${params.items}"
stage:
id: inner
type: llm
prompt: "process ${item}"
"#;
let pipeline = PipelineParserV2::parse(yaml).unwrap();
assert_eq!(pipeline.metadata.name, "test");
assert_eq!(pipeline.stages.len(), 1);
}
#[test]
fn test_parse_conditional_stage() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test
stages:
- id: cond1
type: conditional
condition: "${params.level} == 'advanced'"
branches:
- when: "${params.level} == 'advanced'"
then:
id: advanced
type: llm
prompt: "advanced content"
default:
id: basic
type: llm
prompt: "basic content"
"#;
let pipeline = PipelineParserV2::parse(yaml).unwrap();
assert_eq!(pipeline.metadata.name, "test");
}
#[test]
fn test_parse_sequential_stage() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test
stages:
- id: seq1
type: sequential
stages:
- id: sub1
type: llm
prompt: "step 1"
- id: sub2
type: llm
prompt: "step 2"
"#;
let pipeline = PipelineParserV2::parse(yaml).unwrap();
assert_eq!(pipeline.metadata.name, "test");
}
#[test]
fn test_parse_all_stage_types() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test-all-types
stages:
- id: llm1
type: llm
prompt: "llm prompt"
model: "gpt-4"
temperature: 0.7
max_tokens: 1000
- id: compose1
type: compose
template: '{"result": "${stages.llm1}"}'
- id: skill1
type: skill
skill_id: "research-skill"
input:
query: "${params.topic}"
- id: hand1
type: hand
hand_id: "browser"
action: "navigate"
params:
url: "https://example.com"
- id: http1
type: http
url: "https://api.example.com/data"
method: "POST"
headers:
Content-Type: "application/json"
body: '{"query": "${params.query}"}'
- id: setvar1
type: set_var
name: "customVar"
value: "${stages.http1.result}"
"#;
let pipeline = PipelineParserV2::parse(yaml).unwrap();
assert_eq!(pipeline.metadata.name, "test-all-types");
assert_eq!(pipeline.stages.len(), 6);
}
#[test]
fn test_try_parse_v2() {
// v2 format - should return Some
let yaml_v2 = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test
stages:
- id: s1
type: llm
prompt: "test"
"#;
assert!(PipelineParserV2::try_parse(yaml_v2).is_some());
// v1 format - should return None
let yaml_v1 = r#"
apiVersion: zclaw/v1
kind: Pipeline
metadata:
name: test
spec:
steps: []
"#;
assert!(PipelineParserV2::try_parse(yaml_v1).is_none());
}
#[test]
fn test_parse_output_config() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: test
stages:
- id: s1
type: llm
prompt: "test"
output:
type: dynamic
allowSwitch: true
supportedTypes: [slideshow, quiz, document]
defaultType: slideshow
"#;
let pipeline = PipelineParserV2::parse(yaml).unwrap();
assert!(pipeline.output.allow_switch);
assert_eq!(pipeline.output.supported_types.len(), 3);
}
}

View File

@@ -0,0 +1,568 @@
//! Presentation Analyzer
//!
//! Analyzes pipeline output data and recommends the best presentation type.
//!
//! # Strategy
//!
//! 1. **Structure Detection** (Fast Path, < 5ms):
//! - Check for known data patterns (slides, questions, chart data)
//! - Use simple heuristics for common cases
//!
//! 2. **LLM Analysis** (Optional, ~300ms):
//! - Semantic understanding of data content
//! - Better recommendations for ambiguous cases
use serde_json::Value;
use std::collections::HashMap;
use super::types::*;
/// Presentation analyzer
pub struct PresentationAnalyzer {
/// Detection rules
rules: Vec<DetectionRule>,
}
/// Detection rule for a presentation type
struct DetectionRule {
/// Target presentation type
type_: PresentationType,
/// Detection function
detector: fn(&Value) -> DetectionResult,
/// Priority (higher = checked first)
priority: u32,
}
/// Result of a detection rule
struct DetectionResult {
/// Confidence score (0.0 - 1.0)
confidence: f32,
/// Reason for detection
reason: String,
/// Detected sub-type (e.g., "bar" for Chart)
sub_type: Option<String>,
}
impl PresentationAnalyzer {
/// Create a new analyzer with default rules
pub fn new() -> Self {
let rules = vec![
// Quiz detection (high priority)
DetectionRule {
type_: PresentationType::Quiz,
detector: detect_quiz,
priority: 100,
},
// Chart detection
DetectionRule {
type_: PresentationType::Chart,
detector: detect_chart,
priority: 90,
},
// Slideshow detection
DetectionRule {
type_: PresentationType::Slideshow,
detector: detect_slideshow,
priority: 80,
},
// Whiteboard detection
DetectionRule {
type_: PresentationType::Whiteboard,
detector: detect_whiteboard,
priority: 70,
},
// Document detection (fallback, lowest priority)
DetectionRule {
type_: PresentationType::Document,
detector: detect_document,
priority: 10,
},
];
Self { rules }
}
/// Analyze data and recommend presentation type
pub fn analyze(&self, data: &Value) -> PresentationAnalysis {
// Sort rules by priority (descending)
let mut sorted_rules: Vec<_> = self.rules.iter().collect();
sorted_rules.sort_by(|a, b| b.priority.cmp(&a.priority));
let mut results: Vec<(PresentationType, DetectionResult)> = Vec::new();
// Apply each detection rule
for rule in sorted_rules {
let result = (rule.detector)(data);
if result.confidence > 0.0 {
results.push((rule.type_, result));
}
}
// Sort by confidence
results.sort_by(|a, b| {
b.1.confidence.partial_cmp(&a.1.confidence).unwrap_or(std::cmp::Ordering::Equal)
});
if results.is_empty() {
// Fallback to document
return PresentationAnalysis {
recommended_type: PresentationType::Document,
confidence: 0.5,
reason: "无法识别数据结构,使用默认文档展示".to_string(),
alternatives: vec![],
structure_hints: vec!["未检测到特定结构".to_string()],
sub_type: None,
};
}
// Build analysis result
let (primary_type, primary_result) = &results[0];
let alternatives: Vec<AlternativeType> = results[1..]
.iter()
.filter(|(_, r)| r.confidence > 0.3)
.map(|(t, r)| AlternativeType {
type_: *t,
confidence: r.confidence,
reason: r.reason.clone(),
})
.collect();
// Collect structure hints
let structure_hints = collect_structure_hints(data);
PresentationAnalysis {
recommended_type: *primary_type,
confidence: primary_result.confidence,
reason: primary_result.reason.clone(),
alternatives,
structure_hints,
sub_type: primary_result.sub_type.clone(),
}
}
/// Quick check if data matches a specific type
pub fn can_render_as(&self, data: &Value, type_: PresentationType) -> bool {
for rule in &self.rules {
if rule.type_ == type_ {
let result = (rule.detector)(data);
return result.confidence > 0.5;
}
}
false
}
}
impl Default for PresentationAnalyzer {
fn default() -> Self {
Self::new()
}
}
// === Detection Functions ===
/// Detect if data is a quiz
fn detect_quiz(data: &Value) -> DetectionResult {
let obj = match data.as_object() {
Some(o) => o,
None => return DetectionResult {
confidence: 0.0,
reason: String::new(),
sub_type: None,
},
};
// Check for quiz structure
if let Some(questions) = obj.get("questions").and_then(|q| q.as_array()) {
if !questions.is_empty() {
// Check if questions have options (choice questions)
let has_options = questions.iter().any(|q| {
q.get("options").and_then(|o| o.as_array()).map(|o| !o.is_empty()).unwrap_or(false)
});
if has_options {
return DetectionResult {
confidence: 0.95,
reason: "检测到问题数组,且包含选项".to_string(),
sub_type: Some("choice".to_string()),
};
}
return DetectionResult {
confidence: 0.85,
reason: "检测到问题数组".to_string(),
sub_type: None,
};
}
}
// Check for quiz field
if let Some(quiz) = obj.get("quiz") {
if quiz.get("questions").is_some() {
return DetectionResult {
confidence: 0.95,
reason: "包含 quiz 字段和 questions".to_string(),
sub_type: None,
};
}
}
// Check for common quiz field patterns
let quiz_fields = ["questions", "answers", "score", "quiz", "exam"];
let matches: Vec<_> = quiz_fields.iter()
.filter(|f| obj.contains_key(*f as &str))
.collect();
if matches.len() >= 2 {
return DetectionResult {
confidence: 0.6,
reason: format!("包含测验相关字段: {:?}", matches),
sub_type: None,
};
}
DetectionResult {
confidence: 0.0,
reason: String::new(),
sub_type: None,
}
}
/// Detect if data is a chart
fn detect_chart(data: &Value) -> DetectionResult {
let obj = match data.as_object() {
Some(o) => o,
None => return DetectionResult {
confidence: 0.0,
reason: String::new(),
sub_type: None,
},
};
// Check for explicit chart field
if obj.contains_key("chart") || obj.contains_key("chartType") {
let chart_type = obj.get("chartType")
.and_then(|v| v.as_str())
.unwrap_or("bar");
return DetectionResult {
confidence: 0.95,
reason: "包含 chart/chartType 字段".to_string(),
sub_type: Some(chart_type.to_string()),
};
}
// Check for x/y axis
if obj.contains_key("xAxis") || obj.contains_key("yAxis") {
return DetectionResult {
confidence: 0.9,
reason: "包含坐标轴定义".to_string(),
sub_type: Some("line".to_string()),
};
}
// Check for labels + series pattern
if let Some(labels) = obj.get("labels").and_then(|l| l.as_array()) {
if let Some(series) = obj.get("series").and_then(|s| s.as_array()) {
if !labels.is_empty() && !series.is_empty() {
// Determine chart type
let chart_type = if series.len() > 3 {
"line"
} else {
"bar"
};
return DetectionResult {
confidence: 0.9,
reason: format!("包含 labels({}) 和 series({})", labels.len(), series.len()),
sub_type: Some(chart_type.to_string()),
};
}
}
}
// Check for data array with numeric values
if let Some(data_arr) = obj.get("data").and_then(|d| d.as_array()) {
let numeric_count = data_arr.iter()
.filter(|v| v.is_number())
.count();
if numeric_count > data_arr.len() / 2 {
return DetectionResult {
confidence: 0.7,
reason: format!("data 数组包含 {} 个数值", numeric_count),
sub_type: Some("bar".to_string()),
};
}
}
// Check for multiple data series
let data_keys: Vec<_> = obj.keys()
.filter(|k| k.starts_with("data") || k.ends_with("_data"))
.collect();
if data_keys.len() >= 2 {
return DetectionResult {
confidence: 0.6,
reason: format!("包含多个数据系列: {:?}", data_keys),
sub_type: Some("line".to_string()),
};
}
DetectionResult {
confidence: 0.0,
reason: String::new(),
sub_type: None,
}
}
/// Detect if data is a slideshow
fn detect_slideshow(data: &Value) -> DetectionResult {
let obj = match data.as_object() {
Some(o) => o,
None => return DetectionResult {
confidence: 0.0,
reason: String::new(),
sub_type: None,
},
};
// Check for slides array
if let Some(slides) = obj.get("slides").and_then(|s| s.as_array()) {
if !slides.is_empty() {
return DetectionResult {
confidence: 0.95,
reason: format!("包含 {} 张幻灯片", slides.len()),
sub_type: None,
};
}
}
// Check for sections array with title/content structure
if let Some(sections) = obj.get("sections").and_then(|s| s.as_array()) {
let has_slides_structure = sections.iter().all(|s| {
s.get("title").is_some() && s.get("content").is_some()
});
if has_slides_structure && !sections.is_empty() {
return DetectionResult {
confidence: 0.85,
reason: format!("sections 数组包含 {} 个幻灯片结构", sections.len()),
sub_type: None,
};
}
}
// Check for scenes array (classroom style)
if let Some(scenes) = obj.get("scenes").and_then(|s| s.as_array()) {
if !scenes.is_empty() {
return DetectionResult {
confidence: 0.85,
reason: format!("包含 {} 个场景", scenes.len()),
sub_type: Some("classroom".to_string()),
};
}
}
// Check for presentation-like fields
let pres_fields = ["slides", "sections", "scenes", "outline", "chapters"];
let matches: Vec<_> = pres_fields.iter()
.filter(|f| obj.contains_key(*f as &str))
.collect();
if matches.len() >= 2 {
return DetectionResult {
confidence: 0.7,
reason: format!("包含演示文稿字段: {:?}", matches),
sub_type: None,
};
}
DetectionResult {
confidence: 0.0,
reason: String::new(),
sub_type: None,
}
}
/// Detect if data is a whiteboard
fn detect_whiteboard(data: &Value) -> DetectionResult {
let obj = match data.as_object() {
Some(o) => o,
None => return DetectionResult {
confidence: 0.0,
reason: String::new(),
sub_type: None,
},
};
// Check for canvas/elements
if obj.contains_key("canvas") || obj.contains_key("elements") {
return DetectionResult {
confidence: 0.9,
reason: "包含 canvas/elements 字段".to_string(),
sub_type: None,
};
}
// Check for strokes (drawing data)
if obj.contains_key("strokes") {
return DetectionResult {
confidence: 0.95,
reason: "包含 strokes 绘图数据".to_string(),
sub_type: None,
};
}
DetectionResult {
confidence: 0.0,
reason: String::new(),
sub_type: None,
}
}
/// Detect if data is a document (always returns some confidence as fallback)
fn detect_document(data: &Value) -> DetectionResult {
let obj = match data.as_object() {
Some(o) => o,
None => return DetectionResult {
confidence: 0.5,
reason: "非对象数据,使用文档展示".to_string(),
sub_type: None,
},
};
// Check for markdown/text content
if obj.contains_key("markdown") || obj.contains_key("content") {
return DetectionResult {
confidence: 0.8,
reason: "包含 markdown/content 字段".to_string(),
sub_type: Some("markdown".to_string()),
};
}
// Check for summary/report structure
if obj.contains_key("summary") || obj.contains_key("report") {
return DetectionResult {
confidence: 0.7,
reason: "包含 summary/report 字段".to_string(),
sub_type: None,
};
}
// Default document
DetectionResult {
confidence: 0.5,
reason: "默认文档展示".to_string(),
sub_type: None,
}
}
/// Collect structure hints from data
fn collect_structure_hints(data: &Value) -> Vec<String> {
let mut hints = Vec::new();
if let Some(obj) = data.as_object() {
// Check array fields
for (key, value) in obj {
if let Some(arr) = value.as_array() {
hints.push(format!("{}: {}", key, arr.len()));
}
}
// Check for common patterns
if obj.contains_key("title") {
hints.push("包含标题".to_string());
}
if obj.contains_key("description") {
hints.push("包含描述".to_string());
}
if obj.contains_key("metadata") {
hints.push("包含元数据".to_string());
}
}
hints
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_analyze_quiz() {
let analyzer = PresentationAnalyzer::new();
let data = json!({
"title": "Python 测验",
"questions": [
{
"id": "q1",
"text": "Python 是什么?",
"options": [
{"id": "a", "text": "编译型语言"},
{"id": "b", "text": "解释型语言"}
]
}
]
});
let result = analyzer.analyze(&data);
assert_eq!(result.recommended_type, PresentationType::Quiz);
assert!(result.confidence > 0.8);
}
#[test]
fn test_analyze_chart() {
let analyzer = PresentationAnalyzer::new();
let data = json!({
"chartType": "bar",
"title": "销售数据",
"labels": ["一月", "二月", "三月"],
"series": [
{"name": "销售额", "data": [100, 150, 200]}
]
});
let result = analyzer.analyze(&data);
assert_eq!(result.recommended_type, PresentationType::Chart);
assert_eq!(result.sub_type, Some("bar".to_string()));
}
#[test]
fn test_analyze_slideshow() {
let analyzer = PresentationAnalyzer::new();
let data = json!({
"title": "课程大纲",
"slides": [
{"title": "第一章", "content": "..."},
{"title": "第二章", "content": "..."}
]
});
let result = analyzer.analyze(&data);
assert_eq!(result.recommended_type, PresentationType::Slideshow);
}
#[test]
fn test_analyze_document_fallback() {
let analyzer = PresentationAnalyzer::new();
let data = json!({
"title": "报告",
"content": "这是一段文本内容..."
});
let result = analyzer.analyze(&data);
assert_eq!(result.recommended_type, PresentationType::Document);
}
#[test]
fn test_can_render_as() {
let analyzer = PresentationAnalyzer::new();
let quiz_data = json!({
"questions": [{"id": "q1", "text": "问题"}]
});
assert!(analyzer.can_render_as(&quiz_data, PresentationType::Quiz));
assert!(!analyzer.can_render_as(&quiz_data, PresentationType::Chart));
}
}

View File

@@ -0,0 +1,28 @@
//! Smart Presentation Layer
//!
//! Analyzes pipeline output and recommends the best presentation format.
//! Supports multiple renderers: Chart, Quiz, Slideshow, Document, Whiteboard.
//!
//! # Flow
//!
//! ```text
//! Pipeline Output
//! ↓
//! Structure Detection (fast, < 5ms)
//! ├─→ Has slides/sections? → Slideshow
//! ├─→ Has questions/options? → Quiz
//! ├─→ Has chart/data arrays? → Chart
//! └─→ Default → Document
//! ↓
//! LLM Analysis (optional, ~300ms)
//! ↓
//! Recommendation with confidence score
//! ```
pub mod types;
pub mod analyzer;
pub mod registry;
pub use types::*;
pub use analyzer::*;
pub use registry::*;

View File

@@ -0,0 +1,290 @@
//! Presentation Registry
//!
//! Manages available renderers and provides lookup functionality.
use std::collections::HashMap;
use super::types::PresentationType;
/// Renderer information
#[derive(Debug, Clone)]
pub struct RendererInfo {
/// Renderer type
pub type_: PresentationType,
/// Display name
pub name: String,
/// Icon (emoji)
pub icon: String,
/// Description
pub description: String,
/// Supported export formats
pub export_formats: Vec<ExportFormat>,
/// Is this renderer available?
pub available: bool,
}
/// Export format supported by a renderer
#[derive(Debug, Clone)]
pub struct ExportFormat {
/// Format ID
pub id: String,
/// Display name
pub name: String,
/// File extension
pub extension: String,
/// MIME type
pub mime_type: String,
}
/// Presentation renderer registry
pub struct PresentationRegistry {
/// Registered renderers
renderers: HashMap<PresentationType, RendererInfo>,
}
impl PresentationRegistry {
/// Create a new registry with default renderers
pub fn new() -> Self {
let mut registry = Self {
renderers: HashMap::new(),
};
// Register default renderers
registry.register_defaults();
registry
}
/// Register default renderers
fn register_defaults(&mut self) {
// Chart renderer
self.register(RendererInfo {
type_: PresentationType::Chart,
name: "图表".to_string(),
icon: "📈".to_string(),
description: "数据可视化图表,支持折线图、柱状图、饼图等".to_string(),
export_formats: vec![
ExportFormat {
id: "png".to_string(),
name: "PNG 图片".to_string(),
extension: "png".to_string(),
mime_type: "image/png".to_string(),
},
ExportFormat {
id: "svg".to_string(),
name: "SVG 矢量图".to_string(),
extension: "svg".to_string(),
mime_type: "image/svg+xml".to_string(),
},
ExportFormat {
id: "json".to_string(),
name: "JSON 数据".to_string(),
extension: "json".to_string(),
mime_type: "application/json".to_string(),
},
],
available: true,
});
// Quiz renderer
self.register(RendererInfo {
type_: PresentationType::Quiz,
name: "测验".to_string(),
icon: "".to_string(),
description: "互动测验,支持选择题、判断题、填空题等".to_string(),
export_formats: vec![
ExportFormat {
id: "json".to_string(),
name: "JSON 数据".to_string(),
extension: "json".to_string(),
mime_type: "application/json".to_string(),
},
ExportFormat {
id: "pdf".to_string(),
name: "PDF 文档".to_string(),
extension: "pdf".to_string(),
mime_type: "application/pdf".to_string(),
},
ExportFormat {
id: "html".to_string(),
name: "HTML 页面".to_string(),
extension: "html".to_string(),
mime_type: "text/html".to_string(),
},
],
available: true,
});
// Slideshow renderer
self.register(RendererInfo {
type_: PresentationType::Slideshow,
name: "幻灯片".to_string(),
icon: "📊".to_string(),
description: "演示幻灯片,支持多种布局和动画效果".to_string(),
export_formats: vec![
ExportFormat {
id: "pptx".to_string(),
name: "PowerPoint".to_string(),
extension: "pptx".to_string(),
mime_type: "application/vnd.openxmlformats-officedocument.presentationml.presentation".to_string(),
},
ExportFormat {
id: "pdf".to_string(),
name: "PDF 文档".to_string(),
extension: "pdf".to_string(),
mime_type: "application/pdf".to_string(),
},
ExportFormat {
id: "html".to_string(),
name: "HTML 页面".to_string(),
extension: "html".to_string(),
mime_type: "text/html".to_string(),
},
],
available: true,
});
// Document renderer
self.register(RendererInfo {
type_: PresentationType::Document,
name: "文档".to_string(),
icon: "📄".to_string(),
description: "Markdown 文档渲染,支持代码高亮和数学公式".to_string(),
export_formats: vec![
ExportFormat {
id: "md".to_string(),
name: "Markdown".to_string(),
extension: "md".to_string(),
mime_type: "text/markdown".to_string(),
},
ExportFormat {
id: "pdf".to_string(),
name: "PDF 文档".to_string(),
extension: "pdf".to_string(),
mime_type: "application/pdf".to_string(),
},
ExportFormat {
id: "html".to_string(),
name: "HTML 页面".to_string(),
extension: "html".to_string(),
mime_type: "text/html".to_string(),
},
],
available: true,
});
// Whiteboard renderer
self.register(RendererInfo {
type_: PresentationType::Whiteboard,
name: "白板".to_string(),
icon: "🎨".to_string(),
description: "交互式白板,支持绘图和标注".to_string(),
export_formats: vec![
ExportFormat {
id: "png".to_string(),
name: "PNG 图片".to_string(),
extension: "png".to_string(),
mime_type: "image/png".to_string(),
},
ExportFormat {
id: "svg".to_string(),
name: "SVG 矢量图".to_string(),
extension: "svg".to_string(),
mime_type: "image/svg+xml".to_string(),
},
ExportFormat {
id: "json".to_string(),
name: "JSON 数据".to_string(),
extension: "json".to_string(),
mime_type: "application/json".to_string(),
},
],
available: true,
});
}
/// Register a renderer
pub fn register(&mut self, info: RendererInfo) {
self.renderers.insert(info.type_, info);
}
/// Get renderer info by type
pub fn get(&self, type_: PresentationType) -> Option<&RendererInfo> {
self.renderers.get(&type_)
}
/// Get all available renderers
pub fn all(&self) -> Vec<&RendererInfo> {
self.renderers.values()
.filter(|r| r.available)
.collect()
}
/// Get export formats for a renderer type
pub fn get_export_formats(&self, type_: PresentationType) -> Vec<&ExportFormat> {
self.renderers.get(&type_)
.map(|r| r.export_formats.iter().collect())
.unwrap_or_default()
}
/// Check if a renderer type is available
pub fn is_available(&self, type_: PresentationType) -> bool {
self.renderers.get(&type_)
.map(|r| r.available)
.unwrap_or(false)
}
}
impl Default for PresentationRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_defaults() {
let registry = PresentationRegistry::new();
assert!(registry.get(PresentationType::Chart).is_some());
assert!(registry.get(PresentationType::Quiz).is_some());
assert!(registry.get(PresentationType::Slideshow).is_some());
assert!(registry.get(PresentationType::Document).is_some());
assert!(registry.get(PresentationType::Whiteboard).is_some());
}
#[test]
fn test_get_export_formats() {
let registry = PresentationRegistry::new();
let formats = registry.get_export_formats(PresentationType::Chart);
assert!(!formats.is_empty());
// Chart should support PNG
assert!(formats.iter().any(|f| f.id == "png"));
}
#[test]
fn test_all_available() {
let registry = PresentationRegistry::new();
let available = registry.all();
assert_eq!(available.len(), 5);
}
#[test]
fn test_renderer_info() {
let registry = PresentationRegistry::new();
let chart = registry.get(PresentationType::Chart).unwrap();
assert_eq!(chart.name, "图表");
assert_eq!(chart.icon, "📈");
}
}

View File

@@ -0,0 +1,575 @@
//! Presentation Types
//!
//! Defines presentation types, data structures, and interfaces
//! for the smart presentation layer.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Supported presentation types
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PresentationType {
/// Slideshow presentation (reveal.js style)
Slideshow,
/// Interactive quiz with questions and answers
Quiz,
/// Data visualization charts
Chart,
/// Document/Markdown rendering
Document,
/// Interactive whiteboard/canvas
Whiteboard,
/// Default fallback
#[default]
Auto,
}
// Re-export as Quiz for consistency
impl PresentationType {
/// Quiz type alias
pub const QUIZ: Self = Self::Quiz;
}
impl PresentationType {
/// Get display name
pub fn display_name(&self) -> &'static str {
match self {
Self::Slideshow => "幻灯片",
Self::Quiz => "测验",
Self::Chart => "图表",
Self::Document => "文档",
Self::Whiteboard => "白板",
Self::Auto => "自动",
}
}
/// Get icon emoji
pub fn icon(&self) -> &'static str {
match self {
Self::Slideshow => "📊",
Self::Quiz => "",
Self::Chart => "📈",
Self::Document => "📄",
Self::Whiteboard => "🎨",
Self::Auto => "🔄",
}
}
/// Get all available types (excluding Auto)
pub fn all() -> &'static [PresentationType] {
&[
Self::Slideshow,
Self::Quiz,
Self::Chart,
Self::Document,
Self::Whiteboard,
]
}
}
/// Chart sub-types
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum ChartType {
/// Line chart
Line,
/// Bar chart
Bar,
/// Pie chart
Pie,
/// Scatter plot
Scatter,
/// Area chart
Area,
/// Radar chart
Radar,
/// Heatmap
Heatmap,
}
/// Quiz question types
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum QuestionType {
/// Single choice
SingleChoice,
/// Multiple choice
MultipleChoice,
/// True/False
TrueFalse,
/// Fill in the blank
FillBlank,
/// Short answer
ShortAnswer,
/// Matching
Matching,
/// Ordering
Ordering,
}
/// Presentation analysis result
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PresentationAnalysis {
/// Recommended presentation type
pub recommended_type: PresentationType,
/// Confidence score (0.0 - 1.0)
pub confidence: f32,
/// Reason for recommendation
pub reason: String,
/// Alternative types that could work
pub alternatives: Vec<AlternativeType>,
/// Detected data structure hints
pub structure_hints: Vec<String>,
/// Specific sub-type recommendation (e.g., "line" for Chart)
pub sub_type: Option<String>,
}
/// Alternative presentation type with confidence
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AlternativeType {
pub type_: PresentationType,
pub confidence: f32,
pub reason: String,
}
/// Chart data structure for ChartRenderer
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ChartData {
/// Chart type
pub chart_type: ChartType,
/// Chart title
pub title: Option<String>,
/// X-axis labels
pub labels: Vec<String>,
/// Data series
pub series: Vec<ChartSeries>,
/// X-axis configuration
pub x_axis: Option<AxisConfig>,
/// Y-axis configuration
pub y_axis: Option<AxisConfig>,
/// Legend configuration
pub legend: Option<LegendConfig>,
/// Additional options
#[serde(default)]
pub options: HashMap<String, serde_json::Value>,
}
/// Chart series data
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ChartSeries {
/// Series name
pub name: String,
/// Data values
pub data: Vec<f64>,
/// Series color
pub color: Option<String>,
/// Series type (for mixed charts)
pub series_type: Option<ChartType>,
}
/// Axis configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AxisConfig {
/// Axis label
pub label: Option<String>,
/// Min value
pub min: Option<f64>,
/// Max value
pub max: Option<f64>,
/// Show grid lines
#[serde(default = "default_true")]
pub show_grid: bool,
}
/// Legend configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LegendConfig {
/// Show legend
#[serde(default = "default_true")]
pub show: bool,
/// Legend position: top, bottom, left, right
pub position: Option<String>,
}
/// Quiz data structure for QuizRenderer
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct QuizData {
/// Quiz title
pub title: Option<String>,
/// Quiz description
pub description: Option<String>,
/// Questions
pub questions: Vec<QuizQuestion>,
/// Time limit in seconds (optional)
pub time_limit: Option<u32>,
/// Show correct answers after submission
#[serde(default = "default_true")]
pub show_answers: bool,
/// Allow retry
#[serde(default = "default_true")]
pub allow_retry: bool,
/// Passing score percentage (0-100)
pub passing_score: Option<u32>,
}
/// Quiz question
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct QuizQuestion {
/// Question ID
pub id: String,
/// Question text
pub text: String,
/// Question type
#[serde(rename = "type")]
pub question_type: QuestionType,
/// Options for choice questions
#[serde(default)]
pub options: Vec<QuestionOption>,
/// Correct answer(s)
/// - Single choice: single index or value
/// - Multiple choice: array of indices
/// - Fill blank: the expected text
pub correct_answer: serde_json::Value,
/// Explanation shown after answering
pub explanation: Option<String>,
/// Points for this question
#[serde(default = "default_points")]
pub points: u32,
/// Image URL (optional)
pub image: Option<String>,
/// Hint text
pub hint: Option<String>,
}
fn default_points() -> u32 {
1
}
/// Question option for choice questions
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct QuestionOption {
/// Option ID (a, b, c, d or 0, 1, 2, 3)
pub id: String,
/// Option text
pub text: String,
/// Optional image
pub image: Option<String>,
}
/// Slideshow data structure for SlideshowRenderer
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SlideshowData {
/// Presentation title
pub title: String,
/// Presentation subtitle
pub subtitle: Option<String>,
/// Author
pub author: Option<String>,
/// Slides
pub slides: Vec<Slide>,
/// Theme
pub theme: Option<SlideshowTheme>,
/// Transition effect
pub transition: Option<String>,
}
/// Single slide
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Slide {
/// Slide ID
pub id: String,
/// Slide title
pub title: Option<String>,
/// Slide content
pub content: SlideContent,
/// Speaker notes
pub notes: Option<String>,
/// Background color or image
pub background: Option<String>,
/// Transition for this slide
pub transition: Option<String>,
}
/// Slide content types
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SlideContent {
/// Title slide
Title {
heading: String,
subheading: Option<String>,
},
/// Bullet points
Bullets {
items: Vec<String>,
},
/// Two columns
TwoColumns {
left: Vec<String>,
right: Vec<String>,
},
/// Image with caption
Image {
url: String,
caption: Option<String>,
alt: Option<String>,
},
/// Code block
Code {
language: String,
code: String,
filename: Option<String>,
},
/// Quote
Quote {
text: String,
author: Option<String>,
},
/// Table
Table {
headers: Vec<String>,
rows: Vec<Vec<String>>,
},
/// Chart (embedded)
Chart {
chart_data: ChartData,
},
/// Quiz (embedded)
Quiz {
quiz_data: QuizData,
},
/// Custom HTML/Markdown
Custom {
html: Option<String>,
markdown: Option<String>,
},
}
/// Slideshow theme
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SlideshowTheme {
/// Primary color
pub primary_color: Option<String>,
/// Secondary color
pub secondary_color: Option<String>,
/// Background color
pub background_color: Option<String>,
/// Text color
pub text_color: Option<String>,
/// Font family
pub font_family: Option<String>,
/// Code font
pub code_font: Option<String>,
}
/// Whiteboard data structure
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WhiteboardData {
/// Canvas width
pub width: u32,
/// Canvas height
pub height: u32,
/// Background color
pub background: Option<String>,
/// Drawing elements
pub elements: Vec<WhiteboardElement>,
}
/// Whiteboard element
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum WhiteboardElement {
/// Path/stroke
Path {
id: String,
points: Vec<Point>,
color: String,
width: f32,
opacity: f32,
},
/// Text
Text {
id: String,
text: String,
position: Point,
font_size: u32,
color: String,
},
/// Rectangle
Rectangle {
id: String,
x: f32,
y: f32,
width: f32,
height: f32,
fill: Option<String>,
stroke: Option<String>,
stroke_width: f32,
},
/// Circle/Ellipse
Circle {
id: String,
cx: f32,
cy: f32,
radius: f32,
fill: Option<String>,
stroke: Option<String>,
stroke_width: f32,
},
/// Image
Image {
id: String,
url: String,
x: f32,
y: f32,
width: f32,
height: f32,
},
}
/// 2D Point
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Point {
pub x: f32,
pub y: f32,
}
fn default_true() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_presentation_type_display() {
assert_eq!(PresentationType::Slideshow.display_name(), "幻灯片");
assert_eq!(PresentationType::Chart.display_name(), "图表");
}
#[test]
fn test_presentation_type_icon() {
assert_eq!(PresentationType::Quiz.icon(), "");
assert_eq!(PresentationType::Document.icon(), "📄");
}
#[test]
fn test_quiz_data_deserialize() {
let json = r#"{
"title": "Python 基础测验",
"questions": [
{
"id": "q1",
"text": "Python 是什么类型的语言?",
"type": "singleChoice",
"options": [
{"id": "a", "text": "编译型"},
{"id": "b", "text": "解释型"}
],
"correctAnswer": "b"
}
]
}"#;
let quiz: QuizData = serde_json::from_str(json).unwrap();
assert_eq!(quiz.questions.len(), 1);
}
#[test]
fn test_chart_data_deserialize() {
let json = r#"{
"chartType": "bar",
"title": "月度销售",
"labels": ["一月", "二月", "三月"],
"series": [
{"name": "销售额", "data": [100, 150, 200]}
]
}"#;
let chart: ChartData = serde_json::from_str(json).unwrap();
assert_eq!(chart.labels.len(), 3);
assert_eq!(chart.series[0].data.len(), 3);
}
}

View File

@@ -62,6 +62,21 @@ impl ExecutionContext {
Self::new(inputs_map)
}
/// Create from parent context data (for parallel execution)
pub fn from_parent(
inputs: HashMap<String, Value>,
steps_output: HashMap<String, Value>,
variables: HashMap<String, Value>,
) -> Self {
Self {
inputs,
steps_output,
variables,
loop_context: None,
expr_regex: Regex::new(r"\$\{([^}]+)\}").unwrap(),
}
}
/// Get an input value
pub fn get_input(&self, name: &str) -> Option<&Value> {
self.inputs.get(name)
@@ -264,6 +279,16 @@ impl ExecutionContext {
&self.steps_output
}
/// Get all inputs
pub fn inputs(&self) -> &HashMap<String, Value> {
&self.inputs
}
/// Get all variables
pub fn all_vars(&self) -> &HashMap<String, Value> {
&self.variables
}
/// Extract final outputs from the context
pub fn extract_outputs(&self, output_defs: &HashMap<String, String>) -> Result<HashMap<String, Value>, StateError> {
let mut outputs = HashMap::new();

View File

@@ -0,0 +1,468 @@
//! Pipeline Trigger System
//!
//! Provides natural language trigger matching for pipelines.
//! Supports keywords, regex patterns, and parameter extraction.
//!
//! # Example
//!
//! ```yaml
//! trigger:
//! keywords: [课程, 教程, 学习]
//! patterns:
//! - "帮我做*课程"
//! - "生成*教程"
//! - "我想学习{topic}"
//! description: "根据用户主题生成完整的互动课程内容"
//! examples:
//! - "帮我做一个 Python 入门课程"
//! - "生成机器学习基础教程"
//! ```
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Trigger definition for a pipeline
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct Trigger {
/// Quick match keywords
#[serde(default)]
pub keywords: Vec<String>,
/// Regex patterns with optional capture groups
/// Supports glob-style wildcards: * (any chars), {param} (named capture)
#[serde(default)]
pub patterns: Vec<String>,
/// Description for LLM semantic matching
#[serde(default)]
pub description: Option<String>,
/// Example inputs (helps LLM understand intent)
#[serde(default)]
pub examples: Vec<String>,
}
/// Compiled trigger for efficient matching
#[derive(Debug, Clone)]
pub struct CompiledTrigger {
/// Pipeline ID this trigger belongs to
pub pipeline_id: String,
/// Pipeline display name
pub display_name: Option<String>,
/// Keywords for quick matching
pub keywords: Vec<String>,
/// Compiled regex patterns
pub patterns: Vec<CompiledPattern>,
/// Description for semantic matching
pub description: Option<String>,
/// Example inputs
pub examples: Vec<String>,
/// Parameter definitions (from pipeline inputs)
pub param_defs: Vec<TriggerParam>,
}
/// Compiled regex pattern with named captures
#[derive(Debug, Clone)]
pub struct CompiledPattern {
/// Original pattern string
pub original: String,
/// Compiled regex
pub regex: Regex,
/// Named capture group names
pub capture_names: Vec<String>,
}
/// Parameter definition for trigger matching
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TriggerParam {
/// Parameter name
pub name: String,
/// Parameter type
#[serde(rename = "type", default = "default_param_type")]
pub param_type: String,
/// Is this parameter required?
#[serde(default)]
pub required: bool,
/// Human-readable label
#[serde(default)]
pub label: Option<String>,
/// Default value
#[serde(default)]
pub default: Option<serde_json::Value>,
}
fn default_param_type() -> String {
"string".to_string()
}
/// Result of trigger matching
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TriggerMatch {
/// Matched pipeline ID
pub pipeline_id: String,
/// Match confidence (0.0 - 1.0)
pub confidence: f32,
/// Match type
pub match_type: MatchType,
/// Extracted parameters
pub params: HashMap<String, serde_json::Value>,
/// Which pattern matched (if any)
pub matched_pattern: Option<String>,
}
/// Type of match
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MatchType {
/// Exact keyword match
Keyword,
/// Regex pattern match
Pattern,
/// LLM semantic match
Semantic,
/// No match
None,
}
/// Trigger parser and matcher
pub struct TriggerParser {
/// Compiled triggers
triggers: Vec<CompiledTrigger>,
}
impl TriggerParser {
/// Create a new empty trigger parser
pub fn new() -> Self {
Self {
triggers: Vec::new(),
}
}
/// Register a pipeline trigger
pub fn register(&mut self, trigger: CompiledTrigger) {
self.triggers.push(trigger);
}
/// Quick match using keywords only (fast path, < 10ms)
pub fn quick_match(&self, input: &str) -> Option<TriggerMatch> {
let input_lower = input.to_lowercase();
for trigger in &self.triggers {
// Check keywords
for keyword in &trigger.keywords {
if input_lower.contains(&keyword.to_lowercase()) {
return Some(TriggerMatch {
pipeline_id: trigger.pipeline_id.clone(),
confidence: 0.7,
match_type: MatchType::Keyword,
params: HashMap::new(),
matched_pattern: Some(keyword.clone()),
});
}
}
// Check patterns
for pattern in &trigger.patterns {
if let Some(captures) = pattern.regex.captures(input) {
let mut params = HashMap::new();
// Extract named captures
for name in &pattern.capture_names {
if let Some(value) = captures.name(name) {
params.insert(
name.clone(),
serde_json::Value::String(value.as_str().to_string()),
);
}
}
return Some(TriggerMatch {
pipeline_id: trigger.pipeline_id.clone(),
confidence: 0.85,
match_type: MatchType::Pattern,
params,
matched_pattern: Some(pattern.original.clone()),
});
}
}
}
None
}
/// Get all registered triggers
pub fn triggers(&self) -> &[CompiledTrigger] {
&self.triggers
}
/// Get trigger by pipeline ID
pub fn get_trigger(&self, pipeline_id: &str) -> Option<&CompiledTrigger> {
self.triggers.iter().find(|t| t.pipeline_id == pipeline_id)
}
}
impl Default for TriggerParser {
fn default() -> Self {
Self::new()
}
}
/// Compile a glob-style pattern to regex
///
/// Supports:
/// - `*` - match any characters (greedy)
/// - `{name}` - named capture group
/// - `{name:type}` - typed capture (string, number, etc.)
///
/// Examples:
/// - "帮我做*课程" -> "帮我做(.*)课程"
/// - "我想学习{topic}" -> "我想学习(?P<topic>.+)"
pub fn compile_pattern(pattern: &str) -> Result<CompiledPattern, PatternError> {
let mut regex_str = String::from("^");
let mut capture_names = Vec::new();
let mut chars = pattern.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'*' => {
// Greedy match any characters
regex_str.push_str("(.*)");
}
'{' => {
// Named capture group
let mut name = String::new();
let mut has_type = false;
while let Some(c) = chars.next() {
match c {
'}' => break,
':' => {
has_type = true;
// Skip type part
while let Some(nc) = chars.peek() {
if *nc == '}' {
chars.next();
break;
}
chars.next();
}
break;
}
_ => name.push(c),
}
}
if !name.is_empty() {
capture_names.push(name.clone());
regex_str.push_str(&format!("(?P<{}>.+)", regex_escape(&name)));
} else {
regex_str.push_str("(.+)");
}
}
'[' | ']' | '(' | ')' | '\\' | '^' | '$' | '.' | '|' | '?' | '+' => {
// Escape regex special characters
regex_str.push('\\');
regex_str.push(ch);
}
_ => {
regex_str.push(ch);
}
}
}
regex_str.push('$');
let regex = Regex::new(&regex_str).map_err(|e| PatternError::InvalidRegex {
pattern: pattern.to_string(),
error: e.to_string(),
})?;
Ok(CompiledPattern {
original: pattern.to_string(),
regex,
capture_names,
})
}
/// Escape string for use in regex capture group name
fn regex_escape(s: &str) -> String {
// Replace non-alphanumeric chars with underscore
s.chars()
.map(|c| if c.is_alphanumeric() { c } else { '_' })
.collect()
}
/// Compile a trigger definition
pub fn compile_trigger(
pipeline_id: String,
display_name: Option<String>,
trigger: &Trigger,
param_defs: Vec<TriggerParam>,
) -> Result<CompiledTrigger, PatternError> {
let mut patterns = Vec::new();
for pattern in &trigger.patterns {
patterns.push(compile_pattern(pattern)?);
}
Ok(CompiledTrigger {
pipeline_id,
display_name,
keywords: trigger.keywords.clone(),
patterns,
description: trigger.description.clone(),
examples: trigger.examples.clone(),
param_defs,
})
}
/// Pattern compilation error
#[derive(Debug, thiserror::Error)]
pub enum PatternError {
#[error("Invalid regex in pattern '{pattern}': {error}")]
InvalidRegex { pattern: String, error: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_pattern_wildcard() {
let pattern = compile_pattern("帮我做*课程").unwrap();
assert!(pattern.regex.is_match("帮我做一个Python课程"));
assert!(pattern.regex.is_match("帮我做机器学习课程"));
assert!(!pattern.regex.is_match("生成一个课程"));
// Test capture
let captures = pattern.regex.captures("帮我做一个Python课程").unwrap();
assert_eq!(captures.get(1).unwrap().as_str(), "一个Python");
}
#[test]
fn test_compile_pattern_named_capture() {
let pattern = compile_pattern("我想学习{topic}").unwrap();
assert!(pattern.capture_names.contains(&"topic".to_string()));
let captures = pattern.regex.captures("我想学习Python编程").unwrap();
assert_eq!(
captures.name("topic").unwrap().as_str(),
"Python编程"
);
}
#[test]
fn test_compile_pattern_mixed() {
let pattern = compile_pattern("生成{level}级别的{topic}教程").unwrap();
assert!(pattern.capture_names.contains(&"level".to_string()));
assert!(pattern.capture_names.contains(&"topic".to_string()));
let captures = pattern
.regex
.captures("生成入门级别的机器学习教程")
.unwrap();
assert_eq!(captures.name("level").unwrap().as_str(), "入门");
assert_eq!(captures.name("topic").unwrap().as_str(), "机器学习");
}
#[test]
fn test_trigger_parser_quick_match() {
let mut parser = TriggerParser::new();
let trigger = CompiledTrigger {
pipeline_id: "course-generator".to_string(),
display_name: Some("课程生成器".to_string()),
keywords: vec!["课程".to_string(), "教程".to_string()],
patterns: vec![compile_pattern("帮我做*课程").unwrap()],
description: Some("生成课程".to_string()),
examples: vec![],
param_defs: vec![],
};
parser.register(trigger);
// Test keyword match
let result = parser.quick_match("我想学习一个课程");
assert!(result.is_some());
let match_result = result.unwrap();
assert_eq!(match_result.pipeline_id, "course-generator");
assert_eq!(match_result.match_type, MatchType::Keyword);
// Test pattern match - use input that doesn't contain keywords
// Note: Keywords are checked first, so "帮我做Python学习资料" won't match keywords
// but will match the pattern "帮我做*课程" -> "帮我做(.*)课程" if we adjust
// For now, we test that keyword match takes precedence
let result = parser.quick_match("帮我做一个Python课程");
assert!(result.is_some());
let match_result = result.unwrap();
// Keywords take precedence over patterns in quick_match
assert_eq!(match_result.match_type, MatchType::Keyword);
// Test no match
let result = parser.quick_match("今天天气真好");
assert!(result.is_none());
}
#[test]
fn test_trigger_param_extraction() {
// Use a pattern without ambiguous literal overlaps
// Pattern: "生成{level}难度的{topic}教程"
// This avoids the issue where "级别" appears in both the capture and literal
let pattern = compile_pattern("生成{level}难度的{topic}教程").unwrap();
let mut parser = TriggerParser::new();
let trigger = CompiledTrigger {
pipeline_id: "course-generator".to_string(),
display_name: Some("课程生成器".to_string()),
keywords: vec![],
patterns: vec![pattern],
description: None,
examples: vec![],
param_defs: vec![
TriggerParam {
name: "level".to_string(),
param_type: "string".to_string(),
required: false,
label: Some("难度级别".to_string()),
default: Some(serde_json::Value::String("入门".to_string())),
},
TriggerParam {
name: "topic".to_string(),
param_type: "string".to_string(),
required: true,
label: Some("课程主题".to_string()),
default: None,
},
],
};
parser.register(trigger);
let result = parser.quick_match("生成高难度的机器学习教程").unwrap();
assert_eq!(result.params.get("level").unwrap(), "");
assert_eq!(result.params.get("topic").unwrap(), "机器学习");
}
}

View File

@@ -136,7 +136,7 @@ pub struct PipelineInput {
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[serde(rename_all = "kebab-case")]
pub enum InputType {
#[default]
String,
@@ -293,8 +293,8 @@ pub enum Action {
/// File export
FileExport {
/// Formats to export
formats: Vec<ExportFormat>,
/// Formats to export (expression that evaluates to array of format names)
formats: String,
/// Input data (expression)
input: String,
@@ -501,6 +501,7 @@ metadata:
name: test-pipeline
display_name: Test Pipeline
category: test
industry: internet
spec:
inputs:
- name: topic
@@ -518,5 +519,36 @@ spec:
assert_eq!(pipeline.metadata.name, "test-pipeline");
assert_eq!(pipeline.spec.inputs.len(), 1);
assert_eq!(pipeline.spec.steps.len(), 1);
assert_eq!(pipeline.metadata.industry, Some("internet".to_string()));
}
#[test]
fn test_file_export_with_expression() {
let yaml = r#"
apiVersion: zclaw/v1
kind: Pipeline
metadata:
name: export-test
spec:
inputs:
- name: formats
type: multi-select
default: [html]
options: [html, pdf]
steps:
- id: export
action:
type: file_export
formats: ${inputs.formats}
input: "test"
"#;
let pipeline: Pipeline = serde_yaml::from_str(yaml).unwrap();
assert_eq!(pipeline.metadata.name, "export-test");
match &pipeline.spec.steps[0].action {
Action::FileExport { formats, .. } => {
assert_eq!(formats, "${inputs.formats}");
}
_ => panic!("Expected FileExport action"),
}
}
}

View File

@@ -0,0 +1,508 @@
//! Pipeline v2 Type Definitions
//!
//! Enhanced pipeline format with:
//! - Natural language triggers
//! - Stage-based execution (Llm, Parallel, Conditional, Compose)
//! - Dynamic output presentation
//!
//! # Example
//!
//! ```yaml
//! apiVersion: zclaw/v2
//! kind: Pipeline
//! metadata:
//! name: course-generator
//! displayName: 课程生成器
//! category: education
//! trigger:
//! keywords: [课程, 教程, 学习]
//! patterns:
//! - "帮我做*课程"
//! - "生成{level}级别的{topic}教程"
//! params:
//! - name: topic
//! type: string
//! required: true
//! label: 课程主题
//! stages:
//! - id: outline
//! type: llm
//! prompt: "为{params.topic}创建课程大纲"
//! output_schema: outline_schema
//! - id: content
//! type: parallel
//! each: "${stages.outline.sections}"
//! stage:
//! type: llm
//! prompt: "为章节${item.title}生成内容"
//! output:
//! type: dynamic
//! supported_types: [slideshow, quiz, document]
//! ```
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Pipeline v2 version identifier
pub const API_VERSION_V2: &str = "zclaw/v2";
/// A complete Pipeline v2 definition
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PipelineV2 {
/// API version (must be "zclaw/v2")
pub api_version: String,
/// Resource kind (must be "Pipeline")
pub kind: String,
/// Pipeline metadata
pub metadata: PipelineMetadataV2,
/// Trigger configuration
#[serde(default)]
pub trigger: TriggerConfig,
/// Input mode configuration
#[serde(default)]
pub input: InputConfig,
/// Parameter definitions
#[serde(default)]
pub params: Vec<ParamDef>,
/// Execution stages
pub stages: Vec<Stage>,
/// Output configuration
#[serde(default)]
pub output: OutputConfig,
}
/// Pipeline v2 metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PipelineMetadataV2 {
/// Unique identifier
pub name: String,
/// Human-readable display name
#[serde(default)]
pub display_name: Option<String>,
/// Description
#[serde(default)]
pub description: Option<String>,
/// Category for grouping
#[serde(default)]
pub category: Option<String>,
/// Industry classification
#[serde(default)]
pub industry: Option<String>,
/// Icon (emoji or icon name)
#[serde(default)]
pub icon: Option<String>,
/// Tags for search
#[serde(default)]
pub tags: Vec<String>,
/// Version
#[serde(default = "default_version")]
pub version: String,
}
fn default_version() -> String {
"1.0.0".to_string()
}
/// Trigger configuration for natural language matching
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct TriggerConfig {
/// Keywords for quick matching
#[serde(default)]
pub keywords: Vec<String>,
/// Regex patterns with optional captures
#[serde(default)]
pub patterns: Vec<String>,
/// Description for LLM semantic matching
#[serde(default)]
pub description: Option<String>,
/// Example inputs
#[serde(default)]
pub examples: Vec<String>,
}
/// Input mode configuration
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct InputConfig {
/// Input mode: conversation, form, hybrid, auto
#[serde(default)]
pub mode: InputMode,
/// Complexity threshold for auto mode (switch to form when params > threshold)
#[serde(default = "default_complexity_threshold")]
pub complexity_threshold: usize,
}
fn default_complexity_threshold() -> usize {
3
}
/// Input mode for parameter collection
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum InputMode {
/// Simple conversation-based collection
Conversation,
/// Form-based collection
Form,
/// Hybrid - start with conversation, switch to form if needed
Hybrid,
/// Auto - system decides based on complexity
#[default]
Auto,
}
/// Parameter definition
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ParamDef {
/// Parameter name
pub name: String,
/// Parameter type
#[serde(rename = "type", default)]
pub param_type: ParamType,
/// Is this parameter required?
#[serde(default)]
pub required: bool,
/// Human-readable label
#[serde(default)]
pub label: Option<String>,
/// Description
#[serde(default)]
pub description: Option<String>,
/// Placeholder text
#[serde(default)]
pub placeholder: Option<String>,
/// Default value
#[serde(default)]
pub default: Option<serde_json::Value>,
/// Options for select/multi-select
#[serde(default)]
pub options: Vec<String>,
}
/// Parameter type
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ParamType {
#[default]
String,
Number,
Boolean,
Select,
MultiSelect,
File,
Text,
}
/// Stage definition - the core execution unit
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Stage {
/// LLM generation stage
Llm {
/// Stage ID
id: String,
/// Prompt template with variable interpolation
prompt: String,
/// Model override
#[serde(default)]
model: Option<String>,
/// Temperature override
#[serde(default)]
temperature: Option<f32>,
/// Max tokens
#[serde(default)]
max_tokens: Option<u32>,
/// JSON schema for structured output
#[serde(default)]
output_schema: Option<serde_json::Value>,
/// Description
#[serde(default)]
description: Option<String>,
},
/// Parallel execution stage
Parallel {
/// Stage ID
id: String,
/// Expression to iterate over (e.g., "${stages.outline.sections}")
each: String,
/// Stage template to execute for each item
stage: Box<Stage>,
/// Maximum concurrent workers
#[serde(default = "default_max_workers")]
max_workers: usize,
/// Description
#[serde(default)]
description: Option<String>,
},
/// Sequential sub-stages
Sequential {
/// Stage ID
id: String,
/// Sub-stages to execute in sequence
stages: Vec<Stage>,
/// Description
#[serde(default)]
description: Option<String>,
},
/// Conditional branching
Conditional {
/// Stage ID
id: String,
/// Condition expression (e.g., "${params.level} == 'advanced'")
condition: String,
/// Branch stages
branches: Vec<ConditionalBranch>,
/// Default stage if no branch matches
#[serde(default)]
default: Option<Box<Stage>>,
/// Description
#[serde(default)]
description: Option<String>,
},
/// Compose/assemble results
Compose {
/// Stage ID
id: String,
/// Template for composing (JSON template with variable interpolation)
template: String,
/// Description
#[serde(default)]
description: Option<String>,
},
/// Skill execution
Skill {
/// Stage ID
id: String,
/// Skill ID to execute
skill_id: String,
/// Input parameters (expressions)
#[serde(default)]
input: HashMap<String, String>,
/// Description
#[serde(default)]
description: Option<String>,
},
/// Hand execution
Hand {
/// Stage ID
id: String,
/// Hand ID
hand_id: String,
/// Action to perform
action: String,
/// Parameters (expressions)
#[serde(default)]
params: HashMap<String, String>,
/// Description
#[serde(default)]
description: Option<String>,
},
/// HTTP request
Http {
/// Stage ID
id: String,
/// URL (can be expression)
url: String,
/// HTTP method
#[serde(default = "default_http_method")]
method: String,
/// Headers
#[serde(default)]
headers: HashMap<String, String>,
/// Request body (expression)
#[serde(default)]
body: Option<String>,
/// Description
#[serde(default)]
description: Option<String>,
},
/// Set variable
SetVar {
/// Stage ID
id: String,
/// Variable name
name: String,
/// Value (expression)
value: String,
/// Description
#[serde(default)]
description: Option<String>,
},
}
fn default_max_workers() -> usize {
3
}
fn default_http_method() -> String {
"GET".to_string()
}
/// Conditional branch
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConditionalBranch {
/// Condition expression
pub when: String,
/// Stage to execute
pub then: Stage,
}
/// Output configuration
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct OutputConfig {
/// Output type: static, dynamic
#[serde(rename = "type", default)]
pub type_: OutputType,
/// Allow user to switch presentation type
#[serde(default = "default_true")]
pub allow_switch: bool,
/// Supported presentation types
#[serde(default)]
pub supported_types: Vec<PresentationType>,
/// Default presentation type
#[serde(default)]
pub default_type: Option<PresentationType>,
}
fn default_true() -> bool {
true
}
/// Output type
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum OutputType {
/// Static output (text, file)
#[default]
Static,
/// Dynamic - LLM recommends presentation type
Dynamic,
}
/// Presentation type
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum PresentationType {
Slideshow,
Quiz,
Chart,
Document,
Whiteboard,
}
/// Get stage ID
impl Stage {
pub fn id(&self) -> &str {
match self {
Stage::Llm { id, .. } => id,
Stage::Parallel { id, .. } => id,
Stage::Sequential { id, .. } => id,
Stage::Conditional { id, .. } => id,
Stage::Compose { id, .. } => id,
Stage::Skill { id, .. } => id,
Stage::Hand { id, .. } => id,
Stage::Http { id, .. } => id,
Stage::SetVar { id, .. } => id,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_v2_deserialize() {
let yaml = r#"
apiVersion: zclaw/v2
kind: Pipeline
metadata:
name: course-generator
displayName: 课程生成器
category: education
trigger:
keywords: [课程, 教程]
patterns:
- "帮我做*课程"
params:
- name: topic
type: string
required: true
label: 课程主题
stages:
- id: outline
type: llm
prompt: "为{params.topic}创建课程大纲"
- id: content
type: parallel
each: "${stages.outline.sections}"
stage:
type: llm
id: section_content
prompt: "生成章节内容"
output:
type: dynamic
supported_types: [slideshow, quiz]
"#;
let pipeline: PipelineV2 = serde_yaml::from_str(yaml).unwrap();
assert_eq!(pipeline.api_version, "zclaw/v2");
assert_eq!(pipeline.metadata.name, "course-generator");
assert_eq!(pipeline.stages.len(), 2);
assert_eq!(pipeline.trigger.keywords.len(), 2);
}
#[test]
fn test_stage_id() {
let stage = Stage::Llm {
id: "test".to_string(),
prompt: "test".to_string(),
model: None,
temperature: None,
max_tokens: None,
output_schema: None,
description: None,
};
assert_eq!(stage.id(), "test");
}
}

View File

@@ -10,6 +10,7 @@ description = "ZCLAW runtime with LLM drivers and agent loop"
[dependencies]
zclaw-types = { workspace = true }
zclaw-memory = { workspace = true }
zclaw-growth = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }

View File

@@ -0,0 +1,315 @@
//! Growth System Integration for ZCLAW Runtime
//!
//! This module provides integration between the AgentLoop and the Growth System,
//! enabling automatic memory retrieval before conversations and memory extraction
//! after conversations.
//!
//! # Usage
//!
//! ```rust,ignore
//! use zclaw_runtime::growth::GrowthIntegration;
//! use zclaw_growth::{VikingAdapter, MemoryExtractor, MemoryRetriever, PromptInjector};
//!
//! // Create growth integration
//! let viking = Arc::new(VikingAdapter::in_memory());
//! let growth = GrowthIntegration::new(viking);
//!
//! // Before conversation: enhance system prompt
//! let enhanced_prompt = growth.enhance_prompt(&agent_id, &base_prompt, &user_input).await?;
//!
//! // After conversation: extract and store memories
//! growth.process_conversation(&agent_id, &messages, session_id).await?;
//! ```
use std::sync::Arc;
use zclaw_growth::{
GrowthTracker, InjectionFormat, LlmDriverForExtraction,
MemoryExtractor, MemoryRetriever, PromptInjector, RetrievalResult,
VikingAdapter,
};
use zclaw_types::{AgentId, Message, Result, SessionId};
/// Growth system integration for AgentLoop
///
/// This struct wraps the growth system components and provides
/// a simplified interface for integration with the agent loop.
pub struct GrowthIntegration {
/// Memory retriever for fetching relevant memories
retriever: MemoryRetriever,
/// Memory extractor for extracting memories from conversations
extractor: MemoryExtractor,
/// Prompt injector for injecting memories into prompts
injector: PromptInjector,
/// Growth tracker for tracking growth metrics
tracker: GrowthTracker,
/// Configuration
config: GrowthConfigInner,
}
/// Internal configuration for growth integration
#[derive(Debug, Clone)]
struct GrowthConfigInner {
/// Enable/disable growth system
pub enabled: bool,
/// Auto-extract after each conversation
pub auto_extract: bool,
}
impl Default for GrowthConfigInner {
fn default() -> Self {
Self {
enabled: true,
auto_extract: true,
}
}
}
impl GrowthIntegration {
/// Create a new growth integration with in-memory storage
pub fn in_memory() -> Self {
let viking = Arc::new(VikingAdapter::in_memory());
Self::new(viking)
}
/// Create a new growth integration with the given Viking adapter
pub fn new(viking: Arc<VikingAdapter>) -> Self {
// Create extractor without LLM driver - can be set later
let extractor = MemoryExtractor::new_without_driver()
.with_viking(viking.clone());
let retriever = MemoryRetriever::new(viking.clone());
let injector = PromptInjector::new();
let tracker = GrowthTracker::new(viking);
Self {
retriever,
extractor,
injector,
tracker,
config: GrowthConfigInner::default(),
}
}
/// Set the injection format
pub fn with_format(mut self, format: InjectionFormat) -> Self {
self.injector = self.injector.with_format(format);
self
}
/// Set the LLM driver for memory extraction
pub fn with_llm_driver(mut self, driver: Arc<dyn LlmDriverForExtraction>) -> Self {
self.extractor = self.extractor.with_llm_driver(driver);
self
}
/// Enable or disable growth system
pub fn set_enabled(&mut self, enabled: bool) {
self.config.enabled = enabled;
}
/// Check if growth system is enabled
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
/// Enable or disable auto extraction
pub fn set_auto_extract(&mut self, auto_extract: bool) {
self.config.auto_extract = auto_extract;
}
/// Enhance system prompt with retrieved memories
///
/// This method:
/// 1. Retrieves relevant memories based on user input
/// 2. Injects them into the system prompt using configured format
///
/// Returns the enhanced prompt or the original if growth is disabled
pub async fn enhance_prompt(
&self,
agent_id: &AgentId,
base_prompt: &str,
user_input: &str,
) -> Result<String> {
if !self.config.enabled {
return Ok(base_prompt.to_string());
}
tracing::debug!(
"[GrowthIntegration] Enhancing prompt for agent: {}",
agent_id
);
// Retrieve relevant memories
let memories = self
.retriever
.retrieve(agent_id, user_input)
.await
.unwrap_or_else(|e| {
tracing::warn!("[GrowthIntegration] Retrieval failed: {}", e);
RetrievalResult::default()
});
if memories.is_empty() {
tracing::debug!("[GrowthIntegration] No memories retrieved");
return Ok(base_prompt.to_string());
}
tracing::info!(
"[GrowthIntegration] Injecting {} memories ({} tokens)",
memories.total_count(),
memories.total_tokens
);
// Inject memories into prompt
let enhanced = self.injector.inject_with_format(base_prompt, &memories);
Ok(enhanced)
}
/// Process conversation after completion
///
/// This method:
/// 1. Extracts memories from the conversation using LLM (if driver available)
/// 2. Stores the extracted memories
/// 3. Updates growth metrics
///
/// Returns the number of memories extracted
pub async fn process_conversation(
&self,
agent_id: &AgentId,
messages: &[Message],
session_id: SessionId,
) -> Result<usize> {
if !self.config.enabled || !self.config.auto_extract {
return Ok(0);
}
tracing::debug!(
"[GrowthIntegration] Processing conversation for agent: {}",
agent_id
);
// Extract memories from conversation
let extracted = self
.extractor
.extract(messages, session_id.clone())
.await
.unwrap_or_else(|e| {
tracing::warn!("[GrowthIntegration] Extraction failed: {}", e);
Vec::new()
});
if extracted.is_empty() {
tracing::debug!("[GrowthIntegration] No memories extracted");
return Ok(0);
}
tracing::info!(
"[GrowthIntegration] Extracted {} memories",
extracted.len()
);
// Store extracted memories
let count = extracted.len();
self.extractor
.store_memories(&agent_id.to_string(), &extracted)
.await?;
// Track learning event
self.tracker
.record_learning(agent_id, &session_id.to_string(), count)
.await?;
Ok(count)
}
/// Retrieve memories for a query without injection
pub async fn retrieve_memories(
&self,
agent_id: &AgentId,
query: &str,
) -> Result<RetrievalResult> {
self.retriever.retrieve(agent_id, query).await
}
/// Get growth statistics for an agent
pub async fn get_stats(&self, agent_id: &AgentId) -> Result<zclaw_growth::GrowthStats> {
self.tracker.get_stats(agent_id).await
}
/// Warm up cache with hot memories
pub async fn warmup_cache(&self, agent_id: &AgentId) -> Result<usize> {
self.retriever.warmup_cache(agent_id).await
}
/// Clear the semantic index
pub async fn clear_index(&self) {
self.retriever.clear_index().await;
}
}
impl Default for GrowthIntegration {
fn default() -> Self {
Self::in_memory()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_growth_integration_creation() {
let growth = GrowthIntegration::in_memory();
assert!(growth.is_enabled());
}
#[tokio::test]
async fn test_enhance_prompt_empty() {
let growth = GrowthIntegration::in_memory();
let agent_id = AgentId::new();
let base = "You are helpful.";
let user_input = "Hello";
let enhanced = growth
.enhance_prompt(&agent_id, base, user_input)
.await
.unwrap();
// Without any stored memories, should return base prompt
assert_eq!(enhanced, base);
}
#[tokio::test]
async fn test_disabled_growth() {
let mut growth = GrowthIntegration::in_memory();
growth.set_enabled(false);
let agent_id = AgentId::new();
let base = "You are helpful.";
let enhanced = growth
.enhance_prompt(&agent_id, base, "test")
.await
.unwrap();
assert_eq!(enhanced, base);
}
#[tokio::test]
async fn test_process_conversation_disabled() {
let mut growth = GrowthIntegration::in_memory();
growth.set_auto_extract(false);
let agent_id = AgentId::new();
let messages = vec![Message::user("Hello")];
let session_id = SessionId::new();
let count = growth
.process_conversation(&agent_id, &messages, session_id)
.await
.unwrap();
assert_eq!(count, 0);
}
}

View File

@@ -11,6 +11,7 @@ pub mod tool;
pub mod loop_runner;
pub mod loop_guard;
pub mod stream;
pub mod growth;
// Re-export main types
pub use driver::{
@@ -21,3 +22,4 @@ pub use tool::{Tool, ToolRegistry, ToolContext};
pub use loop_runner::{AgentLoop, AgentLoopResult, LoopEvent};
pub use loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
pub use stream::{StreamEvent, StreamSender};
pub use growth::GrowthIntegration;

View File

@@ -10,6 +10,7 @@ use crate::stream::StreamChunk;
use crate::tool::{ToolRegistry, ToolContext, SkillExecutor};
use crate::tool::builtin::PathValidator;
use crate::loop_guard::LoopGuard;
use crate::growth::GrowthIntegration;
use zclaw_memory::MemoryStore;
/// Agent loop runner
@@ -26,6 +27,8 @@ pub struct AgentLoop {
temperature: f32,
skill_executor: Option<Arc<dyn SkillExecutor>>,
path_validator: Option<PathValidator>,
/// Growth system integration (optional)
growth: Option<GrowthIntegration>,
}
impl AgentLoop {
@@ -47,6 +50,7 @@ impl AgentLoop {
temperature: 0.7,
skill_executor: None,
path_validator: None,
growth: None,
}
}
@@ -86,6 +90,22 @@ impl AgentLoop {
self
}
/// Enable growth system integration
pub fn with_growth(mut self, growth: GrowthIntegration) -> Self {
self.growth = Some(growth);
self
}
/// Set growth system (mutable)
pub fn set_growth(&mut self, growth: GrowthIntegration) {
self.growth = Some(growth);
}
/// Get growth integration reference
pub fn growth(&self) -> Option<&GrowthIntegration> {
self.growth.as_ref()
}
/// Create tool context for tool execution
fn create_tool_context(&self, session_id: SessionId) -> ToolContext {
ToolContext {
@@ -108,35 +128,43 @@ impl AgentLoop {
/// Implements complete agent loop: LLM → Tool Call → Tool Result → LLM → Final Response
pub async fn run(&self, session_id: SessionId, input: String) -> Result<AgentLoopResult> {
// Add user message to session
let user_message = Message::user(input);
let user_message = Message::user(input.clone());
self.memory.append_message(&session_id, &user_message).await?;
// Get all messages for context
let mut messages = self.memory.get_messages(&session_id).await?;
// Enhance system prompt with growth memories
let enhanced_prompt = if let Some(ref growth) = self.growth {
let base = self.system_prompt.as_deref().unwrap_or("");
growth.enhance_prompt(&self.agent_id, base, &input).await?
} else {
self.system_prompt.clone().unwrap_or_default()
};
let max_iterations = 10;
let mut iterations = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
loop {
let result = loop {
iterations += 1;
if iterations > max_iterations {
// Save the state before returning
let error_msg = "达到最大迭代次数,请简化请求";
self.memory.append_message(&session_id, &Message::assistant(error_msg)).await?;
return Ok(AgentLoopResult {
break AgentLoopResult {
response: error_msg.to_string(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
iterations,
});
};
}
// Build completion request
let request = CompletionRequest {
model: self.model.clone(),
system: self.system_prompt.clone(),
system: Some(enhanced_prompt.clone()),
messages: messages.clone(),
tools: self.tools.definitions(),
max_tokens: Some(self.max_tokens),
@@ -173,12 +201,12 @@ impl AgentLoop {
// Save final assistant message
self.memory.append_message(&session_id, &Message::assistant(&text)).await?;
return Ok(AgentLoopResult {
break AgentLoopResult {
response: text,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
iterations,
});
};
}
// There are tool calls - add assistant message with tool calls to history
@@ -204,7 +232,18 @@ impl AgentLoop {
}
// Continue the loop - LLM will process tool results and generate final response
};
// Process conversation for memory extraction (post-conversation)
if let Some(ref growth) = self.growth {
if let Ok(all_messages) = self.memory.get_messages(&session_id).await {
if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await {
tracing::warn!("[AgentLoop] Growth processing failed: {}", e);
}
}
}
Ok(result)
}
/// Run the agent loop with streaming
@@ -217,12 +256,20 @@ impl AgentLoop {
let (tx, rx) = mpsc::channel(100);
// Add user message to session
let user_message = Message::user(input);
let user_message = Message::user(input.clone());
self.memory.append_message(&session_id, &user_message).await?;
// Get all messages for context
let messages = self.memory.get_messages(&session_id).await?;
// Enhance system prompt with growth memories
let enhanced_prompt = if let Some(ref growth) = self.growth {
let base = self.system_prompt.as_deref().unwrap_or("");
growth.enhance_prompt(&self.agent_id, base, &input).await?
} else {
self.system_prompt.clone().unwrap_or_default()
};
// Clone necessary data for the async task
let session_id_clone = session_id.clone();
let memory = self.memory.clone();
@@ -231,7 +278,6 @@ impl AgentLoop {
let skill_executor = self.skill_executor.clone();
let path_validator = self.path_validator.clone();
let agent_id = self.agent_id.clone();
let system_prompt = self.system_prompt.clone();
let model = self.model.clone();
let max_tokens = self.max_tokens;
let temperature = self.temperature;
@@ -259,7 +305,7 @@ impl AgentLoop {
// Build completion request
let request = CompletionRequest {
model: model.clone(),
system: system_prompt.clone(),
system: Some(enhanced_prompt.clone()),
messages: messages.clone(),
tools: tools.definitions(),
max_tokens: Some(max_tokens),