refactor: 统一项目名称从OpenFang到ZCLAW
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
重构所有代码和文档中的项目名称,将OpenFang统一更新为ZCLAW。包括: - 配置文件中的项目名称 - 代码注释和文档引用 - 环境变量和路径 - 类型定义和接口名称 - 测试用例和模拟数据 同时优化部分代码结构,移除未使用的模块,并更新相关依赖项。
This commit is contained in:
32
desktop/src-tauri/src/embedding_adapter.rs
Normal file
32
desktop/src-tauri/src/embedding_adapter.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
//! Embedding Adapter - Bridges Tauri LLM EmbeddingClient to Growth System trait
|
||||
//!
|
||||
//! Implements zclaw_growth::retrieval::semantic::EmbeddingClient
|
||||
//! by wrapping the concrete llm::EmbeddingClient.
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_growth::retrieval::semantic::EmbeddingClient;
|
||||
|
||||
/// Adapter wrapping Tauri's llm::EmbeddingClient to implement the growth trait
|
||||
pub struct TauriEmbeddingAdapter {
|
||||
inner: Arc<crate::llm::EmbeddingClient>,
|
||||
}
|
||||
|
||||
impl TauriEmbeddingAdapter {
|
||||
pub fn new(client: crate::llm::EmbeddingClient) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(client),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl EmbeddingClient for TauriEmbeddingAdapter {
|
||||
async fn embed(&self, text: &str) -> Result<Vec<f32>, String> {
|
||||
let response = self.inner.embed(text).await?;
|
||||
Ok(response.embedding)
|
||||
}
|
||||
|
||||
fn is_available(&self) -> bool {
|
||||
self.inner.is_configured()
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,6 @@
|
||||
//!
|
||||
//! NOTE: Some methods are reserved for future proactive features.
|
||||
|
||||
#![allow(dead_code)] // Methods reserved for future proactive features
|
||||
|
||||
use chrono::{Local, Timelike};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
@@ -94,6 +92,7 @@ pub enum HeartbeatStatus {
|
||||
}
|
||||
|
||||
/// Type alias for heartbeat check function
|
||||
#[allow(dead_code)] // Reserved for future proactive check registration
|
||||
pub type HeartbeatCheckFn = Box<dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = Option<HeartbeatAlert>> + Send>> + Send + Sync>;
|
||||
|
||||
// === Default Config ===
|
||||
@@ -187,6 +186,7 @@ impl HeartbeatEngine {
|
||||
}
|
||||
|
||||
/// Check if the engine is running
|
||||
#[allow(dead_code)] // Reserved for UI status display
|
||||
pub async fn is_running(&self) -> bool {
|
||||
*self.running.lock().await
|
||||
}
|
||||
@@ -197,6 +197,7 @@ impl HeartbeatEngine {
|
||||
}
|
||||
|
||||
/// Subscribe to alerts
|
||||
#[allow(dead_code)] // Reserved for future UI notification integration
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<HeartbeatAlert> {
|
||||
self.alert_sender.subscribe()
|
||||
}
|
||||
@@ -355,7 +356,9 @@ static LAST_INTERACTION: OnceLock<RwLock<StdHashMap<String, String>>> = OnceLock
|
||||
pub struct MemoryStatsCache {
|
||||
pub task_count: usize,
|
||||
pub total_entries: usize,
|
||||
#[allow(dead_code)] // Reserved for UI display
|
||||
pub storage_size_bytes: usize,
|
||||
#[allow(dead_code)] // Reserved for UI display
|
||||
pub last_updated: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,397 +0,0 @@
|
||||
//! Adaptive Intelligence Mesh - Coordinates Memory, Pipeline, and Heartbeat
|
||||
//!
|
||||
//! This module provides proactive workflow recommendations based on user behavior patterns.
|
||||
//! It integrates with:
|
||||
//! - PatternDetector for behavior pattern detection
|
||||
//! - WorkflowRecommender for generating recommendations
|
||||
//! - HeartbeatEngine for periodic checks
|
||||
//! - PersistentMemoryStore for historical data
|
||||
//! - PipelineExecutor for workflow execution
|
||||
//!
|
||||
//! NOTE: Some methods are reserved for future integration with the UI.
|
||||
|
||||
#![allow(dead_code)] // Methods reserved for future UI integration
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, Mutex};
|
||||
|
||||
use super::pattern_detector::{BehaviorPattern, PatternContext, PatternDetector};
|
||||
use super::recommender::WorkflowRecommender;
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Workflow recommendation generated by the mesh
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkflowRecommendation {
|
||||
/// Unique recommendation identifier
|
||||
pub id: String,
|
||||
/// Pipeline ID to recommend
|
||||
pub pipeline_id: String,
|
||||
/// Confidence score (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// Human-readable reason for recommendation
|
||||
pub reason: String,
|
||||
/// Suggested input values
|
||||
pub suggested_inputs: HashMap<String, serde_json::Value>,
|
||||
/// Pattern IDs that matched
|
||||
pub patterns_matched: Vec<String>,
|
||||
/// When this recommendation was generated
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Mesh coordinator configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MeshConfig {
|
||||
/// Enable mesh recommendations
|
||||
pub enabled: bool,
|
||||
/// Minimum confidence threshold for recommendations
|
||||
pub min_confidence: f32,
|
||||
/// Maximum recommendations to generate per analysis
|
||||
pub max_recommendations: usize,
|
||||
/// Hours to look back for pattern analysis
|
||||
pub analysis_window_hours: u64,
|
||||
}
|
||||
|
||||
impl Default for MeshConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
min_confidence: 0.6,
|
||||
max_recommendations: 5,
|
||||
analysis_window_hours: 24,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Analysis result from mesh coordinator
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MeshAnalysisResult {
|
||||
/// Generated recommendations
|
||||
pub recommendations: Vec<WorkflowRecommendation>,
|
||||
/// Patterns detected
|
||||
pub patterns_detected: usize,
|
||||
/// Analysis timestamp
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
// === Mesh Coordinator ===
|
||||
|
||||
/// Main mesh coordinator that integrates pattern detection and recommendations
|
||||
pub struct MeshCoordinator {
|
||||
/// Agent ID
|
||||
#[allow(dead_code)] // Reserved for multi-agent scenarios
|
||||
agent_id: String,
|
||||
/// Configuration
|
||||
config: Arc<Mutex<MeshConfig>>,
|
||||
/// Pattern detector
|
||||
pattern_detector: Arc<Mutex<PatternDetector>>,
|
||||
/// Workflow recommender
|
||||
recommender: Arc<Mutex<WorkflowRecommender>>,
|
||||
/// Recommendation sender
|
||||
#[allow(dead_code)] // Reserved for real-time recommendation streaming
|
||||
recommendation_sender: broadcast::Sender<WorkflowRecommendation>,
|
||||
/// Last analysis timestamp
|
||||
last_analysis: Arc<Mutex<Option<DateTime<Utc>>>>,
|
||||
}
|
||||
|
||||
impl MeshCoordinator {
|
||||
/// Create a new mesh coordinator
|
||||
pub fn new(agent_id: String, config: Option<MeshConfig>) -> Self {
|
||||
let (sender, _) = broadcast::channel(100);
|
||||
let config = config.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
agent_id,
|
||||
config: Arc::new(Mutex::new(config)),
|
||||
pattern_detector: Arc::new(Mutex::new(PatternDetector::new(None))),
|
||||
recommender: Arc::new(Mutex::new(WorkflowRecommender::new(None))),
|
||||
recommendation_sender: sender,
|
||||
last_analysis: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze current context and generate recommendations
|
||||
pub async fn analyze(&self) -> Result<MeshAnalysisResult, String> {
|
||||
let config = self.config.lock().await.clone();
|
||||
|
||||
if !config.enabled {
|
||||
return Ok(MeshAnalysisResult {
|
||||
recommendations: vec![],
|
||||
patterns_detected: 0,
|
||||
timestamp: Utc::now(),
|
||||
});
|
||||
}
|
||||
|
||||
// Get patterns from detector (clone to avoid borrow issues)
|
||||
let patterns: Vec<BehaviorPattern> = {
|
||||
let detector = self.pattern_detector.lock().await;
|
||||
let patterns_ref = detector.get_patterns();
|
||||
patterns_ref.into_iter().cloned().collect()
|
||||
};
|
||||
let patterns_detected = patterns.len();
|
||||
|
||||
// Generate recommendations from patterns
|
||||
let recommender = self.recommender.lock().await;
|
||||
let pattern_refs: Vec<&BehaviorPattern> = patterns.iter().collect();
|
||||
let mut recommendations = recommender.recommend(&pattern_refs);
|
||||
|
||||
// Filter by confidence
|
||||
recommendations.retain(|r| r.confidence >= config.min_confidence);
|
||||
|
||||
// Limit count
|
||||
recommendations.truncate(config.max_recommendations);
|
||||
|
||||
// Update timestamps
|
||||
for rec in &mut recommendations {
|
||||
rec.timestamp = Utc::now();
|
||||
}
|
||||
|
||||
// Update last analysis time
|
||||
*self.last_analysis.lock().await = Some(Utc::now());
|
||||
|
||||
Ok(MeshAnalysisResult {
|
||||
recommendations: recommendations.clone(),
|
||||
patterns_detected,
|
||||
timestamp: Utc::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Record user activity for pattern detection
|
||||
pub async fn record_activity(
|
||||
&self,
|
||||
activity_type: ActivityType,
|
||||
context: PatternContext,
|
||||
) -> Result<(), String> {
|
||||
let mut detector = self.pattern_detector.lock().await;
|
||||
|
||||
match activity_type {
|
||||
ActivityType::SkillUsed { skill_ids } => {
|
||||
detector.record_skill_usage(skill_ids);
|
||||
}
|
||||
ActivityType::PipelineExecuted {
|
||||
task_type,
|
||||
pipeline_id,
|
||||
} => {
|
||||
detector.record_pipeline_execution(&task_type, &pipeline_id, context);
|
||||
}
|
||||
ActivityType::InputReceived { keywords, intent } => {
|
||||
detector.record_input_pattern(keywords, &intent, context);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Subscribe to recommendations
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<WorkflowRecommendation> {
|
||||
self.recommendation_sender.subscribe()
|
||||
}
|
||||
|
||||
/// Get current patterns
|
||||
pub async fn get_patterns(&self) -> Vec<BehaviorPattern> {
|
||||
let detector = self.pattern_detector.lock().await;
|
||||
detector.get_patterns().into_iter().cloned().collect()
|
||||
}
|
||||
|
||||
/// Decay old patterns (call periodically)
|
||||
pub async fn decay_patterns(&self) {
|
||||
let mut detector = self.pattern_detector.lock().await;
|
||||
detector.decay_patterns();
|
||||
}
|
||||
|
||||
/// Update configuration
|
||||
pub async fn update_config(&self, config: MeshConfig) {
|
||||
*self.config.lock().await = config;
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub async fn get_config(&self) -> MeshConfig {
|
||||
self.config.lock().await.clone()
|
||||
}
|
||||
|
||||
/// Record a user correction (for pattern refinement)
|
||||
pub async fn record_correction(&self, correction_type: &str) {
|
||||
let mut detector = self.pattern_detector.lock().await;
|
||||
// Record as input pattern with negative signal
|
||||
detector.record_input_pattern(
|
||||
vec![format!("correction:{}", correction_type)],
|
||||
"user_preference",
|
||||
PatternContext::default(),
|
||||
);
|
||||
}
|
||||
|
||||
/// Get recommendation count
|
||||
pub async fn recommendation_count(&self) -> usize {
|
||||
let recommender = self.recommender.lock().await;
|
||||
recommender.recommendation_count()
|
||||
}
|
||||
|
||||
/// Accept a recommendation (returns the accepted recommendation)
|
||||
pub async fn accept_recommendation(&self, recommendation_id: &str) -> Option<WorkflowRecommendation> {
|
||||
let mut recommender = self.recommender.lock().await;
|
||||
recommender.accept_recommendation(recommendation_id)
|
||||
}
|
||||
|
||||
/// Dismiss a recommendation (returns true if found and dismissed)
|
||||
pub async fn dismiss_recommendation(&self, recommendation_id: &str) -> bool {
|
||||
let mut recommender = self.recommender.lock().await;
|
||||
recommender.dismiss_recommendation(recommendation_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Types of user activities that can be recorded
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ActivityType {
|
||||
/// Skills were used together
|
||||
SkillUsed { skill_ids: Vec<String> },
|
||||
/// A pipeline was executed
|
||||
PipelineExecuted { task_type: String, pipeline_id: String },
|
||||
/// User input was received
|
||||
InputReceived { keywords: Vec<String>, intent: String },
|
||||
}
|
||||
|
||||
// === Tauri Commands ===
|
||||
|
||||
/// Mesh coordinator state for Tauri
|
||||
pub type MeshCoordinatorState = Arc<Mutex<HashMap<String, MeshCoordinator>>>;
|
||||
|
||||
/// Initialize mesh coordinator for an agent
|
||||
#[tauri::command]
|
||||
pub async fn mesh_init(
|
||||
agent_id: String,
|
||||
config: Option<MeshConfig>,
|
||||
state: tauri::State<'_, MeshCoordinatorState>,
|
||||
) -> Result<(), String> {
|
||||
let coordinator = MeshCoordinator::new(agent_id.clone(), config);
|
||||
let mut coordinators = state.lock().await;
|
||||
coordinators.insert(agent_id, coordinator);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Analyze and get recommendations
|
||||
#[tauri::command]
|
||||
pub async fn mesh_analyze(
|
||||
agent_id: String,
|
||||
state: tauri::State<'_, MeshCoordinatorState>,
|
||||
) -> Result<MeshAnalysisResult, String> {
|
||||
let coordinators = state.lock().await;
|
||||
let coordinator = coordinators
|
||||
.get(&agent_id)
|
||||
.ok_or_else(|| format!("Mesh coordinator not initialized for agent: {}", agent_id))?;
|
||||
coordinator.analyze().await
|
||||
}
|
||||
|
||||
/// Record user activity
|
||||
#[tauri::command]
|
||||
pub async fn mesh_record_activity(
|
||||
agent_id: String,
|
||||
activity_type: ActivityType,
|
||||
context: PatternContext,
|
||||
state: tauri::State<'_, MeshCoordinatorState>,
|
||||
) -> Result<(), String> {
|
||||
let coordinators = state.lock().await;
|
||||
let coordinator = coordinators
|
||||
.get(&agent_id)
|
||||
.ok_or_else(|| format!("Mesh coordinator not initialized for agent: {}", agent_id))?;
|
||||
coordinator.record_activity(activity_type, context).await
|
||||
}
|
||||
|
||||
/// Get current patterns
|
||||
#[tauri::command]
|
||||
pub async fn mesh_get_patterns(
|
||||
agent_id: String,
|
||||
state: tauri::State<'_, MeshCoordinatorState>,
|
||||
) -> Result<Vec<BehaviorPattern>, String> {
|
||||
let coordinators = state.lock().await;
|
||||
let coordinator = coordinators
|
||||
.get(&agent_id)
|
||||
.ok_or_else(|| format!("Mesh coordinator not initialized for agent: {}", agent_id))?;
|
||||
Ok(coordinator.get_patterns().await)
|
||||
}
|
||||
|
||||
/// Update mesh configuration
|
||||
#[tauri::command]
|
||||
pub async fn mesh_update_config(
|
||||
agent_id: String,
|
||||
config: MeshConfig,
|
||||
state: tauri::State<'_, MeshCoordinatorState>,
|
||||
) -> Result<(), String> {
|
||||
let coordinators = state.lock().await;
|
||||
let coordinator = coordinators
|
||||
.get(&agent_id)
|
||||
.ok_or_else(|| format!("Mesh coordinator not initialized for agent: {}", agent_id))?;
|
||||
coordinator.update_config(config).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Decay old patterns
|
||||
#[tauri::command]
|
||||
pub async fn mesh_decay_patterns(
|
||||
agent_id: String,
|
||||
state: tauri::State<'_, MeshCoordinatorState>,
|
||||
) -> Result<(), String> {
|
||||
let coordinators = state.lock().await;
|
||||
let coordinator = coordinators
|
||||
.get(&agent_id)
|
||||
.ok_or_else(|| format!("Mesh coordinator not initialized for agent: {}", agent_id))?;
|
||||
coordinator.decay_patterns().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Accept a recommendation (removes it and returns the accepted recommendation)
|
||||
#[tauri::command]
|
||||
pub async fn mesh_accept_recommendation(
|
||||
agent_id: String,
|
||||
recommendation_id: String,
|
||||
state: tauri::State<'_, MeshCoordinatorState>,
|
||||
) -> Result<Option<WorkflowRecommendation>, String> {
|
||||
let coordinators = state.lock().await;
|
||||
let coordinator = coordinators
|
||||
.get(&agent_id)
|
||||
.ok_or_else(|| format!("Mesh coordinator not initialized for agent: {}", agent_id))?;
|
||||
Ok(coordinator.accept_recommendation(&recommendation_id).await)
|
||||
}
|
||||
|
||||
/// Dismiss a recommendation (removes it without acting on it)
|
||||
#[tauri::command]
|
||||
pub async fn mesh_dismiss_recommendation(
|
||||
agent_id: String,
|
||||
recommendation_id: String,
|
||||
state: tauri::State<'_, MeshCoordinatorState>,
|
||||
) -> Result<bool, String> {
|
||||
let coordinators = state.lock().await;
|
||||
let coordinator = coordinators
|
||||
.get(&agent_id)
|
||||
.ok_or_else(|| format!("Mesh coordinator not initialized for agent: {}", agent_id))?;
|
||||
Ok(coordinator.dismiss_recommendation(&recommendation_id).await)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mesh_config_default() {
|
||||
let config = MeshConfig::default();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.min_confidence, 0.6);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mesh_coordinator_creation() {
|
||||
let coordinator = MeshCoordinator::new("test_agent".to_string(), None);
|
||||
let config = coordinator.get_config().await;
|
||||
assert!(config.enabled);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mesh_analysis() {
|
||||
let coordinator = MeshCoordinator::new("test_agent".to_string(), None);
|
||||
let result = coordinator.analyze().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
@@ -1,421 +0,0 @@
|
||||
//! Pattern Detector - Behavior pattern detection for Adaptive Intelligence Mesh
|
||||
//!
|
||||
//! Detects patterns from user activities including:
|
||||
//! - Skill combinations (frequently used together)
|
||||
//! - Temporal triggers (time-based patterns)
|
||||
//! - Task-pipeline mappings (task types mapped to pipelines)
|
||||
//! - Input patterns (keyword/intent patterns)
|
||||
//!
|
||||
//! NOTE: Analysis and export methods are reserved for future dashboard integration.
|
||||
|
||||
#![allow(dead_code)] // Analysis and export methods reserved for future dashboard features
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// === Pattern Types ===
|
||||
|
||||
/// Unique identifier for a pattern
|
||||
pub type PatternId = String;
|
||||
|
||||
/// Behavior pattern detected from user activities
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BehaviorPattern {
|
||||
/// Unique pattern identifier
|
||||
pub id: PatternId,
|
||||
/// Type of pattern detected
|
||||
pub pattern_type: PatternType,
|
||||
/// How many times this pattern has occurred
|
||||
pub frequency: usize,
|
||||
/// When this pattern was last detected
|
||||
pub last_occurrence: DateTime<Utc>,
|
||||
/// When this pattern was first detected
|
||||
pub first_occurrence: DateTime<Utc>,
|
||||
/// Confidence score (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// Context when pattern was detected
|
||||
pub context: PatternContext,
|
||||
}
|
||||
|
||||
/// Types of detectable patterns
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum PatternType {
|
||||
/// Skills frequently used together
|
||||
SkillCombination {
|
||||
skill_ids: Vec<String>,
|
||||
},
|
||||
/// Time-based trigger pattern
|
||||
TemporalTrigger {
|
||||
hand_id: String,
|
||||
time_pattern: String, // Cron-like pattern or time range
|
||||
},
|
||||
/// Task type mapped to a pipeline
|
||||
TaskPipelineMapping {
|
||||
task_type: String,
|
||||
pipeline_id: String,
|
||||
},
|
||||
/// Input keyword/intent pattern
|
||||
InputPattern {
|
||||
keywords: Vec<String>,
|
||||
intent: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Context information when pattern was detected
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PatternContext {
|
||||
/// Skills involved in the session
|
||||
pub skill_ids: Option<Vec<String>>,
|
||||
/// Topics discussed recently
|
||||
pub recent_topics: Option<Vec<String>>,
|
||||
/// Detected intent
|
||||
pub intent: Option<String>,
|
||||
/// Time of day when detected (hour 0-23)
|
||||
pub time_of_day: Option<u8>,
|
||||
/// Day of week (0=Monday, 6=Sunday)
|
||||
pub day_of_week: Option<u8>,
|
||||
}
|
||||
|
||||
/// Pattern detection configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PatternDetectorConfig {
|
||||
/// Minimum occurrences before pattern is recognized
|
||||
pub min_frequency: usize,
|
||||
/// Minimum confidence threshold
|
||||
pub min_confidence: f32,
|
||||
/// Days after which pattern confidence decays
|
||||
pub decay_days: u32,
|
||||
/// Maximum patterns to keep
|
||||
pub max_patterns: usize,
|
||||
}
|
||||
|
||||
impl Default for PatternDetectorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_frequency: 3,
|
||||
min_confidence: 0.5,
|
||||
decay_days: 30,
|
||||
max_patterns: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Pattern Detector ===
|
||||
|
||||
/// Pattern detector that identifies behavior patterns from activities
|
||||
pub struct PatternDetector {
|
||||
/// Detected patterns
|
||||
patterns: HashMap<PatternId, BehaviorPattern>,
|
||||
/// Configuration
|
||||
config: PatternDetectorConfig,
|
||||
/// Skill combination history for pattern detection
|
||||
skill_combination_history: Vec<(Vec<String>, DateTime<Utc>)>,
|
||||
}
|
||||
|
||||
impl PatternDetector {
|
||||
/// Create a new pattern detector
|
||||
pub fn new(config: Option<PatternDetectorConfig>) -> Self {
|
||||
Self {
|
||||
patterns: HashMap::new(),
|
||||
config: config.unwrap_or_default(),
|
||||
skill_combination_history: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record skill usage for combination detection
|
||||
pub fn record_skill_usage(&mut self, skill_ids: Vec<String>) {
|
||||
let now = Utc::now();
|
||||
self.skill_combination_history.push((skill_ids, now));
|
||||
|
||||
// Keep only recent history (last 1000 entries)
|
||||
if self.skill_combination_history.len() > 1000 {
|
||||
self.skill_combination_history.drain(0..500);
|
||||
}
|
||||
|
||||
// Detect patterns
|
||||
self.detect_skill_combinations();
|
||||
}
|
||||
|
||||
/// Record a pipeline execution for task mapping detection
|
||||
pub fn record_pipeline_execution(
|
||||
&mut self,
|
||||
task_type: &str,
|
||||
pipeline_id: &str,
|
||||
context: PatternContext,
|
||||
) {
|
||||
let pattern_key = format!("task_pipeline:{}:{}", task_type, pipeline_id);
|
||||
|
||||
self.update_or_create_pattern(
|
||||
&pattern_key,
|
||||
PatternType::TaskPipelineMapping {
|
||||
task_type: task_type.to_string(),
|
||||
pipeline_id: pipeline_id.to_string(),
|
||||
},
|
||||
context,
|
||||
);
|
||||
}
|
||||
|
||||
/// Record an input pattern
|
||||
pub fn record_input_pattern(
|
||||
&mut self,
|
||||
keywords: Vec<String>,
|
||||
intent: &str,
|
||||
context: PatternContext,
|
||||
) {
|
||||
let pattern_key = format!("input_pattern:{}:{}", keywords.join(","), intent);
|
||||
|
||||
self.update_or_create_pattern(
|
||||
&pattern_key,
|
||||
PatternType::InputPattern {
|
||||
keywords,
|
||||
intent: intent.to_string(),
|
||||
},
|
||||
context,
|
||||
);
|
||||
}
|
||||
|
||||
/// Update existing pattern or create new one
|
||||
fn update_or_create_pattern(
|
||||
&mut self,
|
||||
key: &str,
|
||||
pattern_type: PatternType,
|
||||
context: PatternContext,
|
||||
) {
|
||||
let now = Utc::now();
|
||||
let decay_days = self.config.decay_days;
|
||||
|
||||
if let Some(pattern) = self.patterns.get_mut(key) {
|
||||
// Update existing pattern
|
||||
pattern.frequency += 1;
|
||||
pattern.last_occurrence = now;
|
||||
pattern.context = context;
|
||||
|
||||
// Recalculate confidence inline to avoid borrow issues
|
||||
let days_since_last = (now - pattern.last_occurrence).num_days() as f32;
|
||||
let frequency_score = (pattern.frequency as f32 / 10.0).min(1.0);
|
||||
let decay_factor = if days_since_last > decay_days as f32 {
|
||||
0.5
|
||||
} else {
|
||||
1.0 - (days_since_last / decay_days as f32) * 0.3
|
||||
};
|
||||
pattern.confidence = (frequency_score * decay_factor).min(1.0);
|
||||
} else {
|
||||
// Create new pattern
|
||||
let pattern = BehaviorPattern {
|
||||
id: key.to_string(),
|
||||
pattern_type,
|
||||
frequency: 1,
|
||||
first_occurrence: now,
|
||||
last_occurrence: now,
|
||||
confidence: 0.1, // Low initial confidence
|
||||
context,
|
||||
};
|
||||
|
||||
self.patterns.insert(key.to_string(), pattern);
|
||||
|
||||
// Enforce max patterns limit
|
||||
self.enforce_max_patterns();
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect skill combination patterns from history
|
||||
fn detect_skill_combinations(&mut self) {
|
||||
// Group skill combinations
|
||||
let mut combination_counts: HashMap<String, (Vec<String>, usize, DateTime<Utc>)> =
|
||||
HashMap::new();
|
||||
|
||||
for (skills, time) in &self.skill_combination_history {
|
||||
if skills.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Sort skills for consistent grouping
|
||||
let mut sorted_skills = skills.clone();
|
||||
sorted_skills.sort();
|
||||
let key = sorted_skills.join("|");
|
||||
|
||||
let entry = combination_counts.entry(key).or_insert((
|
||||
sorted_skills,
|
||||
0,
|
||||
*time,
|
||||
));
|
||||
entry.1 += 1;
|
||||
entry.2 = *time; // Update last occurrence
|
||||
}
|
||||
|
||||
// Create patterns for combinations meeting threshold
|
||||
for (key, (skills, count, last_time)) in combination_counts {
|
||||
if count >= self.config.min_frequency {
|
||||
let pattern = BehaviorPattern {
|
||||
id: format!("skill_combo:{}", key),
|
||||
pattern_type: PatternType::SkillCombination { skill_ids: skills },
|
||||
frequency: count,
|
||||
first_occurrence: last_time,
|
||||
last_occurrence: last_time,
|
||||
confidence: self.calculate_confidence_from_frequency(count),
|
||||
context: PatternContext::default(),
|
||||
};
|
||||
|
||||
self.patterns.insert(pattern.id.clone(), pattern);
|
||||
}
|
||||
}
|
||||
|
||||
self.enforce_max_patterns();
|
||||
}
|
||||
|
||||
/// Calculate confidence based on frequency and recency
|
||||
fn calculate_confidence(&self, pattern: &BehaviorPattern) -> f32 {
|
||||
let now = Utc::now();
|
||||
let days_since_last = (now - pattern.last_occurrence).num_days() as f32;
|
||||
|
||||
// Base confidence from frequency (capped at 1.0)
|
||||
let frequency_score = (pattern.frequency as f32 / 10.0).min(1.0);
|
||||
|
||||
// Decay factor based on time since last occurrence
|
||||
let decay_factor = if days_since_last > self.config.decay_days as f32 {
|
||||
0.5 // Significant decay for old patterns
|
||||
} else {
|
||||
1.0 - (days_since_last / self.config.decay_days as f32) * 0.3
|
||||
};
|
||||
|
||||
(frequency_score * decay_factor).min(1.0)
|
||||
}
|
||||
|
||||
/// Calculate confidence from frequency alone
|
||||
fn calculate_confidence_from_frequency(&self, frequency: usize) -> f32 {
|
||||
(frequency as f32 / self.config.min_frequency.max(1) as f32).min(1.0)
|
||||
}
|
||||
|
||||
/// Enforce maximum patterns limit by removing lowest confidence patterns
|
||||
fn enforce_max_patterns(&mut self) {
|
||||
if self.patterns.len() <= self.config.max_patterns {
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort patterns by confidence and remove lowest
|
||||
let mut patterns_vec: Vec<_> = self.patterns.drain().collect();
|
||||
patterns_vec.sort_by(|a, b| b.1.confidence.partial_cmp(&a.1.confidence).unwrap());
|
||||
|
||||
// Keep top patterns
|
||||
self.patterns = patterns_vec
|
||||
.into_iter()
|
||||
.take(self.config.max_patterns)
|
||||
.collect();
|
||||
}
|
||||
|
||||
/// Get all patterns above confidence threshold
|
||||
pub fn get_patterns(&self) -> Vec<&BehaviorPattern> {
|
||||
self.patterns
|
||||
.values()
|
||||
.filter(|p| p.confidence >= self.config.min_confidence)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get patterns of a specific type
|
||||
pub fn get_patterns_by_type(&self, pattern_type: &PatternType) -> Vec<&BehaviorPattern> {
|
||||
self.patterns
|
||||
.values()
|
||||
.filter(|p| std::mem::discriminant(&p.pattern_type) == std::mem::discriminant(pattern_type))
|
||||
.filter(|p| p.confidence >= self.config.min_confidence)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get patterns sorted by confidence
|
||||
pub fn get_patterns_sorted(&self) -> Vec<&BehaviorPattern> {
|
||||
let mut patterns: Vec<_> = self.get_patterns();
|
||||
patterns.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
|
||||
patterns
|
||||
}
|
||||
|
||||
/// Decay old patterns (should be called periodically)
|
||||
pub fn decay_patterns(&mut self) {
|
||||
let now = Utc::now();
|
||||
|
||||
for pattern in self.patterns.values_mut() {
|
||||
let days_since_last = (now - pattern.last_occurrence).num_days() as f32;
|
||||
|
||||
if days_since_last > self.config.decay_days as f32 {
|
||||
// Reduce confidence for old patterns
|
||||
let decay_amount = 0.1 * (days_since_last / self.config.decay_days as f32);
|
||||
pattern.confidence = (pattern.confidence - decay_amount).max(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove patterns below threshold
|
||||
self.patterns
|
||||
.retain(|_, p| p.confidence >= self.config.min_confidence * 0.5);
|
||||
}
|
||||
|
||||
/// Clear all patterns
|
||||
pub fn clear(&mut self) {
|
||||
self.patterns.clear();
|
||||
self.skill_combination_history.clear();
|
||||
}
|
||||
|
||||
/// Get pattern count
|
||||
pub fn pattern_count(&self) -> usize {
|
||||
self.patterns.len()
|
||||
}
|
||||
|
||||
/// Export patterns for persistence
|
||||
pub fn export_patterns(&self) -> Vec<BehaviorPattern> {
|
||||
self.patterns.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Import patterns from persistence
|
||||
pub fn import_patterns(&mut self, patterns: Vec<BehaviorPattern>) {
|
||||
for pattern in patterns {
|
||||
self.patterns.insert(pattern.id.clone(), pattern);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pattern_creation() {
|
||||
let detector = PatternDetector::new(None);
|
||||
assert_eq!(detector.pattern_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_combination_detection() {
|
||||
let mut detector = PatternDetector::new(Some(PatternDetectorConfig {
|
||||
min_frequency: 2,
|
||||
..Default::default()
|
||||
}));
|
||||
|
||||
// Record skill usage multiple times
|
||||
detector.record_skill_usage(vec!["skill_a".to_string(), "skill_b".to_string()]);
|
||||
detector.record_skill_usage(vec!["skill_a".to_string(), "skill_b".to_string()]);
|
||||
|
||||
// Should detect pattern after 2 occurrences
|
||||
let patterns = detector.get_patterns();
|
||||
assert!(!patterns.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_confidence_calculation() {
|
||||
let detector = PatternDetector::new(None);
|
||||
|
||||
let pattern = BehaviorPattern {
|
||||
id: "test".to_string(),
|
||||
pattern_type: PatternType::TaskPipelineMapping {
|
||||
task_type: "test".to_string(),
|
||||
pipeline_id: "pipeline".to_string(),
|
||||
},
|
||||
frequency: 5,
|
||||
first_occurrence: Utc::now(),
|
||||
last_occurrence: Utc::now(),
|
||||
confidence: 0.5,
|
||||
context: PatternContext::default(),
|
||||
};
|
||||
|
||||
let confidence = detector.calculate_confidence(&pattern);
|
||||
assert!(confidence > 0.0 && confidence <= 1.0);
|
||||
}
|
||||
}
|
||||
@@ -1,819 +0,0 @@
|
||||
//! Persona Evolver - Memory-powered persona evolution system
|
||||
//!
|
||||
//! Automatically evolves agent persona based on:
|
||||
//! - User interaction patterns (preferences, communication style)
|
||||
//! - Reflection insights (positive/negative patterns)
|
||||
//! - Memory accumulation (facts, lessons, context)
|
||||
//!
|
||||
//! Key features:
|
||||
//! - Automatic user_profile enrichment from preferences
|
||||
//! - Instruction refinement proposals based on patterns
|
||||
//! - Soul evolution suggestions (requires explicit user approval)
|
||||
//!
|
||||
//! Phase 4 of Intelligence Layer - P1 Innovation Task.
|
||||
//!
|
||||
//! NOTE: Tauri commands defined here are not yet registered with the app.
|
||||
|
||||
#![allow(dead_code)] // Tauri commands not yet registered with application
|
||||
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::reflection::{ReflectionResult, Sentiment, MemoryEntryForAnalysis};
|
||||
use super::identity::{IdentityFiles, IdentityFile, ProposalStatus};
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Persona evolution configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersonaEvolverConfig {
|
||||
/// Enable automatic user_profile updates
|
||||
#[serde(default = "default_auto_profile_update")]
|
||||
pub auto_profile_update: bool,
|
||||
/// Minimum preferences before suggesting profile update
|
||||
#[serde(default = "default_min_preferences")]
|
||||
pub min_preferences_for_update: usize,
|
||||
/// Minimum conversations before evolution
|
||||
#[serde(default = "default_min_conversations")]
|
||||
pub min_conversations_for_evolution: usize,
|
||||
/// Enable instruction refinement proposals
|
||||
#[serde(default = "default_enable_instruction_refinement")]
|
||||
pub enable_instruction_refinement: bool,
|
||||
/// Enable soul evolution (requires explicit approval)
|
||||
#[serde(default = "default_enable_soul_evolution")]
|
||||
pub enable_soul_evolution: bool,
|
||||
/// Maximum proposals per evolution cycle
|
||||
#[serde(default = "default_max_proposals")]
|
||||
pub max_proposals_per_cycle: usize,
|
||||
}
|
||||
|
||||
fn default_auto_profile_update() -> bool { true }
|
||||
fn default_min_preferences() -> usize { 3 }
|
||||
fn default_min_conversations() -> usize { 5 }
|
||||
fn default_enable_instruction_refinement() -> bool { true }
|
||||
fn default_enable_soul_evolution() -> bool { true }
|
||||
fn default_max_proposals() -> usize { 3 }
|
||||
|
||||
impl Default for PersonaEvolverConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auto_profile_update: true,
|
||||
min_preferences_for_update: 3,
|
||||
min_conversations_for_evolution: 5,
|
||||
enable_instruction_refinement: true,
|
||||
enable_soul_evolution: true,
|
||||
max_proposals_per_cycle: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Persona evolution result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EvolutionResult {
|
||||
/// Agent ID
|
||||
pub agent_id: String,
|
||||
/// Timestamp
|
||||
pub timestamp: String,
|
||||
/// Profile updates applied (auto)
|
||||
pub profile_updates: Vec<ProfileUpdate>,
|
||||
/// Proposals generated (require approval)
|
||||
pub proposals: Vec<EvolutionProposal>,
|
||||
/// Evolution insights
|
||||
pub insights: Vec<EvolutionInsight>,
|
||||
/// Whether evolution occurred
|
||||
pub evolved: bool,
|
||||
}
|
||||
|
||||
/// Profile update (auto-applied)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProfileUpdate {
|
||||
pub section: String,
|
||||
pub previous: String,
|
||||
pub updated: String,
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
/// Evolution proposal (requires approval)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EvolutionProposal {
|
||||
pub id: String,
|
||||
pub agent_id: String,
|
||||
pub target_file: IdentityFile,
|
||||
pub change_type: EvolutionChangeType,
|
||||
pub reason: String,
|
||||
pub current_content: String,
|
||||
pub proposed_content: String,
|
||||
pub confidence: f32,
|
||||
pub evidence: Vec<String>,
|
||||
pub status: ProposalStatus,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Type of evolution change
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum EvolutionChangeType {
|
||||
/// Add new instruction section
|
||||
InstructionAddition,
|
||||
/// Refine existing instruction
|
||||
InstructionRefinement,
|
||||
/// Add personality trait
|
||||
TraitAddition,
|
||||
/// Communication style adjustment
|
||||
StyleAdjustment,
|
||||
/// Knowledge domain expansion
|
||||
DomainExpansion,
|
||||
}
|
||||
|
||||
/// Evolution insight
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EvolutionInsight {
|
||||
pub category: InsightCategory,
|
||||
pub observation: String,
|
||||
pub recommendation: String,
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum InsightCategory {
|
||||
CommunicationStyle,
|
||||
TechnicalExpertise,
|
||||
TaskEfficiency,
|
||||
UserPreference,
|
||||
KnowledgeGap,
|
||||
}
|
||||
|
||||
/// Persona evolution state
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersonaEvolverState {
|
||||
pub last_evolution: Option<String>,
|
||||
pub total_evolutions: usize,
|
||||
pub pending_proposals: usize,
|
||||
pub profile_enrichment_score: f32,
|
||||
}
|
||||
|
||||
impl Default for PersonaEvolverState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
last_evolution: None,
|
||||
total_evolutions: 0,
|
||||
pending_proposals: 0,
|
||||
profile_enrichment_score: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Persona Evolver ===
|
||||
|
||||
pub struct PersonaEvolver {
|
||||
config: PersonaEvolverConfig,
|
||||
state: PersonaEvolverState,
|
||||
evolution_history: Vec<EvolutionResult>,
|
||||
}
|
||||
|
||||
impl PersonaEvolver {
|
||||
pub fn new(config: Option<PersonaEvolverConfig>) -> Self {
|
||||
Self {
|
||||
config: config.unwrap_or_default(),
|
||||
state: PersonaEvolverState::default(),
|
||||
evolution_history: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run evolution cycle for an agent
|
||||
pub fn evolve(
|
||||
&mut self,
|
||||
agent_id: &str,
|
||||
memories: &[MemoryEntryForAnalysis],
|
||||
reflection_result: &ReflectionResult,
|
||||
current_identity: &IdentityFiles,
|
||||
) -> EvolutionResult {
|
||||
let mut profile_updates = Vec::new();
|
||||
let mut proposals = Vec::new();
|
||||
#[allow(unused_assignments)] // Overwritten by generate_insights below
|
||||
let mut insights = Vec::new();
|
||||
|
||||
// 1. Extract user preferences and auto-update profile
|
||||
if self.config.auto_profile_update {
|
||||
profile_updates = self.extract_profile_updates(memories, current_identity);
|
||||
}
|
||||
|
||||
// 2. Generate instruction refinement proposals
|
||||
if self.config.enable_instruction_refinement {
|
||||
let instruction_proposals = self.generate_instruction_proposals(
|
||||
agent_id,
|
||||
reflection_result,
|
||||
current_identity,
|
||||
);
|
||||
proposals.extend(instruction_proposals);
|
||||
}
|
||||
|
||||
// 3. Generate soul evolution proposals (rare, high bar)
|
||||
if self.config.enable_soul_evolution {
|
||||
let soul_proposals = self.generate_soul_proposals(
|
||||
agent_id,
|
||||
reflection_result,
|
||||
current_identity,
|
||||
);
|
||||
proposals.extend(soul_proposals);
|
||||
}
|
||||
|
||||
// 4. Generate insights
|
||||
insights = self.generate_insights(memories, reflection_result);
|
||||
|
||||
// 5. Limit proposals
|
||||
proposals.truncate(self.config.max_proposals_per_cycle);
|
||||
|
||||
// 6. Update state
|
||||
let evolved = !profile_updates.is_empty() || !proposals.is_empty();
|
||||
if evolved {
|
||||
self.state.last_evolution = Some(Utc::now().to_rfc3339());
|
||||
self.state.total_evolutions += 1;
|
||||
self.state.pending_proposals += proposals.len();
|
||||
self.state.profile_enrichment_score = self.calculate_profile_score(memories);
|
||||
}
|
||||
|
||||
let result = EvolutionResult {
|
||||
agent_id: agent_id.to_string(),
|
||||
timestamp: Utc::now().to_rfc3339(),
|
||||
profile_updates,
|
||||
proposals,
|
||||
insights,
|
||||
evolved,
|
||||
};
|
||||
|
||||
// Store in history
|
||||
self.evolution_history.push(result.clone());
|
||||
if self.evolution_history.len() > 20 {
|
||||
self.evolution_history = self.evolution_history.split_off(10);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Extract profile updates from memory
|
||||
fn extract_profile_updates(
|
||||
&self,
|
||||
memories: &[MemoryEntryForAnalysis],
|
||||
current_identity: &IdentityFiles,
|
||||
) -> Vec<ProfileUpdate> {
|
||||
let mut updates = Vec::new();
|
||||
|
||||
// Extract preferences
|
||||
let preferences: Vec<_> = memories
|
||||
.iter()
|
||||
.filter(|m| m.memory_type == "preference")
|
||||
.collect();
|
||||
|
||||
if preferences.len() >= self.config.min_preferences_for_update {
|
||||
// Check if user_profile needs updating
|
||||
let current_profile = ¤t_identity.user_profile;
|
||||
let default_profile = "尚未收集到用户偏好信息";
|
||||
|
||||
if current_profile.contains(default_profile) || current_profile.len() < 100 {
|
||||
// Build new profile from preferences
|
||||
let mut sections = Vec::new();
|
||||
|
||||
// Group preferences by category
|
||||
let mut categories: HashMap<String, Vec<String>> = HashMap::new();
|
||||
for pref in &preferences {
|
||||
// Simple categorization based on keywords
|
||||
let category = self.categorize_preference(&pref.content);
|
||||
categories
|
||||
.entry(category)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(pref.content.clone());
|
||||
}
|
||||
|
||||
// Build sections
|
||||
for (category, items) in categories {
|
||||
if !items.is_empty() {
|
||||
sections.push(format!("### {}\n{}", category, items.iter()
|
||||
.map(|i| format!("- {}", i))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")));
|
||||
}
|
||||
}
|
||||
|
||||
if !sections.is_empty() {
|
||||
let new_profile = format!("# 用户画像\n\n{}\n\n_自动生成于 {}_",
|
||||
sections.join("\n\n"),
|
||||
Utc::now().format("%Y-%m-%d")
|
||||
);
|
||||
|
||||
updates.push(ProfileUpdate {
|
||||
section: "user_profile".to_string(),
|
||||
previous: current_profile.clone(),
|
||||
updated: new_profile,
|
||||
source: format!("{} 个偏好记忆", preferences.len()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
/// Categorize a preference
|
||||
fn categorize_preference(&self, content: &str) -> String {
|
||||
let content_lower = content.to_lowercase();
|
||||
|
||||
if content_lower.contains("语言") || content_lower.contains("沟通") || content_lower.contains("回复") {
|
||||
"沟通偏好".to_string()
|
||||
} else if content_lower.contains("技术") || content_lower.contains("框架") || content_lower.contains("工具") {
|
||||
"技术栈".to_string()
|
||||
} else if content_lower.contains("项目") || content_lower.contains("工作") || content_lower.contains("任务") {
|
||||
"工作习惯".to_string()
|
||||
} else if content_lower.contains("格式") || content_lower.contains("风格") || content_lower.contains("风格") {
|
||||
"输出风格".to_string()
|
||||
} else {
|
||||
"其他偏好".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate instruction refinement proposals
|
||||
fn generate_instruction_proposals(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
reflection_result: &ReflectionResult,
|
||||
current_identity: &IdentityFiles,
|
||||
) -> Vec<EvolutionProposal> {
|
||||
let mut proposals = Vec::new();
|
||||
|
||||
// Only propose if there are negative patterns
|
||||
let negative_patterns: Vec<_> = reflection_result.patterns
|
||||
.iter()
|
||||
.filter(|p| matches!(p.sentiment, Sentiment::Negative))
|
||||
.collect();
|
||||
|
||||
if negative_patterns.is_empty() {
|
||||
return proposals;
|
||||
}
|
||||
|
||||
// Check if instructions already contain these warnings
|
||||
let current_instructions = ¤t_identity.instructions;
|
||||
|
||||
// Build proposed additions
|
||||
let mut additions = Vec::new();
|
||||
let mut evidence = Vec::new();
|
||||
|
||||
for pattern in &negative_patterns {
|
||||
// Check if this pattern is already addressed
|
||||
let key_phrase = &pattern.observation;
|
||||
if !current_instructions.contains(key_phrase) {
|
||||
additions.push(format!("- **注意事项**: {}", pattern.observation));
|
||||
evidence.extend(pattern.evidence.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if !additions.is_empty() {
|
||||
let proposed = format!(
|
||||
"{}\n\n## 🔄 自我改进建议\n\n{}\n\n_基于交互模式分析自动生成_",
|
||||
current_instructions.trim_end(),
|
||||
additions.join("\n")
|
||||
);
|
||||
|
||||
proposals.push(EvolutionProposal {
|
||||
id: format!("evo_inst_{}", Utc::now().timestamp()),
|
||||
agent_id: agent_id.to_string(),
|
||||
target_file: IdentityFile::Instructions,
|
||||
change_type: EvolutionChangeType::InstructionAddition,
|
||||
reason: format!(
|
||||
"基于 {} 个负面模式观察,建议在指令中增加自我改进提醒",
|
||||
negative_patterns.len()
|
||||
),
|
||||
current_content: current_instructions.clone(),
|
||||
proposed_content: proposed,
|
||||
confidence: 0.7 + (negative_patterns.len() as f32 * 0.05).min(0.2),
|
||||
evidence,
|
||||
status: ProposalStatus::Pending,
|
||||
created_at: Utc::now().to_rfc3339(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for improvement suggestions that could become instructions
|
||||
for improvement in &reflection_result.improvements {
|
||||
if current_instructions.contains(&improvement.suggestion) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// High priority improvements become instruction proposals
|
||||
if matches!(improvement.priority, super::reflection::Priority::High) {
|
||||
proposals.push(EvolutionProposal {
|
||||
id: format!("evo_inst_{}_{}", Utc::now().timestamp(), rand_suffix()),
|
||||
agent_id: agent_id.to_string(),
|
||||
target_file: IdentityFile::Instructions,
|
||||
change_type: EvolutionChangeType::InstructionRefinement,
|
||||
reason: format!("高优先级改进建议: {}", improvement.area),
|
||||
current_content: current_instructions.clone(),
|
||||
proposed_content: format!(
|
||||
"{}\n\n### {}\n\n{}",
|
||||
current_instructions.trim_end(),
|
||||
improvement.area,
|
||||
improvement.suggestion
|
||||
),
|
||||
confidence: 0.75,
|
||||
evidence: vec![improvement.suggestion.clone()],
|
||||
status: ProposalStatus::Pending,
|
||||
created_at: Utc::now().to_rfc3339(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
proposals
|
||||
}
|
||||
|
||||
/// Generate soul evolution proposals (high bar)
|
||||
fn generate_soul_proposals(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
reflection_result: &ReflectionResult,
|
||||
current_identity: &IdentityFiles,
|
||||
) -> Vec<EvolutionProposal> {
|
||||
let mut proposals = Vec::new();
|
||||
|
||||
// Soul evolution requires strong positive patterns
|
||||
let positive_patterns: Vec<_> = reflection_result.patterns
|
||||
.iter()
|
||||
.filter(|p| matches!(p.sentiment, Sentiment::Positive))
|
||||
.collect();
|
||||
|
||||
// Need at least 3 strong positive patterns
|
||||
if positive_patterns.len() < 3 {
|
||||
return proposals;
|
||||
}
|
||||
|
||||
// Calculate overall confidence
|
||||
let avg_frequency: usize = positive_patterns.iter()
|
||||
.map(|p| p.frequency)
|
||||
.sum::<usize>() / positive_patterns.len();
|
||||
|
||||
if avg_frequency < 5 {
|
||||
return proposals;
|
||||
}
|
||||
|
||||
// Build soul enhancement proposal
|
||||
let current_soul = ¤t_identity.soul;
|
||||
let mut traits = Vec::new();
|
||||
let mut evidence = Vec::new();
|
||||
|
||||
for pattern in &positive_patterns {
|
||||
// Extract trait from observation
|
||||
if pattern.observation.contains("偏好") {
|
||||
traits.push("深入理解用户偏好");
|
||||
} else if pattern.observation.contains("经验") {
|
||||
traits.push("持续积累经验教训");
|
||||
} else if pattern.observation.contains("知识") {
|
||||
traits.push("构建核心知识体系");
|
||||
}
|
||||
evidence.extend(pattern.evidence.clone());
|
||||
}
|
||||
|
||||
if !traits.is_empty() {
|
||||
let traits_section = traits.iter()
|
||||
.map(|t| format!("- {}", t))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let proposed = format!(
|
||||
"{}\n\n## 🌱 成长特质\n\n{}\n\n_通过交互学习持续演化_",
|
||||
current_soul.trim_end(),
|
||||
traits_section
|
||||
);
|
||||
|
||||
proposals.push(EvolutionProposal {
|
||||
id: format!("evo_soul_{}", Utc::now().timestamp()),
|
||||
agent_id: agent_id.to_string(),
|
||||
target_file: IdentityFile::Soul,
|
||||
change_type: EvolutionChangeType::TraitAddition,
|
||||
reason: format!(
|
||||
"基于 {} 个强正面模式,建议增加成长特质",
|
||||
positive_patterns.len()
|
||||
),
|
||||
current_content: current_soul.clone(),
|
||||
proposed_content: proposed,
|
||||
confidence: 0.85,
|
||||
evidence,
|
||||
status: ProposalStatus::Pending,
|
||||
created_at: Utc::now().to_rfc3339(),
|
||||
});
|
||||
}
|
||||
|
||||
proposals
|
||||
}
|
||||
|
||||
/// Generate evolution insights
|
||||
fn generate_insights(
|
||||
&self,
|
||||
memories: &[MemoryEntryForAnalysis],
|
||||
reflection_result: &ReflectionResult,
|
||||
) -> Vec<EvolutionInsight> {
|
||||
let mut insights = Vec::new();
|
||||
|
||||
// Communication style insight
|
||||
let comm_prefs: Vec<_> = memories
|
||||
.iter()
|
||||
.filter(|m| m.memory_type == "preference" &&
|
||||
(m.content.contains("回复") || m.content.contains("语言") || m.content.contains("简洁")))
|
||||
.collect();
|
||||
|
||||
if !comm_prefs.is_empty() {
|
||||
insights.push(EvolutionInsight {
|
||||
category: InsightCategory::CommunicationStyle,
|
||||
observation: format!("用户有 {} 个沟通风格偏好", comm_prefs.len()),
|
||||
recommendation: "在回复中应用这些偏好,提高用户满意度".to_string(),
|
||||
confidence: 0.8,
|
||||
});
|
||||
}
|
||||
|
||||
// Technical expertise insight
|
||||
let tech_memories: Vec<_> = memories
|
||||
.iter()
|
||||
.filter(|m| m.tags.iter().any(|t| t.contains("技术") || t.contains("代码")))
|
||||
.collect();
|
||||
|
||||
if tech_memories.len() >= 5 {
|
||||
insights.push(EvolutionInsight {
|
||||
category: InsightCategory::TechnicalExpertise,
|
||||
observation: format!("积累了 {} 个技术相关记忆", tech_memories.len()),
|
||||
recommendation: "考虑构建技术知识图谱,提高检索效率".to_string(),
|
||||
confidence: 0.7,
|
||||
});
|
||||
}
|
||||
|
||||
// Task efficiency insight from negative patterns
|
||||
let has_task_issues = reflection_result.patterns
|
||||
.iter()
|
||||
.any(|p| p.observation.contains("任务") && matches!(p.sentiment, Sentiment::Negative));
|
||||
|
||||
if has_task_issues {
|
||||
insights.push(EvolutionInsight {
|
||||
category: InsightCategory::TaskEfficiency,
|
||||
observation: "存在任务管理相关问题".to_string(),
|
||||
recommendation: "建议增加任务跟踪和提醒机制".to_string(),
|
||||
confidence: 0.75,
|
||||
});
|
||||
}
|
||||
|
||||
// Knowledge gap insight
|
||||
let lesson_count = memories.iter()
|
||||
.filter(|m| m.memory_type == "lesson")
|
||||
.count();
|
||||
|
||||
if lesson_count > 10 {
|
||||
insights.push(EvolutionInsight {
|
||||
category: InsightCategory::KnowledgeGap,
|
||||
observation: format!("已记录 {} 条经验教训", lesson_count),
|
||||
recommendation: "定期回顾教训,避免重复错误".to_string(),
|
||||
confidence: 0.8,
|
||||
});
|
||||
}
|
||||
|
||||
insights
|
||||
}
|
||||
|
||||
/// Calculate profile enrichment score
|
||||
fn calculate_profile_score(&self, memories: &[MemoryEntryForAnalysis]) -> f32 {
|
||||
let pref_count = memories.iter().filter(|m| m.memory_type == "preference").count();
|
||||
let fact_count = memories.iter().filter(|m| m.memory_type == "fact").count();
|
||||
|
||||
// Score based on diversity and quantity
|
||||
let pref_score = (pref_count as f32 / 10.0).min(1.0) * 0.5;
|
||||
let fact_score = (fact_count as f32 / 20.0).min(1.0) * 0.3;
|
||||
let diversity = if pref_count > 0 && fact_count > 0 { 0.2 } else { 0.0 };
|
||||
|
||||
pref_score + fact_score + diversity
|
||||
}
|
||||
|
||||
/// Get evolution history
|
||||
pub fn get_history(&self, limit: usize) -> Vec<&EvolutionResult> {
|
||||
self.evolution_history.iter().rev().take(limit).collect()
|
||||
}
|
||||
|
||||
/// Get current state
|
||||
pub fn get_state(&self) -> &PersonaEvolverState {
|
||||
&self.state
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn get_config(&self) -> &PersonaEvolverConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Update configuration
|
||||
pub fn update_config(&mut self, config: PersonaEvolverConfig) {
|
||||
self.config = config;
|
||||
}
|
||||
|
||||
/// Mark proposal as handled (approved/rejected)
|
||||
pub fn proposal_handled(&mut self) {
|
||||
if self.state.pending_proposals > 0 {
|
||||
self.state.pending_proposals -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate random suffix
|
||||
fn rand_suffix() -> String {
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
let count = COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
format!("{:04x}", count % 0x10000)
|
||||
}
|
||||
|
||||
// === Tauri Commands ===
|
||||
|
||||
/// Type alias for Tauri state management (shared evolver handle)
|
||||
pub type PersonaEvolverStateHandle = Arc<Mutex<PersonaEvolver>>;
|
||||
|
||||
/// Initialize persona evolver
|
||||
#[tauri::command]
|
||||
pub async fn persona_evolver_init(
|
||||
config: Option<PersonaEvolverConfig>,
|
||||
state: tauri::State<'_, PersonaEvolverStateHandle>,
|
||||
) -> Result<bool, String> {
|
||||
let mut evolver = state.lock().await;
|
||||
if let Some(cfg) = config {
|
||||
evolver.update_config(cfg);
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Run evolution cycle
|
||||
#[tauri::command]
|
||||
pub async fn persona_evolve(
|
||||
agent_id: String,
|
||||
memories: Vec<MemoryEntryForAnalysis>,
|
||||
reflection_state: tauri::State<'_, super::reflection::ReflectionEngineState>,
|
||||
identity_state: tauri::State<'_, super::identity::IdentityManagerState>,
|
||||
evolver_state: tauri::State<'_, PersonaEvolverStateHandle>,
|
||||
) -> Result<EvolutionResult, String> {
|
||||
// 1. Run reflection first
|
||||
let mut reflection = reflection_state.lock().await;
|
||||
let reflection_result = reflection.reflect(&agent_id, &memories);
|
||||
drop(reflection);
|
||||
|
||||
// 2. Get current identity
|
||||
let mut identity = identity_state.lock().await;
|
||||
let current_identity = identity.get_identity(&agent_id);
|
||||
drop(identity);
|
||||
|
||||
// 3. Run evolution
|
||||
let mut evolver = evolver_state.lock().await;
|
||||
let result = evolver.evolve(&agent_id, &memories, &reflection_result, ¤t_identity);
|
||||
|
||||
// 4. Apply auto profile updates
|
||||
if !result.profile_updates.is_empty() {
|
||||
let mut identity = identity_state.lock().await;
|
||||
for update in &result.profile_updates {
|
||||
identity.update_user_profile(&agent_id, &update.updated);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get evolution history
|
||||
#[tauri::command]
|
||||
pub async fn persona_evolution_history(
|
||||
limit: Option<usize>,
|
||||
state: tauri::State<'_, PersonaEvolverStateHandle>,
|
||||
) -> Result<Vec<EvolutionResult>, String> {
|
||||
let evolver = state.lock().await;
|
||||
Ok(evolver.get_history(limit.unwrap_or(10)).into_iter().cloned().collect())
|
||||
}
|
||||
|
||||
/// Get evolver state
|
||||
#[tauri::command]
|
||||
pub async fn persona_evolver_state(
|
||||
state: tauri::State<'_, PersonaEvolverStateHandle>,
|
||||
) -> Result<PersonaEvolverState, String> {
|
||||
let evolver = state.lock().await;
|
||||
Ok(evolver.get_state().clone())
|
||||
}
|
||||
|
||||
/// Get evolver config
|
||||
#[tauri::command]
|
||||
pub async fn persona_evolver_config(
|
||||
state: tauri::State<'_, PersonaEvolverStateHandle>,
|
||||
) -> Result<PersonaEvolverConfig, String> {
|
||||
let evolver = state.lock().await;
|
||||
Ok(evolver.get_config().clone())
|
||||
}
|
||||
|
||||
/// Update evolver config
|
||||
#[tauri::command]
|
||||
pub async fn persona_evolver_update_config(
|
||||
config: PersonaEvolverConfig,
|
||||
state: tauri::State<'_, PersonaEvolverStateHandle>,
|
||||
) -> Result<(), String> {
|
||||
let mut evolver = state.lock().await;
|
||||
evolver.update_config(config);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply evolution proposal (approve)
|
||||
#[tauri::command]
|
||||
pub async fn persona_apply_proposal(
|
||||
proposal: EvolutionProposal,
|
||||
identity_state: tauri::State<'_, super::identity::IdentityManagerState>,
|
||||
evolver_state: tauri::State<'_, PersonaEvolverStateHandle>,
|
||||
) -> Result<IdentityFiles, String> {
|
||||
// Apply the proposal through identity manager
|
||||
let mut identity = identity_state.lock().await;
|
||||
|
||||
let result = match proposal.target_file {
|
||||
IdentityFile::Soul => {
|
||||
identity.update_file(&proposal.agent_id, "soul", &proposal.proposed_content)
|
||||
}
|
||||
IdentityFile::Instructions => {
|
||||
identity.update_file(&proposal.agent_id, "instructions", &proposal.proposed_content)
|
||||
}
|
||||
};
|
||||
|
||||
if result.is_err() {
|
||||
return result.map(|_| IdentityFiles {
|
||||
soul: String::new(),
|
||||
instructions: String::new(),
|
||||
user_profile: String::new(),
|
||||
heartbeat: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Update evolver state
|
||||
let mut evolver = evolver_state.lock().await;
|
||||
evolver.proposal_handled();
|
||||
|
||||
// Return updated identity
|
||||
Ok(identity.get_identity(&proposal.agent_id))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_evolve_empty() {
|
||||
let mut evolver = PersonaEvolver::new(None);
|
||||
let memories = vec![];
|
||||
let reflection = ReflectionResult {
|
||||
patterns: vec![],
|
||||
improvements: vec![],
|
||||
identity_proposals: vec![],
|
||||
new_memories: 0,
|
||||
timestamp: Utc::now().to_rfc3339(),
|
||||
};
|
||||
let identity = IdentityFiles {
|
||||
soul: "Test soul".to_string(),
|
||||
instructions: "Test instructions".to_string(),
|
||||
user_profile: "Test profile".to_string(),
|
||||
heartbeat: None,
|
||||
};
|
||||
|
||||
let result = evolver.evolve("test-agent", &memories, &reflection, &identity);
|
||||
assert!(!result.evolved);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_profile_update() {
|
||||
let mut evolver = PersonaEvolver::new(None);
|
||||
let memories = vec![
|
||||
MemoryEntryForAnalysis {
|
||||
memory_type: "preference".to_string(),
|
||||
content: "喜欢简洁的回复".to_string(),
|
||||
importance: 7,
|
||||
access_count: 3,
|
||||
tags: vec!["沟通".to_string()],
|
||||
},
|
||||
MemoryEntryForAnalysis {
|
||||
memory_type: "preference".to_string(),
|
||||
content: "使用中文".to_string(),
|
||||
importance: 8,
|
||||
access_count: 5,
|
||||
tags: vec!["语言".to_string()],
|
||||
},
|
||||
MemoryEntryForAnalysis {
|
||||
memory_type: "preference".to_string(),
|
||||
content: "代码使用 TypeScript".to_string(),
|
||||
importance: 7,
|
||||
access_count: 2,
|
||||
tags: vec!["技术".to_string()],
|
||||
},
|
||||
];
|
||||
|
||||
let identity = IdentityFiles {
|
||||
soul: "Test".to_string(),
|
||||
instructions: "Test".to_string(),
|
||||
user_profile: "尚未收集到用户偏好信息".to_string(),
|
||||
heartbeat: None,
|
||||
};
|
||||
|
||||
let updates = evolver.extract_profile_updates(&memories, &identity);
|
||||
assert!(!updates.is_empty());
|
||||
assert!(updates[0].updated.contains("用户画像"));
|
||||
}
|
||||
}
|
||||
@@ -1,519 +0,0 @@
|
||||
//! Workflow Recommender - Generates workflow recommendations from detected patterns
|
||||
//!
|
||||
//! This module analyzes behavior patterns and generates actionable workflow recommendations.
|
||||
//! It maps detected patterns to pipelines and provides confidence scoring.
|
||||
//!
|
||||
//! NOTE: Some methods are reserved for future integration with the UI.
|
||||
|
||||
#![allow(dead_code)] // Methods reserved for future UI integration
|
||||
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::mesh::WorkflowRecommendation;
|
||||
use super::pattern_detector::{BehaviorPattern, PatternType};
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Recommendation rule that maps patterns to pipelines
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RecommendationRule {
|
||||
/// Rule identifier
|
||||
pub id: String,
|
||||
/// Pattern types this rule matches
|
||||
pub pattern_types: Vec<String>,
|
||||
/// Pipeline to recommend
|
||||
pub pipeline_id: String,
|
||||
/// Base confidence for this rule
|
||||
pub base_confidence: f32,
|
||||
/// Human-readable description
|
||||
pub description: String,
|
||||
/// Input mappings (pattern context field -> pipeline input)
|
||||
pub input_mappings: HashMap<String, String>,
|
||||
/// Priority (higher = more important)
|
||||
pub priority: u8,
|
||||
}
|
||||
|
||||
/// Recommender configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RecommenderConfig {
|
||||
/// Minimum confidence threshold
|
||||
pub min_confidence: f32,
|
||||
/// Maximum recommendations to generate
|
||||
pub max_recommendations: usize,
|
||||
/// Enable rule-based recommendations
|
||||
pub enable_rules: bool,
|
||||
/// Enable pattern-based recommendations
|
||||
pub enable_patterns: bool,
|
||||
}
|
||||
|
||||
impl Default for RecommenderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_confidence: 0.5,
|
||||
max_recommendations: 10,
|
||||
enable_rules: true,
|
||||
enable_patterns: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Workflow Recommender ===
|
||||
|
||||
/// Workflow recommendation engine
|
||||
pub struct WorkflowRecommender {
|
||||
/// Configuration
|
||||
config: RecommenderConfig,
|
||||
/// Recommendation rules
|
||||
rules: Vec<RecommendationRule>,
|
||||
/// Pipeline registry (pipeline_id -> metadata)
|
||||
#[allow(dead_code)] // Reserved for future pipeline-based recommendations
|
||||
pipeline_registry: HashMap<String, PipelineMetadata>,
|
||||
/// Generated recommendations cache
|
||||
recommendations_cache: Vec<WorkflowRecommendation>,
|
||||
}
|
||||
|
||||
/// Metadata about a registered pipeline
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PipelineMetadata {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub tags: Vec<String>,
|
||||
pub input_schema: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl WorkflowRecommender {
|
||||
/// Create a new workflow recommender
|
||||
pub fn new(config: Option<RecommenderConfig>) -> Self {
|
||||
let mut recommender = Self {
|
||||
config: config.unwrap_or_default(),
|
||||
rules: Vec::new(),
|
||||
pipeline_registry: HashMap::new(),
|
||||
recommendations_cache: Vec::new(),
|
||||
};
|
||||
|
||||
// Initialize with built-in rules
|
||||
recommender.initialize_default_rules();
|
||||
recommender
|
||||
}
|
||||
|
||||
/// Initialize default recommendation rules
|
||||
fn initialize_default_rules(&mut self) {
|
||||
// Rule: Research + Analysis -> Report Generation
|
||||
self.rules.push(RecommendationRule {
|
||||
id: "rule_research_report".to_string(),
|
||||
pattern_types: vec!["SkillCombination".to_string()],
|
||||
pipeline_id: "research-report-generator".to_string(),
|
||||
base_confidence: 0.7,
|
||||
description: "Generate comprehensive research report".to_string(),
|
||||
input_mappings: HashMap::new(),
|
||||
priority: 8,
|
||||
});
|
||||
|
||||
// Rule: Code + Test -> Quality Check Pipeline
|
||||
self.rules.push(RecommendationRule {
|
||||
id: "rule_code_quality".to_string(),
|
||||
pattern_types: vec!["SkillCombination".to_string()],
|
||||
pipeline_id: "code-quality-check".to_string(),
|
||||
base_confidence: 0.75,
|
||||
description: "Run code quality and test pipeline".to_string(),
|
||||
input_mappings: HashMap::new(),
|
||||
priority: 7,
|
||||
});
|
||||
|
||||
// Rule: Daily morning -> Daily briefing
|
||||
self.rules.push(RecommendationRule {
|
||||
id: "rule_morning_briefing".to_string(),
|
||||
pattern_types: vec!["TemporalTrigger".to_string()],
|
||||
pipeline_id: "daily-briefing".to_string(),
|
||||
base_confidence: 0.6,
|
||||
description: "Generate daily briefing".to_string(),
|
||||
input_mappings: HashMap::new(),
|
||||
priority: 5,
|
||||
});
|
||||
|
||||
// Rule: Task + Deadline -> Priority sort
|
||||
self.rules.push(RecommendationRule {
|
||||
id: "rule_task_priority".to_string(),
|
||||
pattern_types: vec!["InputPattern".to_string()],
|
||||
pipeline_id: "task-priority-sorter".to_string(),
|
||||
base_confidence: 0.65,
|
||||
description: "Sort and prioritize tasks".to_string(),
|
||||
input_mappings: HashMap::new(),
|
||||
priority: 6,
|
||||
});
|
||||
}
|
||||
|
||||
/// Generate recommendations from detected patterns
|
||||
pub fn recommend(&self, patterns: &[&BehaviorPattern]) -> Vec<WorkflowRecommendation> {
|
||||
let mut recommendations = Vec::new();
|
||||
|
||||
if patterns.is_empty() {
|
||||
return recommendations;
|
||||
}
|
||||
|
||||
// Rule-based recommendations
|
||||
if self.config.enable_rules {
|
||||
for rule in &self.rules {
|
||||
if let Some(rec) = self.apply_rule(rule, patterns) {
|
||||
if rec.confidence >= self.config.min_confidence {
|
||||
recommendations.push(rec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern-based recommendations (direct mapping)
|
||||
if self.config.enable_patterns {
|
||||
for pattern in patterns {
|
||||
if let Some(rec) = self.pattern_to_recommendation(pattern) {
|
||||
if rec.confidence >= self.config.min_confidence {
|
||||
recommendations.push(rec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by confidence (descending) and priority
|
||||
recommendations.sort_by(|a, b| {
|
||||
let priority_diff = self.get_priority_for_recommendation(b)
|
||||
.cmp(&self.get_priority_for_recommendation(a));
|
||||
if priority_diff != std::cmp::Ordering::Equal {
|
||||
return priority_diff;
|
||||
}
|
||||
b.confidence.partial_cmp(&a.confidence).unwrap()
|
||||
});
|
||||
|
||||
// Limit recommendations
|
||||
recommendations.truncate(self.config.max_recommendations);
|
||||
|
||||
recommendations
|
||||
}
|
||||
|
||||
/// Apply a recommendation rule to patterns
|
||||
fn apply_rule(
|
||||
&self,
|
||||
rule: &RecommendationRule,
|
||||
patterns: &[&BehaviorPattern],
|
||||
) -> Option<WorkflowRecommendation> {
|
||||
let mut matched_patterns: Vec<String> = Vec::new();
|
||||
let mut total_confidence = 0.0;
|
||||
let mut match_count = 0;
|
||||
|
||||
for pattern in patterns {
|
||||
let pattern_type_name = self.get_pattern_type_name(&pattern.pattern_type);
|
||||
|
||||
if rule.pattern_types.contains(&pattern_type_name) {
|
||||
matched_patterns.push(pattern.id.clone());
|
||||
total_confidence += pattern.confidence;
|
||||
match_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if matched_patterns.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Calculate combined confidence
|
||||
let avg_pattern_confidence = total_confidence / match_count as f32;
|
||||
let final_confidence = (rule.base_confidence * 0.6 + avg_pattern_confidence * 0.4).min(1.0);
|
||||
|
||||
// Build suggested inputs from pattern context
|
||||
let suggested_inputs = self.build_suggested_inputs(&matched_patterns, patterns, rule);
|
||||
|
||||
Some(WorkflowRecommendation {
|
||||
id: format!("rec_{}", Uuid::new_v4()),
|
||||
pipeline_id: rule.pipeline_id.clone(),
|
||||
confidence: final_confidence,
|
||||
reason: rule.description.clone(),
|
||||
suggested_inputs,
|
||||
patterns_matched: matched_patterns,
|
||||
timestamp: Utc::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert a single pattern to a recommendation
|
||||
fn pattern_to_recommendation(&self, pattern: &BehaviorPattern) -> Option<WorkflowRecommendation> {
|
||||
let (pipeline_id, reason) = match &pattern.pattern_type {
|
||||
PatternType::TaskPipelineMapping { task_type, pipeline_id } => {
|
||||
(pipeline_id.clone(), format!("Detected task type: {}", task_type))
|
||||
}
|
||||
PatternType::SkillCombination { skill_ids } => {
|
||||
// Find a pipeline that uses these skills
|
||||
let pipeline_id = self.find_pipeline_for_skills(skill_ids)?;
|
||||
(pipeline_id, format!("Skills often used together: {}", skill_ids.join(", ")))
|
||||
}
|
||||
PatternType::InputPattern { keywords, intent } => {
|
||||
// Find a pipeline for this intent
|
||||
let pipeline_id = self.find_pipeline_for_intent(intent)?;
|
||||
(pipeline_id, format!("Intent detected: {} ({})", intent, keywords.join(", ")))
|
||||
}
|
||||
PatternType::TemporalTrigger { hand_id, time_pattern } => {
|
||||
(format!("scheduled_{}", hand_id), format!("Scheduled at: {}", time_pattern))
|
||||
}
|
||||
};
|
||||
|
||||
Some(WorkflowRecommendation {
|
||||
id: format!("rec_{}", Uuid::new_v4()),
|
||||
pipeline_id,
|
||||
confidence: pattern.confidence,
|
||||
reason,
|
||||
suggested_inputs: HashMap::new(),
|
||||
patterns_matched: vec![pattern.id.clone()],
|
||||
timestamp: Utc::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get string name for pattern type
|
||||
fn get_pattern_type_name(&self, pattern_type: &PatternType) -> String {
|
||||
match pattern_type {
|
||||
PatternType::SkillCombination { .. } => "SkillCombination".to_string(),
|
||||
PatternType::TemporalTrigger { .. } => "TemporalTrigger".to_string(),
|
||||
PatternType::TaskPipelineMapping { .. } => "TaskPipelineMapping".to_string(),
|
||||
PatternType::InputPattern { .. } => "InputPattern".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get priority for a recommendation
|
||||
fn get_priority_for_recommendation(&self, rec: &WorkflowRecommendation) -> u8 {
|
||||
self.rules
|
||||
.iter()
|
||||
.find(|r| r.pipeline_id == rec.pipeline_id)
|
||||
.map(|r| r.priority)
|
||||
.unwrap_or(5)
|
||||
}
|
||||
|
||||
/// Build suggested inputs from patterns and rule
|
||||
fn build_suggested_inputs(
|
||||
&self,
|
||||
matched_pattern_ids: &[String],
|
||||
patterns: &[&BehaviorPattern],
|
||||
rule: &RecommendationRule,
|
||||
) -> HashMap<String, serde_json::Value> {
|
||||
let mut inputs = HashMap::new();
|
||||
|
||||
for pattern_id in matched_pattern_ids {
|
||||
if let Some(pattern) = patterns.iter().find(|p| p.id == *pattern_id) {
|
||||
// Add context-based inputs
|
||||
if let Some(ref topics) = pattern.context.recent_topics {
|
||||
if !topics.is_empty() {
|
||||
inputs.insert(
|
||||
"topics".to_string(),
|
||||
serde_json::Value::Array(
|
||||
topics.iter().map(|t| serde_json::Value::String(t.clone())).collect()
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref intent) = pattern.context.intent {
|
||||
inputs.insert("intent".to_string(), serde_json::Value::String(intent.clone()));
|
||||
}
|
||||
|
||||
// Add pattern-specific inputs
|
||||
match &pattern.pattern_type {
|
||||
PatternType::InputPattern { keywords, .. } => {
|
||||
inputs.insert(
|
||||
"keywords".to_string(),
|
||||
serde_json::Value::Array(
|
||||
keywords.iter().map(|k| serde_json::Value::String(k.clone())).collect()
|
||||
),
|
||||
);
|
||||
}
|
||||
PatternType::SkillCombination { skill_ids } => {
|
||||
inputs.insert(
|
||||
"skills".to_string(),
|
||||
serde_json::Value::Array(
|
||||
skill_ids.iter().map(|s| serde_json::Value::String(s.clone())).collect()
|
||||
),
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply rule mappings
|
||||
for (source, target) in &rule.input_mappings {
|
||||
if let Some(value) = inputs.get(source) {
|
||||
inputs.insert(target.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
inputs
|
||||
}
|
||||
|
||||
/// Find a pipeline that uses the given skills
|
||||
fn find_pipeline_for_skills(&self, skill_ids: &[String]) -> Option<String> {
|
||||
// In production, this would query the pipeline registry
|
||||
// For now, return a default
|
||||
if skill_ids.len() >= 2 {
|
||||
Some("skill-orchestration-pipeline".to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Find a pipeline for an intent
|
||||
fn find_pipeline_for_intent(&self, intent: &str) -> Option<String> {
|
||||
// Map common intents to pipelines
|
||||
match intent {
|
||||
"research" => Some("research-pipeline".to_string()),
|
||||
"analysis" => Some("analysis-pipeline".to_string()),
|
||||
"report" => Some("report-generation".to_string()),
|
||||
"code" => Some("code-generation".to_string()),
|
||||
"task" | "tasks" => Some("task-management".to_string()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a pipeline
|
||||
pub fn register_pipeline(&mut self, metadata: PipelineMetadata) {
|
||||
self.pipeline_registry.insert(metadata.id.clone(), metadata);
|
||||
}
|
||||
|
||||
/// Unregister a pipeline
|
||||
pub fn unregister_pipeline(&mut self, pipeline_id: &str) {
|
||||
self.pipeline_registry.remove(pipeline_id);
|
||||
}
|
||||
|
||||
/// Add a custom recommendation rule
|
||||
pub fn add_rule(&mut self, rule: RecommendationRule) {
|
||||
self.rules.push(rule);
|
||||
// Sort by priority
|
||||
self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
|
||||
}
|
||||
|
||||
/// Remove a rule
|
||||
pub fn remove_rule(&mut self, rule_id: &str) {
|
||||
self.rules.retain(|r| r.id != rule_id);
|
||||
}
|
||||
|
||||
/// Get all rules
|
||||
pub fn get_rules(&self) -> &[RecommendationRule] {
|
||||
&self.rules
|
||||
}
|
||||
|
||||
/// Update configuration
|
||||
pub fn update_config(&mut self, config: RecommenderConfig) {
|
||||
self.config = config;
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn get_config(&self) -> &RecommenderConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get recommendation count
|
||||
pub fn recommendation_count(&self) -> usize {
|
||||
self.recommendations_cache.len()
|
||||
}
|
||||
|
||||
/// Clear recommendation cache
|
||||
pub fn clear_cache(&mut self) {
|
||||
self.recommendations_cache.clear();
|
||||
}
|
||||
|
||||
/// Accept a recommendation (remove from cache and return it)
|
||||
/// Returns the accepted recommendation if found
|
||||
pub fn accept_recommendation(&mut self, recommendation_id: &str) -> Option<WorkflowRecommendation> {
|
||||
if let Some(pos) = self.recommendations_cache.iter().position(|r| r.id == recommendation_id) {
|
||||
Some(self.recommendations_cache.remove(pos))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Dismiss a recommendation (remove from cache without acting on it)
|
||||
/// Returns true if the recommendation was found and dismissed
|
||||
pub fn dismiss_recommendation(&mut self, recommendation_id: &str) -> bool {
|
||||
if let Some(pos) = self.recommendations_cache.iter().position(|r| r.id == recommendation_id) {
|
||||
self.recommendations_cache.remove(pos);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a recommendation by ID
|
||||
pub fn get_recommendation(&self, recommendation_id: &str) -> Option<&WorkflowRecommendation> {
|
||||
self.recommendations_cache.iter().find(|r| r.id == recommendation_id)
|
||||
}
|
||||
|
||||
/// Load recommendations from file
|
||||
pub fn load_from_file(&mut self, path: &str) -> Result<(), String> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| format!("Failed to read file: {}", e))?;
|
||||
|
||||
let recommendations: Vec<WorkflowRecommendation> = serde_json::from_str(&content)
|
||||
.map_err(|e| format!("Failed to parse recommendations: {}", e))?;
|
||||
|
||||
self.recommendations_cache = recommendations;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Save recommendations to file
|
||||
pub fn save_to_file(&self, path: &str) -> Result<(), String> {
|
||||
let content = serde_json::to_string_pretty(&self.recommendations_cache)
|
||||
.map_err(|e| format!("Failed to serialize recommendations: {}", e))?;
|
||||
|
||||
std::fs::write(path, content)
|
||||
.map_err(|e| format!("Failed to write file: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_recommender_creation() {
|
||||
let recommender = WorkflowRecommender::new(None);
|
||||
assert!(!recommender.get_rules().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recommend_from_empty_patterns() {
|
||||
let recommender = WorkflowRecommender::new(None);
|
||||
let recommendations = recommender.recommend(&[]);
|
||||
assert!(recommendations.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule_priority() {
|
||||
let mut recommender = WorkflowRecommender::new(None);
|
||||
|
||||
recommender.add_rule(RecommendationRule {
|
||||
id: "high_priority".to_string(),
|
||||
pattern_types: vec!["SkillCombination".to_string()],
|
||||
pipeline_id: "important-pipeline".to_string(),
|
||||
base_confidence: 0.9,
|
||||
description: "High priority rule".to_string(),
|
||||
input_mappings: HashMap::new(),
|
||||
priority: 10,
|
||||
});
|
||||
|
||||
let rules = recommender.get_rules();
|
||||
assert!(rules.iter().any(|r| r.priority == 10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_pipeline() {
|
||||
let mut recommender = WorkflowRecommender::new(None);
|
||||
|
||||
recommender.register_pipeline(PipelineMetadata {
|
||||
id: "test-pipeline".to_string(),
|
||||
name: "Test Pipeline".to_string(),
|
||||
description: Some("A test pipeline".to_string()),
|
||||
tags: vec!["test".to_string()],
|
||||
input_schema: None,
|
||||
});
|
||||
|
||||
assert!(recommender.pipeline_registry.contains_key("test-pipeline"));
|
||||
}
|
||||
}
|
||||
@@ -1,845 +0,0 @@
|
||||
//! Trigger Evaluator - Evaluates context-aware triggers for Hands
|
||||
//!
|
||||
//! This module extends the basic trigger system with semantic matching:
|
||||
//! Supports MemoryQuery, ContextCondition, and IdentityState triggers.
|
||||
//!
|
||||
//! NOTE: This module is not yet integrated into the main application.
|
||||
//! Components are still being developed and will be connected in a future release.
|
||||
|
||||
#![allow(dead_code)] // Module not yet integrated - components under development
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::pin::Pin;
|
||||
use tokio::sync::Mutex;
|
||||
use chrono::{DateTime, Utc, Timelike, Datelike};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
use zclaw_memory::MemoryStore;
|
||||
|
||||
// === ReDoS Protection Constants ===
|
||||
|
||||
/// Maximum allowed length for regex patterns (prevents memory exhaustion)
|
||||
const MAX_REGEX_PATTERN_LENGTH: usize = 500;
|
||||
|
||||
/// Maximum allowed nesting depth for regex quantifiers/groups
|
||||
const MAX_REGEX_NESTING_DEPTH: usize = 10;
|
||||
|
||||
/// Error type for regex validation failures
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum RegexValidationError {
|
||||
/// Pattern exceeds maximum length
|
||||
TooLong { length: usize, max: usize },
|
||||
/// Pattern has excessive nesting depth
|
||||
TooDeeplyNested { depth: usize, max: usize },
|
||||
/// Pattern contains dangerous ReDoS-prone constructs
|
||||
DangerousPattern(String),
|
||||
/// Invalid regex syntax
|
||||
InvalidSyntax(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RegexValidationError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
RegexValidationError::TooLong { length, max } => {
|
||||
write!(f, "Regex pattern too long: {} bytes (max: {})", length, max)
|
||||
}
|
||||
RegexValidationError::TooDeeplyNested { depth, max } => {
|
||||
write!(f, "Regex pattern too deeply nested: {} levels (max: {})", depth, max)
|
||||
}
|
||||
RegexValidationError::DangerousPattern(reason) => {
|
||||
write!(f, "Dangerous regex pattern detected: {}", reason)
|
||||
}
|
||||
RegexValidationError::InvalidSyntax(err) => {
|
||||
write!(f, "Invalid regex syntax: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for RegexValidationError {}
|
||||
|
||||
/// Validate a regex pattern for ReDoS safety
|
||||
///
|
||||
/// This function checks for:
|
||||
/// 1. Pattern length (prevents memory exhaustion)
|
||||
/// 2. Nesting depth (prevents exponential backtracking)
|
||||
/// 3. Dangerous patterns (nested quantifiers on overlapping character classes)
|
||||
fn validate_regex_pattern(pattern: &str) -> Result<(), RegexValidationError> {
|
||||
// Check length
|
||||
if pattern.len() > MAX_REGEX_PATTERN_LENGTH {
|
||||
return Err(RegexValidationError::TooLong {
|
||||
length: pattern.len(),
|
||||
max: MAX_REGEX_PATTERN_LENGTH,
|
||||
});
|
||||
}
|
||||
|
||||
// Check nesting depth by counting unescaped parentheses and brackets
|
||||
let nesting_depth = calculate_nesting_depth(pattern);
|
||||
if nesting_depth > MAX_REGEX_NESTING_DEPTH {
|
||||
return Err(RegexValidationError::TooDeeplyNested {
|
||||
depth: nesting_depth,
|
||||
max: MAX_REGEX_NESTING_DEPTH,
|
||||
});
|
||||
}
|
||||
|
||||
// Check for dangerous ReDoS patterns:
|
||||
// - Nested quantifiers on overlapping patterns like (a+)+
|
||||
// - Alternation with overlapping patterns like (a|a)+
|
||||
if contains_dangerous_redos_pattern(pattern) {
|
||||
return Err(RegexValidationError::DangerousPattern(
|
||||
"Pattern contains nested quantifiers on overlapping character classes".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calculate the maximum nesting depth of groups in a regex pattern
|
||||
fn calculate_nesting_depth(pattern: &str) -> usize {
|
||||
let chars: Vec<char> = pattern.chars().collect();
|
||||
let mut max_depth = 0;
|
||||
let mut current_depth = 0;
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
let c = chars[i];
|
||||
|
||||
// Check for escape sequence
|
||||
if c == '\\' && i + 1 < chars.len() {
|
||||
// Skip the escaped character
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle character classes [...]
|
||||
if c == '[' {
|
||||
current_depth += 1;
|
||||
max_depth = max_depth.max(current_depth);
|
||||
// Find matching ]
|
||||
i += 1;
|
||||
while i < chars.len() {
|
||||
if chars[i] == '\\' && i + 1 < chars.len() {
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
if chars[i] == ']' {
|
||||
current_depth -= 1;
|
||||
break;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
// Handle groups (...)
|
||||
else if c == '(' {
|
||||
// Skip non-capturing groups and lookaheads for simplicity
|
||||
// (?:...), (?=...), (?!...), (?<=...), (?<!...), (?P<name>...)
|
||||
current_depth += 1;
|
||||
max_depth = max_depth.max(current_depth);
|
||||
} else if c == ')' {
|
||||
if current_depth > 0 {
|
||||
current_depth -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
max_depth
|
||||
}
|
||||
|
||||
/// Check for dangerous ReDoS patterns
|
||||
///
|
||||
/// Detects patterns like:
|
||||
/// - (a+)+ - nested quantifiers
|
||||
/// - (a*)+ - nested quantifiers
|
||||
/// - (a+)* - nested quantifiers
|
||||
/// - (.*)* - nested quantifiers on wildcard
|
||||
fn contains_dangerous_redos_pattern(pattern: &str) -> bool {
|
||||
let chars: Vec<char> = pattern.chars().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
// Look for quantified patterns followed by another quantifier
|
||||
if i > 0 {
|
||||
let prev = chars[i - 1];
|
||||
|
||||
// Check if current char is a quantifier
|
||||
if matches!(chars[i], '+' | '*' | '?') {
|
||||
// Look back to see what's being quantified
|
||||
if prev == ')' {
|
||||
// Find the matching opening paren
|
||||
let mut depth = 1;
|
||||
let mut j = i - 2;
|
||||
while j > 0 && depth > 0 {
|
||||
if chars[j] == ')' {
|
||||
depth += 1;
|
||||
} else if chars[j] == '(' {
|
||||
depth -= 1;
|
||||
} else if chars[j] == '\\' && j > 0 {
|
||||
j -= 1; // Skip escaped char
|
||||
}
|
||||
j -= 1;
|
||||
}
|
||||
|
||||
// Check if the group content ends with a quantifier
|
||||
// This would indicate nested quantification
|
||||
// Note: j is usize, so we don't check >= 0 (always true)
|
||||
// The loop above ensures j is valid if depth reached 0
|
||||
let mut k = i - 2;
|
||||
while k > j + 1 {
|
||||
if chars[k] == '\\' && k > 0 {
|
||||
k -= 1;
|
||||
} else if matches!(chars[k], '+' | '*' | '?') {
|
||||
// Found nested quantifier
|
||||
return true;
|
||||
} else if chars[k] == ')' {
|
||||
// Skip nested groups
|
||||
let mut nested_depth = 1;
|
||||
k -= 1;
|
||||
while k > j + 1 && nested_depth > 0 {
|
||||
if chars[k] == ')' {
|
||||
nested_depth += 1;
|
||||
} else if chars[k] == '(' {
|
||||
nested_depth -= 1;
|
||||
} else if chars[k] == '\\' && k > 0 {
|
||||
k -= 1;
|
||||
}
|
||||
k -= 1;
|
||||
}
|
||||
}
|
||||
k -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Safely compile a regex pattern with ReDoS protection
|
||||
///
|
||||
/// This function validates the pattern for safety before compilation.
|
||||
/// Returns a compiled regex or an error describing why validation failed.
|
||||
pub fn compile_safe_regex(pattern: &str) -> Result<regex::Regex, RegexValidationError> {
|
||||
validate_regex_pattern(pattern)?;
|
||||
|
||||
regex::Regex::new(pattern).map_err(|e| RegexValidationError::InvalidSyntax(e.to_string()))
|
||||
}
|
||||
|
||||
// === Extended Trigger Types ===
|
||||
|
||||
/// Memory query trigger configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryQueryConfig {
|
||||
/// Memory type to filter (e.g., "task", "preference")
|
||||
pub memory_type: Option<String>,
|
||||
/// Content pattern to match (regex or substring)
|
||||
pub content_pattern: String,
|
||||
/// Minimum count of matching memories
|
||||
pub min_count: usize,
|
||||
/// Minimum importance threshold
|
||||
pub min_importance: Option<i32>,
|
||||
/// Time window for memories (hours)
|
||||
pub time_window_hours: Option<u64>,
|
||||
}
|
||||
|
||||
/// Context condition configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ContextConditionConfig {
|
||||
/// Conditions to check
|
||||
pub conditions: Vec<ContextConditionClause>,
|
||||
/// How to combine conditions (All, Any, None)
|
||||
pub combination: ConditionCombination,
|
||||
}
|
||||
|
||||
/// Single context condition clause
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ContextConditionClause {
|
||||
/// Field to check
|
||||
pub field: ContextField,
|
||||
/// Comparison operator
|
||||
pub operator: ComparisonOperator,
|
||||
/// Value to compare against
|
||||
pub value: JsonValue,
|
||||
}
|
||||
|
||||
/// Context fields that can be checked
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum ContextField {
|
||||
/// Current hour of day (0-23)
|
||||
TimeOfDay,
|
||||
/// Day of week (0=Monday, 6=Sunday)
|
||||
DayOfWeek,
|
||||
/// Currently active project (if any)
|
||||
ActiveProject,
|
||||
/// Topics discussed recently
|
||||
RecentTopic,
|
||||
/// Number of pending tasks
|
||||
PendingTasks,
|
||||
/// Count of memories in storage
|
||||
MemoryCount,
|
||||
/// Hours since last interaction
|
||||
LastInteractionHours,
|
||||
/// Current conversation intent
|
||||
ConversationIntent,
|
||||
}
|
||||
|
||||
/// Comparison operators for context conditions
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum ComparisonOperator {
|
||||
Equals,
|
||||
NotEquals,
|
||||
Contains,
|
||||
GreaterThan,
|
||||
LessThan,
|
||||
Exists,
|
||||
NotExists,
|
||||
Matches, // regex match
|
||||
}
|
||||
|
||||
/// How to combine multiple conditions
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum ConditionCombination {
|
||||
/// All conditions must true
|
||||
All,
|
||||
/// Any one condition being true is enough
|
||||
Any,
|
||||
/// None of the conditions should be true
|
||||
None,
|
||||
}
|
||||
|
||||
/// Identity state trigger configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IdentityStateConfig {
|
||||
/// Identity file to check
|
||||
pub file: IdentityFile,
|
||||
/// Content pattern to match (regex)
|
||||
pub content_pattern: Option<String>,
|
||||
/// Trigger on any change to the file
|
||||
pub any_change: bool,
|
||||
}
|
||||
|
||||
/// Identity files that can be monitored
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum IdentityFile {
|
||||
Soul,
|
||||
Instructions,
|
||||
User,
|
||||
}
|
||||
|
||||
/// Composite trigger configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompositeTriggerConfig {
|
||||
/// Sub-triggers to combine
|
||||
pub triggers: Vec<ExtendedTriggerType>,
|
||||
/// How to combine results
|
||||
pub combination: ConditionCombination,
|
||||
}
|
||||
|
||||
/// Extended trigger type that includes semantic triggers
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ExtendedTriggerType {
|
||||
/// Standard interval trigger
|
||||
Interval {
|
||||
/// Interval in seconds
|
||||
seconds: u64,
|
||||
},
|
||||
/// Time-of-day trigger
|
||||
TimeOfDay {
|
||||
/// Hour (0-23)
|
||||
hour: u8,
|
||||
/// Optional minute (0-59)
|
||||
minute: Option<u8>,
|
||||
},
|
||||
/// Memory query trigger
|
||||
MemoryQuery(MemoryQueryConfig),
|
||||
/// Context condition trigger
|
||||
ContextCondition(ContextConditionConfig),
|
||||
/// Identity state trigger
|
||||
IdentityState(IdentityStateConfig),
|
||||
/// Composite trigger
|
||||
Composite(CompositeTriggerConfig),
|
||||
}
|
||||
|
||||
// === Trigger Evaluator ===
|
||||
|
||||
/// Evaluator for context-aware triggers
|
||||
pub struct TriggerEvaluator {
|
||||
/// Memory store for memory queries
|
||||
memory_store: Arc<MemoryStore>,
|
||||
/// Identity manager for identity triggers
|
||||
identity_manager: Arc<Mutex<super::identity::AgentIdentityManager>>,
|
||||
/// Heartbeat engine for context
|
||||
heartbeat_engine: Arc<Mutex<super::heartbeat::HeartbeatEngine>>,
|
||||
/// Cached context data
|
||||
context_cache: Arc<Mutex<TriggerContextCache>>,
|
||||
}
|
||||
|
||||
/// Cached context for trigger evaluation
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TriggerContextCache {
|
||||
/// Last known active project
|
||||
pub active_project: Option<String>,
|
||||
/// Recent topics discussed
|
||||
pub recent_topics: Vec<String>,
|
||||
/// Last conversation intent
|
||||
pub conversation_intent: Option<String>,
|
||||
/// Last update time
|
||||
pub last_updated: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TriggerEvaluator {
|
||||
/// Create a new trigger evaluator
|
||||
pub fn new(
|
||||
memory_store: Arc<MemoryStore>,
|
||||
identity_manager: Arc<Mutex<super::identity::AgentIdentityManager>>,
|
||||
heartbeat_engine: Arc<Mutex<super::heartbeat::HeartbeatEngine>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
memory_store,
|
||||
identity_manager,
|
||||
heartbeat_engine,
|
||||
context_cache: Arc::new(Mutex::new(TriggerContextCache::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a trigger
|
||||
pub async fn evaluate(
|
||||
&self,
|
||||
trigger: &ExtendedTriggerType,
|
||||
agent_id: &str,
|
||||
) -> Result<bool, String> {
|
||||
match trigger {
|
||||
ExtendedTriggerType::Interval { .. } => Ok(true),
|
||||
ExtendedTriggerType::TimeOfDay { hour, minute } => {
|
||||
let now = Utc::now();
|
||||
let current_hour = now.hour() as u8;
|
||||
let current_minute = now.minute() as u8;
|
||||
|
||||
if current_hour != *hour {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if let Some(min) = minute {
|
||||
if current_minute != *min {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
ExtendedTriggerType::MemoryQuery(config) => {
|
||||
self.evaluate_memory_query(config, agent_id).await
|
||||
}
|
||||
ExtendedTriggerType::ContextCondition(config) => {
|
||||
self.evaluate_context_condition(config, agent_id).await
|
||||
}
|
||||
ExtendedTriggerType::IdentityState(config) => {
|
||||
self.evaluate_identity_state(config, agent_id).await
|
||||
}
|
||||
ExtendedTriggerType::Composite(config) => {
|
||||
self.evaluate_composite(config, agent_id, None).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate memory query trigger
|
||||
async fn evaluate_memory_query(
|
||||
&self,
|
||||
config: &MemoryQueryConfig,
|
||||
_agent_id: &str,
|
||||
) -> Result<bool, String> {
|
||||
// TODO: Implement proper memory search when MemoryStore supports it
|
||||
// For now, use KV store to check if we have enough keys matching pattern
|
||||
// This is a simplified implementation
|
||||
|
||||
// Memory search is not fully implemented in current MemoryStore
|
||||
// Return false to indicate no matches until proper search is available
|
||||
tracing::warn!(
|
||||
pattern = %config.content_pattern,
|
||||
min_count = config.min_count,
|
||||
"Memory query trigger evaluation not fully implemented"
|
||||
);
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Evaluate context condition trigger
|
||||
async fn evaluate_context_condition(
|
||||
&self,
|
||||
config: &ContextConditionConfig,
|
||||
agent_id: &str,
|
||||
) -> Result<bool, String> {
|
||||
let context = self.get_cached_context(agent_id).await;
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
for condition in &config.conditions {
|
||||
let result = self.evaluate_condition_clause(condition, &context);
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
// Combine results based on combination mode
|
||||
let final_result = match config.combination {
|
||||
ConditionCombination::All => results.iter().all(|r| *r),
|
||||
ConditionCombination::Any => results.iter().any(|r| *r),
|
||||
ConditionCombination::None => results.iter().all(|r| !*r),
|
||||
};
|
||||
|
||||
Ok(final_result)
|
||||
}
|
||||
|
||||
/// Evaluate a single condition clause
|
||||
fn evaluate_condition_clause(
|
||||
&self,
|
||||
clause: &ContextConditionClause,
|
||||
context: &TriggerContextCache,
|
||||
) -> bool {
|
||||
match clause.field {
|
||||
ContextField::TimeOfDay => {
|
||||
let now = Utc::now();
|
||||
let current_hour = now.hour() as i32;
|
||||
self.compare_values(current_hour, &clause.operator, &clause.value)
|
||||
}
|
||||
ContextField::DayOfWeek => {
|
||||
let now = Utc::now();
|
||||
let current_day = now.weekday().num_days_from_monday() as i32;
|
||||
self.compare_values(current_day, &clause.operator, &clause.value)
|
||||
}
|
||||
ContextField::ActiveProject => {
|
||||
if let Some(project) = &context.active_project {
|
||||
self.compare_values(project.clone(), &clause.operator, &clause.value)
|
||||
} else {
|
||||
matches!(clause.operator, ComparisonOperator::NotExists)
|
||||
}
|
||||
}
|
||||
ContextField::RecentTopic => {
|
||||
if let Some(topic) = context.recent_topics.first() {
|
||||
self.compare_values(topic.clone(), &clause.operator, &clause.value)
|
||||
} else {
|
||||
matches!(clause.operator, ComparisonOperator::NotExists)
|
||||
}
|
||||
}
|
||||
ContextField::PendingTasks => {
|
||||
// Would need to query memory store
|
||||
false // Not implemented yet
|
||||
}
|
||||
ContextField::MemoryCount => {
|
||||
// Would need to query memory store
|
||||
false // Not implemented yet
|
||||
}
|
||||
ContextField::LastInteractionHours => {
|
||||
if let Some(last_updated) = context.last_updated {
|
||||
let hours = (Utc::now() - last_updated).num_hours();
|
||||
self.compare_values(hours as i32, &clause.operator, &clause.value)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ContextField::ConversationIntent => {
|
||||
if let Some(intent) = &context.conversation_intent {
|
||||
self.compare_values(intent.clone(), &clause.operator, &clause.value)
|
||||
} else {
|
||||
matches!(clause.operator, ComparisonOperator::NotExists)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values using operator
|
||||
fn compare_values<T>(&self, actual: T, operator: &ComparisonOperator, expected: &JsonValue) -> bool
|
||||
where
|
||||
T: Into<JsonValue>,
|
||||
{
|
||||
let actual_value = actual.into();
|
||||
|
||||
match operator {
|
||||
ComparisonOperator::Equals => &actual_value == expected,
|
||||
ComparisonOperator::NotEquals => &actual_value != expected,
|
||||
ComparisonOperator::Contains => {
|
||||
if let (Some(actual_str), Some(expected_str)) =
|
||||
(actual_value.as_str(), expected.as_str())
|
||||
{
|
||||
actual_str.contains(expected_str)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ComparisonOperator::GreaterThan => {
|
||||
if let (Some(actual_num), Some(expected_num)) =
|
||||
(actual_value.as_i64(), expected.as_i64())
|
||||
{
|
||||
actual_num > expected_num
|
||||
} else if let (Some(actual_num), Some(expected_num)) =
|
||||
(actual_value.as_f64(), expected.as_f64())
|
||||
{
|
||||
actual_num > expected_num
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ComparisonOperator::LessThan => {
|
||||
if let (Some(actual_num), Some(expected_num)) =
|
||||
(actual_value.as_i64(), expected.as_i64())
|
||||
{
|
||||
actual_num < expected_num
|
||||
} else if let (Some(actual_num), Some(expected_num)) =
|
||||
(actual_value.as_f64(), expected.as_f64())
|
||||
{
|
||||
actual_num < expected_num
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ComparisonOperator::Exists => !actual_value.is_null(),
|
||||
ComparisonOperator::NotExists => actual_value.is_null(),
|
||||
ComparisonOperator::Matches => {
|
||||
if let (Some(actual_str), Some(expected_str)) =
|
||||
(actual_value.as_str(), expected.as_str())
|
||||
{
|
||||
compile_safe_regex(expected_str)
|
||||
.map(|re| re.is_match(actual_str))
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!(
|
||||
pattern = %expected_str,
|
||||
error = %e,
|
||||
"Regex pattern validation failed, treating as no match"
|
||||
);
|
||||
false
|
||||
})
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate identity state trigger
|
||||
async fn evaluate_identity_state(
|
||||
&self,
|
||||
config: &IdentityStateConfig,
|
||||
agent_id: &str,
|
||||
) -> Result<bool, String> {
|
||||
let mut manager = self.identity_manager.lock().await;
|
||||
let identity = manager.get_identity(agent_id);
|
||||
|
||||
// Get the target file content
|
||||
let content = match config.file {
|
||||
IdentityFile::Soul => identity.soul,
|
||||
IdentityFile::Instructions => identity.instructions,
|
||||
IdentityFile::User => identity.user_profile,
|
||||
};
|
||||
|
||||
// Check content pattern if specified
|
||||
if let Some(pattern) = &config.content_pattern {
|
||||
let re = compile_safe_regex(pattern)
|
||||
.map_err(|e| format!("Invalid regex pattern: {}", e))?;
|
||||
if !re.is_match(&content) {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
// If any_change is true, we would need to track changes
|
||||
// For now, just return true
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Get cached context for an agent
|
||||
async fn get_cached_context(&self, _agent_id: &str) -> TriggerContextCache {
|
||||
self.context_cache.lock().await.clone()
|
||||
}
|
||||
|
||||
/// Evaluate composite trigger
|
||||
fn evaluate_composite<'a>(
|
||||
&'a self,
|
||||
config: &'a CompositeTriggerConfig,
|
||||
agent_id: &'a str,
|
||||
_depth: Option<usize>,
|
||||
) -> Pin<Box<dyn std::future::Future<Output = Result<bool, String>> + 'a>> {
|
||||
Box::pin(async move {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for trigger in &config.triggers {
|
||||
let result = self.evaluate(trigger, agent_id).await?;
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
// Combine results based on combination mode
|
||||
let final_result = match config.combination {
|
||||
ConditionCombination::All => results.iter().all(|r| *r),
|
||||
ConditionCombination::Any => results.iter().any(|r| *r),
|
||||
ConditionCombination::None => results.iter().all(|r| !*r),
|
||||
};
|
||||
|
||||
Ok(final_result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
mod regex_validation {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_valid_simple_pattern() {
|
||||
let pattern = r"hello";
|
||||
assert!(compile_safe_regex(pattern).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_pattern_with_quantifiers() {
|
||||
let pattern = r"\d+";
|
||||
assert!(compile_safe_regex(pattern).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_pattern_with_groups() {
|
||||
let pattern = r"(foo|bar)\d{2,4}";
|
||||
assert!(compile_safe_regex(pattern).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_character_class() {
|
||||
let pattern = r"[a-zA-Z0-9_]+";
|
||||
assert!(compile_safe_regex(pattern).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_too_long() {
|
||||
let pattern = "a".repeat(501);
|
||||
let result = compile_safe_regex(&pattern);
|
||||
assert!(matches!(result, Err(RegexValidationError::TooLong { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_at_max_length() {
|
||||
let pattern = "a".repeat(500);
|
||||
let result = compile_safe_regex(&pattern);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_quantifier_detection_simple() {
|
||||
// Classic ReDoS pattern: (a+)+
|
||||
// Our implementation detects this as dangerous
|
||||
let pattern = r"(a+)+";
|
||||
let result = validate_regex_pattern(pattern);
|
||||
assert!(
|
||||
matches!(result, Err(RegexValidationError::DangerousPattern(_))),
|
||||
"Expected nested quantifier pattern to be detected as dangerous"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deeply_nested_groups() {
|
||||
// Create a pattern with too many nested groups
|
||||
let pattern = "(".repeat(15) + &"a".repeat(10) + &")".repeat(15);
|
||||
let result = compile_safe_regex(&pattern);
|
||||
assert!(matches!(result, Err(RegexValidationError::TooDeeplyNested { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reasonably_nested_groups() {
|
||||
// Pattern with acceptable nesting
|
||||
let pattern = "(((foo|bar)))";
|
||||
let result = compile_safe_regex(pattern);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_regex_syntax() {
|
||||
let pattern = r"[unclosed";
|
||||
let result = compile_safe_regex(pattern);
|
||||
assert!(matches!(result, Err(RegexValidationError::InvalidSyntax(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escaped_characters_in_pattern() {
|
||||
let pattern = r"\[hello\]";
|
||||
let result = compile_safe_regex(pattern);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_valid_pattern() {
|
||||
// Email-like pattern (simplified)
|
||||
let pattern = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}";
|
||||
let result = compile_safe_regex(pattern);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
mod nesting_depth_calculation {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_no_nesting() {
|
||||
assert_eq!(calculate_nesting_depth("abc"), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_group() {
|
||||
assert_eq!(calculate_nesting_depth("(abc)"), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_groups() {
|
||||
assert_eq!(calculate_nesting_depth("((abc))"), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_character_class() {
|
||||
assert_eq!(calculate_nesting_depth("[abc]"), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_nesting() {
|
||||
assert_eq!(calculate_nesting_depth("([a-z]+)"), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escaped_parens() {
|
||||
// Escaped parens shouldn't count toward nesting
|
||||
assert_eq!(calculate_nesting_depth(r"\(abc\)"), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_groups_same_level() {
|
||||
assert_eq!(calculate_nesting_depth("(abc)(def)"), 1);
|
||||
}
|
||||
}
|
||||
|
||||
mod dangerous_pattern_detection {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simple_quantifier_not_dangerous() {
|
||||
assert!(!contains_dangerous_redos_pattern(r"a+"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_group_not_dangerous() {
|
||||
assert!(!contains_dangerous_redos_pattern(r"(abc)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantified_group_not_dangerous() {
|
||||
assert!(!contains_dangerous_redos_pattern(r"(abc)+"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_alternation_not_dangerous() {
|
||||
assert!(!contains_dangerous_redos_pattern(r"(a|b)+"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
153
desktop/src-tauri/src/intelligence_hooks.rs
Normal file
153
desktop/src-tauri/src/intelligence_hooks.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
//! Intelligence Hooks - Pre/Post conversation integration
|
||||
//!
|
||||
//! Bridges the intelligence layer modules (identity, memory, heartbeat, reflection)
|
||||
//! into the kernel's chat flow at the Tauri command boundary.
|
||||
//!
|
||||
//! Architecture: kernel_commands.rs → intelligence_hooks → intelligence modules → Viking/Kernel
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
use crate::intelligence::identity::IdentityManagerState;
|
||||
use crate::intelligence::heartbeat::HeartbeatEngineState;
|
||||
use crate::intelligence::reflection::ReflectionEngineState;
|
||||
|
||||
/// Run pre-conversation intelligence hooks
|
||||
///
|
||||
/// 1. Build memory context from VikingStorage (FTS5 + TF-IDF + Embedding)
|
||||
/// 2. Build identity-enhanced system prompt (SOUL.md + instructions)
|
||||
///
|
||||
/// Returns the enhanced system prompt that should be passed to the kernel.
|
||||
pub async fn pre_conversation_hook(
|
||||
agent_id: &str,
|
||||
user_message: &str,
|
||||
identity_state: &IdentityManagerState,
|
||||
) -> Result<String, String> {
|
||||
// Step 1: Build memory context from Viking storage
|
||||
let memory_context = build_memory_context(agent_id, user_message).await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Step 2: Build identity-enhanced system prompt
|
||||
let enhanced_prompt = build_identity_prompt(agent_id, &memory_context, identity_state)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(enhanced_prompt)
|
||||
}
|
||||
|
||||
/// Run post-conversation intelligence hooks
|
||||
///
|
||||
/// 1. Record interaction for heartbeat engine
|
||||
/// 2. Record conversation for reflection engine, trigger reflection if needed
|
||||
pub async fn post_conversation_hook(
|
||||
agent_id: &str,
|
||||
_heartbeat_state: &HeartbeatEngineState,
|
||||
reflection_state: &ReflectionEngineState,
|
||||
) {
|
||||
// Step 1: Record interaction for heartbeat
|
||||
crate::intelligence::heartbeat::record_interaction(agent_id);
|
||||
debug!("[intelligence_hooks] Recorded interaction for agent: {}", agent_id);
|
||||
|
||||
// Step 2: Record conversation for reflection
|
||||
// tokio::sync::Mutex::lock() returns MutexGuard directly (panics on poison)
|
||||
let mut engine = reflection_state.lock().await;
|
||||
|
||||
engine.record_conversation();
|
||||
debug!(
|
||||
"[intelligence_hooks] Conversation count updated for agent: {}",
|
||||
agent_id
|
||||
);
|
||||
|
||||
if engine.should_reflect() {
|
||||
debug!(
|
||||
"[intelligence_hooks] Reflection threshold reached for agent: {}",
|
||||
agent_id
|
||||
);
|
||||
let reflection_result = engine.reflect(agent_id, &[]);
|
||||
debug!(
|
||||
"[intelligence_hooks] Reflection completed: {} patterns, {} suggestions",
|
||||
reflection_result.patterns.len(),
|
||||
reflection_result.improvements.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Build memory context by searching VikingStorage for relevant memories
|
||||
async fn build_memory_context(
|
||||
agent_id: &str,
|
||||
user_message: &str,
|
||||
) -> Result<String, String> {
|
||||
// Try Viking storage (has FTS5 + TF-IDF + Embedding)
|
||||
let storage = crate::viking_commands::get_storage().await?;
|
||||
|
||||
// FindOptions from zclaw_growth
|
||||
let options = zclaw_growth::FindOptions {
|
||||
scope: Some(format!("agent://{}", agent_id)),
|
||||
limit: Some(8),
|
||||
min_similarity: Some(0.2),
|
||||
};
|
||||
|
||||
// find is on the VikingStorage trait — call via trait to dispatch correctly
|
||||
let results: Vec<zclaw_growth::MemoryEntry> =
|
||||
zclaw_growth::VikingStorage::find(storage.as_ref(), user_message, options)
|
||||
.await
|
||||
.map_err(|e| format!("Memory search failed: {}", e))?;
|
||||
|
||||
if results.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Format memories into context string
|
||||
let mut context = String::from("## 相关记忆\n\n");
|
||||
let mut token_estimate: usize = 0;
|
||||
let max_tokens: usize = 500;
|
||||
|
||||
for entry in &results {
|
||||
// Prefer overview (L1 summary) over full content
|
||||
// overview is Option<String> — use as_deref to get Option<&str>
|
||||
let overview_str = entry.overview.as_deref().unwrap_or("");
|
||||
let text = if !overview_str.is_empty() {
|
||||
overview_str
|
||||
} else {
|
||||
&entry.content
|
||||
};
|
||||
|
||||
// Truncate long entries
|
||||
let truncated = if text.len() > 100 {
|
||||
format!("{}...", &text[..100])
|
||||
} else {
|
||||
text.to_string()
|
||||
};
|
||||
|
||||
// Simple token estimate (~1.5 tokens per CJK char, ~0.25 per other)
|
||||
let tokens: usize = truncated.chars()
|
||||
.map(|c: char| if c.is_ascii() { 1 } else { 2 })
|
||||
.sum();
|
||||
|
||||
if token_estimate + tokens > max_tokens {
|
||||
break;
|
||||
}
|
||||
|
||||
context.push_str(&format!("- [{}] {}\n", entry.memory_type, truncated));
|
||||
token_estimate += tokens;
|
||||
}
|
||||
|
||||
Ok(context)
|
||||
}
|
||||
|
||||
/// Build identity-enhanced system prompt
|
||||
async fn build_identity_prompt(
|
||||
agent_id: &str,
|
||||
memory_context: &str,
|
||||
identity_state: &IdentityManagerState,
|
||||
) -> Result<String, String> {
|
||||
// IdentityManagerState is Arc<tokio::sync::Mutex<AgentIdentityManager>>
|
||||
// tokio::sync::Mutex::lock() returns MutexGuard directly
|
||||
let mut manager = identity_state.lock().await;
|
||||
|
||||
let prompt = manager.build_system_prompt(
|
||||
agent_id,
|
||||
if memory_context.is_empty() { None } else { Some(memory_context) },
|
||||
);
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
//! ZCLAW Kernel commands for Tauri
|
||||
//!
|
||||
//! These commands provide direct access to the internal ZCLAW Kernel,
|
||||
//! eliminating the need for external OpenFang process.
|
||||
//! eliminating the need for external ZCLAW process.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -416,6 +416,9 @@ pub struct StreamChatRequest {
|
||||
pub async fn agent_chat_stream(
|
||||
app: AppHandle,
|
||||
state: State<'_, KernelState>,
|
||||
identity_state: State<'_, crate::intelligence::IdentityManagerState>,
|
||||
heartbeat_state: State<'_, crate::intelligence::HeartbeatEngineState>,
|
||||
reflection_state: State<'_, crate::intelligence::ReflectionEngineState>,
|
||||
request: StreamChatRequest,
|
||||
) -> Result<(), String> {
|
||||
// Validate inputs
|
||||
@@ -428,7 +431,15 @@ pub async fn agent_chat_stream(
|
||||
.map_err(|_| "Invalid agent ID format".to_string())?;
|
||||
|
||||
let session_id = request.session_id.clone();
|
||||
let message = request.message;
|
||||
let agent_id_str = request.agent_id.clone();
|
||||
let message = request.message.clone();
|
||||
|
||||
// PRE-CONVERSATION: Build intelligence-enhanced system prompt
|
||||
let enhanced_prompt = crate::intelligence_hooks::pre_conversation_hook(
|
||||
&request.agent_id,
|
||||
&request.message,
|
||||
&identity_state,
|
||||
).await.unwrap_or_default();
|
||||
|
||||
// Get the streaming receiver while holding the lock, then release it
|
||||
let mut rx = {
|
||||
@@ -437,12 +448,18 @@ pub async fn agent_chat_stream(
|
||||
.ok_or_else(|| "Kernel not initialized. Call kernel_init first.".to_string())?;
|
||||
|
||||
// Start the stream - this spawns a background task
|
||||
kernel.send_message_stream(&id, message)
|
||||
// Use intelligence-enhanced system prompt if available
|
||||
let prompt_arg = if enhanced_prompt.is_empty() { None } else { Some(enhanced_prompt) };
|
||||
kernel.send_message_stream_with_prompt(&id, message, prompt_arg)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to start streaming: {}", e))?
|
||||
};
|
||||
// Lock is released here
|
||||
|
||||
// Clone Arc references before spawning (State<'_, T> borrows can't enter the spawn)
|
||||
let hb_state = heartbeat_state.inner().clone();
|
||||
let rf_state = reflection_state.inner().clone();
|
||||
|
||||
// Spawn a task to process stream events
|
||||
tokio::spawn(async move {
|
||||
use zclaw_runtime::LoopEvent;
|
||||
@@ -472,6 +489,12 @@ pub async fn agent_chat_stream(
|
||||
LoopEvent::Complete(result) => {
|
||||
println!("[agent_chat_stream] Complete: input_tokens={}, output_tokens={}",
|
||||
result.input_tokens, result.output_tokens);
|
||||
|
||||
// POST-CONVERSATION: record interaction + trigger reflection
|
||||
crate::intelligence_hooks::post_conversation_hook(
|
||||
&agent_id_str, &hb_state, &rf_state,
|
||||
).await;
|
||||
|
||||
StreamChatEvent::Complete {
|
||||
input_tokens: result.input_tokens,
|
||||
output_tokens: result.output_tokens,
|
||||
@@ -1078,3 +1101,155 @@ pub async fn approval_respond(
|
||||
kernel.respond_to_approval(&id, approved, reason).await
|
||||
.map_err(|e| format!("Failed to respond to approval: {}", e))
|
||||
}
|
||||
|
||||
/// Approve a hand execution (alias for approval_respond with approved=true)
|
||||
#[tauri::command]
|
||||
pub async fn hand_approve(
|
||||
state: State<'_, KernelState>,
|
||||
_hand_name: String,
|
||||
run_id: String,
|
||||
approved: bool,
|
||||
reason: Option<String>,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let kernel_lock = state.lock().await;
|
||||
let kernel = kernel_lock.as_ref()
|
||||
.ok_or_else(|| "Kernel not initialized".to_string())?;
|
||||
|
||||
// run_id maps to approval id
|
||||
kernel.respond_to_approval(&run_id, approved, reason).await
|
||||
.map_err(|e| format!("Failed to approve hand: {}", e))?;
|
||||
|
||||
Ok(serde_json::json!({ "status": if approved { "approved" } else { "rejected" } }))
|
||||
}
|
||||
|
||||
/// Cancel a hand execution
|
||||
#[tauri::command]
|
||||
pub async fn hand_cancel(
|
||||
state: State<'_, KernelState>,
|
||||
_hand_name: String,
|
||||
run_id: String,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let kernel_lock = state.lock().await;
|
||||
let kernel = kernel_lock.as_ref()
|
||||
.ok_or_else(|| "Kernel not initialized".to_string())?;
|
||||
|
||||
kernel.cancel_approval(&run_id).await
|
||||
.map_err(|e| format!("Failed to cancel hand: {}", e))?;
|
||||
|
||||
Ok(serde_json::json!({ "status": "cancelled" }))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Scheduled Task Commands
|
||||
// ============================================================
|
||||
|
||||
/// Request to create a scheduled task (maps to kernel trigger)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateScheduledTaskRequest {
|
||||
pub name: String,
|
||||
pub schedule: String,
|
||||
pub schedule_type: String,
|
||||
pub target: Option<ScheduledTaskTarget>,
|
||||
pub description: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Target for a scheduled task
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ScheduledTaskTarget {
|
||||
#[serde(rename = "type")]
|
||||
pub target_type: String,
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
/// Response for scheduled task creation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ScheduledTaskResponse {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub schedule: String,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
/// Create a scheduled task (backed by kernel TriggerManager)
|
||||
///
|
||||
/// Tasks are stored in the kernel's trigger system. Automatic execution
|
||||
/// requires a scheduler loop (not yet implemented in embedded kernel mode).
|
||||
#[tauri::command]
|
||||
pub async fn scheduled_task_create(
|
||||
state: State<'_, KernelState>,
|
||||
request: CreateScheduledTaskRequest,
|
||||
) -> Result<ScheduledTaskResponse, String> {
|
||||
let kernel_lock = state.lock().await;
|
||||
let kernel = kernel_lock.as_ref()
|
||||
.ok_or_else(|| "Kernel not initialized".to_string())?;
|
||||
|
||||
// Build TriggerConfig from request
|
||||
let trigger_type = match request.schedule_type.as_str() {
|
||||
"cron" | "schedule" => zclaw_hands::TriggerType::Schedule {
|
||||
cron: request.schedule.clone(),
|
||||
},
|
||||
"interval" => zclaw_hands::TriggerType::Schedule {
|
||||
cron: request.schedule.clone(), // interval as simplified cron
|
||||
},
|
||||
"once" => zclaw_hands::TriggerType::Schedule {
|
||||
cron: request.schedule.clone(),
|
||||
},
|
||||
_ => return Err(format!("Unsupported schedule type: {}", request.schedule_type)),
|
||||
};
|
||||
|
||||
let target_id = request.target.as_ref().map(|t| t.id.clone()).unwrap_or_default();
|
||||
let task_id = format!("sched_{}", chrono::Utc::now().timestamp_millis());
|
||||
|
||||
let config = zclaw_hands::TriggerConfig {
|
||||
id: task_id.clone(),
|
||||
name: request.name.clone(),
|
||||
hand_id: target_id,
|
||||
trigger_type,
|
||||
enabled: request.enabled.unwrap_or(true),
|
||||
max_executions_per_hour: 60,
|
||||
};
|
||||
|
||||
let entry = kernel.create_trigger(config).await
|
||||
.map_err(|e| format!("Failed to create scheduled task: {}", e))?;
|
||||
|
||||
Ok(ScheduledTaskResponse {
|
||||
id: entry.config.id,
|
||||
name: entry.config.name,
|
||||
schedule: request.schedule,
|
||||
status: "active".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// List all scheduled tasks (kernel triggers of Schedule type)
|
||||
#[tauri::command]
|
||||
pub async fn scheduled_task_list(
|
||||
state: State<'_, KernelState>,
|
||||
) -> Result<Vec<ScheduledTaskResponse>, String> {
|
||||
let kernel_lock = state.lock().await;
|
||||
let kernel = kernel_lock.as_ref()
|
||||
.ok_or_else(|| "Kernel not initialized".to_string())?;
|
||||
|
||||
let triggers = kernel.list_triggers().await;
|
||||
let tasks: Vec<ScheduledTaskResponse> = triggers
|
||||
.into_iter()
|
||||
.filter(|t| matches!(t.config.trigger_type, zclaw_hands::TriggerType::Schedule { .. }))
|
||||
.map(|t| {
|
||||
let schedule = match t.config.trigger_type {
|
||||
zclaw_hands::TriggerType::Schedule { cron } => cron,
|
||||
_ => String::new(),
|
||||
};
|
||||
ScheduledTaskResponse {
|
||||
id: t.config.id,
|
||||
name: t.config.name,
|
||||
schedule,
|
||||
status: if t.config.enabled { "active".to_string() } else { "paused".to_string() },
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(tasks)
|
||||
}
|
||||
|
||||
@@ -15,5 +15,6 @@ pub mod crypto;
|
||||
// Re-export main types for convenience
|
||||
pub use persistent::{
|
||||
PersistentMemory, PersistentMemoryStore, MemorySearchQuery, MemoryStats,
|
||||
generate_memory_id,
|
||||
generate_memory_id, configure_embedding_client, is_embedding_configured,
|
||||
EmbedFn,
|
||||
};
|
||||
|
||||
@@ -11,12 +11,69 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{Mutex, OnceCell};
|
||||
use uuid::Uuid;
|
||||
use tauri::Manager;
|
||||
use sqlx::{SqliteConnection, Connection, Row, sqlite::SqliteRow};
|
||||
use chrono::Utc;
|
||||
|
||||
/// Embedding function type: text -> vector of f32
|
||||
pub type EmbedFn = Arc<dyn Fn(&str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<f32>, String>> + Send>> + Send + Sync>;
|
||||
|
||||
/// Global embedding function for PersistentMemoryStore
|
||||
static EMBEDDING_FN: OnceCell<EmbedFn> = OnceCell::const_new();
|
||||
|
||||
/// Configure the global embedding function for memory search
|
||||
pub fn configure_embedding_client(fn_impl: EmbedFn) {
|
||||
let _ = EMBEDDING_FN.set(fn_impl);
|
||||
tracing::info!("[PersistentMemoryStore] Embedding client configured");
|
||||
}
|
||||
|
||||
/// Check if embedding is available
|
||||
pub fn is_embedding_configured() -> bool {
|
||||
EMBEDDING_FN.get().is_some()
|
||||
}
|
||||
|
||||
/// Generate embedding for text using the configured client
|
||||
async fn embed_text(text: &str) -> Result<Vec<f32>, String> {
|
||||
let client = EMBEDDING_FN.get()
|
||||
.ok_or_else(|| "Embedding client not configured".to_string())?;
|
||||
client(text).await
|
||||
}
|
||||
|
||||
/// Deserialize f32 vector from BLOB (4 bytes per f32, little-endian)
|
||||
fn deserialize_embedding(blob: &[u8]) -> Vec<f32> {
|
||||
blob.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Serialize f32 vector to BLOB
|
||||
fn serialize_embedding(vec: &[f32]) -> Vec<u8> {
|
||||
let mut bytes = Vec::with_capacity(vec.len() * 4);
|
||||
for val in vec {
|
||||
bytes.extend_from_slice(&val.to_le_bytes());
|
||||
}
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.is_empty() || b.is_empty() || a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
let mut dot = 0.0f32;
|
||||
let mut norm_a = 0.0f32;
|
||||
let mut norm_b = 0.0f32;
|
||||
for i in 0..a.len() {
|
||||
dot += a[i] * b[i];
|
||||
norm_a += a[i] * a[i];
|
||||
norm_b += b[i] * b[i];
|
||||
}
|
||||
let denom = (norm_a * norm_b).sqrt();
|
||||
if denom == 0.0 { 0.0 } else { (dot / denom).clamp(0.0, 1.0) }
|
||||
}
|
||||
|
||||
/// Memory entry stored in SQLite
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersistentMemory {
|
||||
@@ -32,6 +89,7 @@ pub struct PersistentMemory {
|
||||
pub last_accessed_at: String,
|
||||
pub access_count: i32,
|
||||
pub embedding: Option<Vec<u8>>, // Vector embedding for semantic search
|
||||
pub overview: Option<String>, // L1 summary (1-2 sentences, ~200 tokens)
|
||||
}
|
||||
|
||||
// Manual implementation of FromRow since sqlx::FromRow derive has issues with Option<Vec<u8>>
|
||||
@@ -50,12 +108,13 @@ impl<'r> sqlx::FromRow<'r, SqliteRow> for PersistentMemory {
|
||||
last_accessed_at: row.try_get("last_accessed_at")?,
|
||||
access_count: row.try_get("access_count")?,
|
||||
embedding: row.try_get("embedding")?,
|
||||
overview: row.try_get("overview").ok(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory search options
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MemorySearchQuery {
|
||||
pub agent_id: Option<String>,
|
||||
pub memory_type: Option<String>,
|
||||
@@ -149,11 +208,34 @@ impl PersistentMemoryStore {
|
||||
.await
|
||||
.map_err(|e| format!("Failed to create schema: {}", e))?;
|
||||
|
||||
// Migration: add overview column (L1 summary)
|
||||
let _ = sqlx::query("ALTER TABLE memories ADD COLUMN overview TEXT")
|
||||
.execute(&mut *conn)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store a new memory
|
||||
pub async fn store(&self, memory: &PersistentMemory) -> Result<(), String> {
|
||||
// Generate embedding if client is configured and memory doesn't have one
|
||||
let embedding = if memory.embedding.is_some() {
|
||||
memory.embedding.clone()
|
||||
} else if is_embedding_configured() {
|
||||
match embed_text(&memory.content).await {
|
||||
Ok(vec) => {
|
||||
tracing::debug!("[PersistentMemoryStore] Generated embedding for {} ({} dims)", memory.id, vec.len());
|
||||
Some(serialize_embedding(&vec))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("[PersistentMemoryStore] Embedding generation failed: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut conn = self.conn.lock().await;
|
||||
|
||||
sqlx::query(
|
||||
@@ -161,8 +243,8 @@ impl PersistentMemoryStore {
|
||||
INSERT INTO memories (
|
||||
id, agent_id, memory_type, content, importance, source,
|
||||
tags, conversation_id, created_at, last_accessed_at,
|
||||
access_count, embedding
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
access_count, embedding, overview
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&memory.id)
|
||||
@@ -176,7 +258,8 @@ impl PersistentMemoryStore {
|
||||
.bind(&memory.created_at)
|
||||
.bind(&memory.last_accessed_at)
|
||||
.bind(memory.access_count)
|
||||
.bind(&memory.embedding)
|
||||
.bind(&embedding)
|
||||
.bind(&memory.overview)
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to store memory: {}", e))?;
|
||||
@@ -212,7 +295,7 @@ impl PersistentMemoryStore {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Search memories with simple query
|
||||
/// Search memories with semantic ranking when embeddings are available
|
||||
pub async fn search(&self, query: MemorySearchQuery) -> Result<Vec<PersistentMemory>, String> {
|
||||
let mut conn = self.conn.lock().await;
|
||||
|
||||
@@ -239,11 +322,14 @@ impl PersistentMemoryStore {
|
||||
params.push(format!("%{}%", query_text));
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY created_at DESC");
|
||||
// When using embedding ranking, fetch more candidates
|
||||
let effective_limit = if query.query.is_some() && is_embedding_configured() {
|
||||
query.limit.unwrap_or(50).max(20) // Fetch more for re-ranking
|
||||
} else {
|
||||
query.limit.unwrap_or(50)
|
||||
};
|
||||
|
||||
if let Some(limit) = query.limit {
|
||||
sql.push_str(&format!(" LIMIT {}", limit));
|
||||
}
|
||||
sql.push_str(&format!(" LIMIT {}", effective_limit));
|
||||
|
||||
if let Some(offset) = query.offset {
|
||||
sql.push_str(&format!(" OFFSET {}", offset));
|
||||
@@ -255,11 +341,41 @@ impl PersistentMemoryStore {
|
||||
query_builder = query_builder.bind(param);
|
||||
}
|
||||
|
||||
let results = query_builder
|
||||
let mut results = query_builder
|
||||
.fetch_all(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to search memories: {}", e))?;
|
||||
|
||||
// Apply semantic ranking if query and embedding are available
|
||||
if let Some(query_text) = &query.query {
|
||||
if is_embedding_configured() {
|
||||
if let Ok(query_embedding) = embed_text(query_text).await {
|
||||
// Score each result by cosine similarity
|
||||
let mut scored: Vec<(f32, PersistentMemory)> = results
|
||||
.into_iter()
|
||||
.map(|mem| {
|
||||
let score = mem.embedding.as_ref()
|
||||
.map(|blob| {
|
||||
let vec = deserialize_embedding(blob);
|
||||
cosine_similarity(&query_embedding, &vec)
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
(score, mem)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by score descending
|
||||
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Apply the original limit
|
||||
results = scored.into_iter()
|
||||
.take(query.limit.unwrap_or(20))
|
||||
.map(|(_, mem)| mem)
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! Phase 1 of Intelligence Layer Migration:
|
||||
//! Provides frontend API for memory storage and retrieval
|
||||
|
||||
use crate::memory::{PersistentMemory, PersistentMemoryStore, MemorySearchQuery, MemoryStats, generate_memory_id};
|
||||
use crate::memory::{PersistentMemory, PersistentMemoryStore, MemorySearchQuery, MemoryStats, generate_memory_id, configure_embedding_client, is_embedding_configured, EmbedFn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tauri::{AppHandle, State};
|
||||
@@ -52,6 +52,9 @@ pub async fn memory_init(
|
||||
}
|
||||
|
||||
/// Store a new memory
|
||||
///
|
||||
/// Writes to both PersistentMemoryStore (backward compat) and SqliteStorage (FTS5+Embedding).
|
||||
/// SqliteStorage write failure is logged but does not block the operation.
|
||||
#[tauri::command]
|
||||
pub async fn memory_store(
|
||||
entry: MemoryEntryInput,
|
||||
@@ -64,28 +67,61 @@ pub async fn memory_store(
|
||||
.ok_or_else(|| "Memory store not initialized. Call memory_init first.".to_string())?;
|
||||
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let id = generate_memory_id();
|
||||
let memory = PersistentMemory {
|
||||
id: generate_memory_id(),
|
||||
agent_id: entry.agent_id,
|
||||
memory_type: entry.memory_type,
|
||||
content: entry.content,
|
||||
id: id.clone(),
|
||||
agent_id: entry.agent_id.clone(),
|
||||
memory_type: entry.memory_type.clone(),
|
||||
content: entry.content.clone(),
|
||||
importance: entry.importance.unwrap_or(5),
|
||||
source: entry.source.unwrap_or_else(|| "auto".to_string()),
|
||||
tags: serde_json::to_string(&entry.tags.unwrap_or_default())
|
||||
tags: serde_json::to_string(&entry.tags.clone().unwrap_or_default())
|
||||
.unwrap_or_else(|_| "[]".to_string()),
|
||||
conversation_id: entry.conversation_id,
|
||||
conversation_id: entry.conversation_id.clone(),
|
||||
created_at: now.clone(),
|
||||
last_accessed_at: now,
|
||||
access_count: 0,
|
||||
embedding: None,
|
||||
overview: None,
|
||||
};
|
||||
|
||||
let id = memory.id.clone();
|
||||
// Write to PersistentMemoryStore (primary)
|
||||
store.store(&memory).await?;
|
||||
|
||||
// Also write to SqliteStorage via VikingStorage for FTS5 + Embedding search
|
||||
if let Ok(storage) = crate::viking_commands::get_storage().await {
|
||||
let memory_type = parse_memory_type(&entry.memory_type);
|
||||
let keywords = entry.tags.unwrap_or_default();
|
||||
|
||||
let viking_entry = zclaw_growth::MemoryEntry::new(
|
||||
&entry.agent_id,
|
||||
memory_type,
|
||||
&entry.memory_type,
|
||||
entry.content,
|
||||
)
|
||||
.with_importance(entry.importance.unwrap_or(5) as u8)
|
||||
.with_keywords(keywords);
|
||||
|
||||
match zclaw_growth::VikingStorage::store(storage.as_ref(), &viking_entry).await {
|
||||
Ok(()) => tracing::debug!("[memory_store] Also stored in SqliteStorage"),
|
||||
Err(e) => tracing::warn!("[memory_store] SqliteStorage write failed (non-blocking): {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Parse a string memory_type into zclaw_growth::MemoryType
|
||||
fn parse_memory_type(type_str: &str) -> zclaw_growth::MemoryType {
|
||||
match type_str.to_lowercase().as_str() {
|
||||
"preference" => zclaw_growth::MemoryType::Preference,
|
||||
"knowledge" | "fact" | "task" | "todo" | "lesson" | "event" => zclaw_growth::MemoryType::Knowledge,
|
||||
"skill" | "experience" => zclaw_growth::MemoryType::Experience,
|
||||
"session" | "conversation" => zclaw_growth::MemoryType::Session,
|
||||
_ => zclaw_growth::MemoryType::Knowledge,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a memory by ID
|
||||
#[tauri::command]
|
||||
pub async fn memory_get(
|
||||
@@ -213,3 +249,223 @@ pub async fn memory_db_path(
|
||||
|
||||
Ok(store.path().to_string_lossy().to_string())
|
||||
}
|
||||
|
||||
/// Configure embedding for PersistentMemoryStore (chat memory search)
|
||||
/// This is called alongside viking_configure_embedding to enable vector search in chat flow
|
||||
#[tauri::command]
|
||||
pub async fn memory_configure_embedding(
|
||||
provider: String,
|
||||
api_key: String,
|
||||
model: Option<String>,
|
||||
endpoint: Option<String>,
|
||||
) -> Result<bool, String> {
|
||||
// Create an llm::EmbeddingClient and wrap it in Arc for the closure
|
||||
let config = crate::llm::EmbeddingConfig {
|
||||
provider,
|
||||
api_key,
|
||||
endpoint,
|
||||
model,
|
||||
};
|
||||
let client = std::sync::Arc::new(crate::llm::EmbeddingClient::new(config));
|
||||
|
||||
let embed_fn: EmbedFn = {
|
||||
let client = client.clone();
|
||||
Arc::new(move |text: &str| {
|
||||
let client = client.clone();
|
||||
let text = text.to_string();
|
||||
Box::pin(async move {
|
||||
let response = client.embed(&text).await?;
|
||||
Ok(response.embedding)
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
configure_embedding_client(embed_fn);
|
||||
|
||||
tracing::info!("[MemoryCommands] Embedding configured for PersistentMemoryStore");
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Check if embedding is configured for PersistentMemoryStore
|
||||
#[tauri::command]
|
||||
pub fn memory_is_embedding_configured() -> bool {
|
||||
is_embedding_configured()
|
||||
}
|
||||
|
||||
/// Build layered memory context for chat prompt injection
|
||||
///
|
||||
/// Uses SqliteStorage (FTS5 + TF-IDF + Embedding) for high-quality semantic search,
|
||||
/// with fallback to PersistentMemoryStore if Viking storage is unavailable.
|
||||
///
|
||||
/// Performs L0→L1→L2 progressive loading:
|
||||
/// - L0: Search all matching memories (vector similarity when available)
|
||||
/// - L1: Use overview/summary when available, fall back to truncated content
|
||||
/// - L2: Full content only for top-ranked items
|
||||
#[tauri::command]
|
||||
pub async fn memory_build_context(
|
||||
agent_id: String,
|
||||
query: String,
|
||||
max_tokens: Option<usize>,
|
||||
state: State<'_, MemoryStoreState>,
|
||||
) -> Result<BuildContextResult, String> {
|
||||
let budget = max_tokens.unwrap_or(500);
|
||||
|
||||
// Try SqliteStorage (Viking) first — has FTS5 + TF-IDF + Embedding
|
||||
let entries = match crate::viking_commands::get_storage().await {
|
||||
Ok(storage) => {
|
||||
let options = zclaw_growth::FindOptions {
|
||||
scope: Some(format!("agent://{}", agent_id)),
|
||||
limit: Some((budget / 25).max(8)),
|
||||
min_similarity: Some(0.2),
|
||||
};
|
||||
|
||||
match zclaw_growth::VikingStorage::find(storage.as_ref(), &query, options).await {
|
||||
Ok(entries) => entries,
|
||||
Err(e) => {
|
||||
tracing::warn!("[memory_build_context] Viking search failed, falling back: {}", e);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::debug!("[memory_build_context] Viking storage unavailable, falling back to PersistentMemoryStore");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// If Viking found results, use them (they have overview/embedding ranking)
|
||||
if !entries.is_empty() {
|
||||
let mut used_tokens = 0;
|
||||
let mut items: Vec<String> = Vec::new();
|
||||
let mut memories_used = 0;
|
||||
|
||||
for entry in &entries {
|
||||
if used_tokens >= budget {
|
||||
break;
|
||||
}
|
||||
|
||||
// Prefer overview (L1 summary) over full content
|
||||
let overview_str = entry.overview.as_deref().unwrap_or("");
|
||||
let display_content = if !overview_str.is_empty() {
|
||||
overview_str.to_string()
|
||||
} else {
|
||||
truncate_for_l1(&entry.content)
|
||||
};
|
||||
|
||||
let item_tokens = estimate_tokens_text(&display_content);
|
||||
if used_tokens + item_tokens > budget {
|
||||
continue;
|
||||
}
|
||||
|
||||
items.push(format!("- [{}] {}", entry.memory_type, display_content));
|
||||
used_tokens += item_tokens;
|
||||
memories_used += 1;
|
||||
}
|
||||
|
||||
let system_prompt_addition = if items.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("## 相关记忆\n{}", items.join("\n"))
|
||||
};
|
||||
|
||||
return Ok(BuildContextResult {
|
||||
system_prompt_addition,
|
||||
total_tokens: used_tokens,
|
||||
memories_used,
|
||||
});
|
||||
}
|
||||
|
||||
// Fallback: PersistentMemoryStore (LIKE-based search)
|
||||
let state_guard = state.lock().await;
|
||||
let store = state_guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| "Memory store not initialized".to_string())?;
|
||||
|
||||
let limit = budget / 25;
|
||||
let search_query = MemorySearchQuery {
|
||||
agent_id: Some(agent_id.clone()),
|
||||
query: Some(query.clone()),
|
||||
limit: Some(limit.max(20)),
|
||||
min_importance: Some(3),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let memories = store.search(search_query).await?;
|
||||
|
||||
if memories.is_empty() {
|
||||
return Ok(BuildContextResult {
|
||||
system_prompt_addition: String::new(),
|
||||
total_tokens: 0,
|
||||
memories_used: 0,
|
||||
});
|
||||
}
|
||||
|
||||
// Build layered context with token budget
|
||||
let mut used_tokens = 0;
|
||||
let mut items: Vec<String> = Vec::new();
|
||||
let mut memories_used = 0;
|
||||
|
||||
for memory in &memories {
|
||||
if used_tokens >= budget {
|
||||
break;
|
||||
}
|
||||
|
||||
let display_content = if let Some(ref overview) = memory.overview {
|
||||
if !overview.is_empty() {
|
||||
overview.clone()
|
||||
} else {
|
||||
truncate_for_l1(&memory.content)
|
||||
}
|
||||
} else {
|
||||
truncate_for_l1(&memory.content)
|
||||
};
|
||||
|
||||
let item_tokens = estimate_tokens_text(&display_content);
|
||||
if used_tokens + item_tokens > budget {
|
||||
continue;
|
||||
}
|
||||
|
||||
items.push(format!("- [{}] {}", memory.memory_type, display_content));
|
||||
used_tokens += item_tokens;
|
||||
memories_used += 1;
|
||||
}
|
||||
|
||||
let system_prompt_addition = if items.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("## 相关记忆\n{}", items.join("\n"))
|
||||
};
|
||||
|
||||
Ok(BuildContextResult {
|
||||
system_prompt_addition,
|
||||
total_tokens: used_tokens,
|
||||
memories_used,
|
||||
})
|
||||
}
|
||||
|
||||
/// Result of building layered memory context
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct BuildContextResult {
|
||||
pub system_prompt_addition: String,
|
||||
pub total_tokens: usize,
|
||||
pub memories_used: usize,
|
||||
}
|
||||
|
||||
/// Truncate content for L1 overview display (~50 tokens)
|
||||
fn truncate_for_l1(content: &str) -> String {
|
||||
let char_limit = 100; // ~50 tokens for mixed CJK/ASCII
|
||||
if content.chars().count() <= char_limit {
|
||||
content.to_string()
|
||||
} else {
|
||||
let truncated: String = content.chars().take(char_limit).collect();
|
||||
format!("{}...", truncated)
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate token count for text
|
||||
fn estimate_tokens_text(text: &str) -> usize {
|
||||
let cjk_count = text.chars().filter(|c| ('\u{4E00}'..='\u{9FFF}').contains(c)).count();
|
||||
let other_count = text.chars().count() - cjk_count;
|
||||
(cjk_count as f32 * 1.5 + other_count as f32 * 0.4).ceil() as usize
|
||||
}
|
||||
|
||||
133
desktop/src-tauri/src/summarizer_adapter.rs
Normal file
133
desktop/src-tauri/src/summarizer_adapter.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
//! Summarizer Adapter - Bridges zclaw_growth::SummaryLlmDriver with Tauri LLM Client
|
||||
//!
|
||||
//! Implements the SummaryLlmDriver trait using the local LlmClient,
|
||||
//! enabling L0/L1 summary generation via the user's configured LLM.
|
||||
|
||||
use zclaw_growth::{MemoryEntry, SummaryLlmDriver, summarizer::{overview_prompt, abstract_prompt}};
|
||||
|
||||
/// Tauri-side implementation of SummaryLlmDriver using llm::LlmClient
|
||||
pub struct TauriSummaryDriver {
|
||||
endpoint: String,
|
||||
api_key: String,
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl TauriSummaryDriver {
|
||||
/// Create a new Tauri summary driver
|
||||
pub fn new(endpoint: String, api_key: String, model: Option<String>) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
api_key,
|
||||
model,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the driver is configured (has endpoint and api_key)
|
||||
pub fn is_configured(&self) -> bool {
|
||||
!self.endpoint.is_empty() && !self.api_key.is_empty()
|
||||
}
|
||||
|
||||
/// Call the LLM API with a simple prompt
|
||||
async fn call_llm(&self, prompt: String) -> Result<String, String> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let model = self.model.clone().unwrap_or_else(|| "glm-4-flash".to_string());
|
||||
|
||||
let request = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": [
|
||||
{ "role": "user", "content": prompt }
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 200,
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(format!("{}/chat/completions", self.endpoint))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Summary LLM request failed: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
return Err(format!("Summary LLM error {}: {}", status, body));
|
||||
}
|
||||
|
||||
let json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse summary response: {}", e))?;
|
||||
|
||||
json.get("choices")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("message"))
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| "Invalid summary LLM response format".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SummaryLlmDriver for TauriSummaryDriver {
|
||||
async fn generate_overview(&self, entry: &MemoryEntry) -> Result<String, String> {
|
||||
let prompt = overview_prompt(entry);
|
||||
self.call_llm(prompt).await
|
||||
}
|
||||
|
||||
async fn generate_abstract(&self, entry: &MemoryEntry) -> Result<String, String> {
|
||||
let prompt = abstract_prompt(entry);
|
||||
self.call_llm(prompt).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Global summary driver instance (lazy-initialized)
|
||||
static SUMMARY_DRIVER: tokio::sync::OnceCell<std::sync::Arc<TauriSummaryDriver>> =
|
||||
tokio::sync::OnceCell::const_new();
|
||||
|
||||
/// Configure the global summary driver
|
||||
pub fn configure_summary_driver(driver: TauriSummaryDriver) {
|
||||
let _ = SUMMARY_DRIVER.set(std::sync::Arc::new(driver));
|
||||
tracing::info!("[SummarizerAdapter] Summary driver configured");
|
||||
}
|
||||
|
||||
/// Check if summary driver is available
|
||||
pub fn is_summary_driver_configured() -> bool {
|
||||
SUMMARY_DRIVER
|
||||
.get()
|
||||
.map(|d| d.is_configured())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Get the global summary driver
|
||||
pub fn get_summary_driver() -> Option<std::sync::Arc<TauriSummaryDriver>> {
|
||||
SUMMARY_DRIVER.get().cloned()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zclaw_growth::MemoryType;
|
||||
|
||||
#[test]
|
||||
fn test_summary_driver_not_configured_by_default() {
|
||||
assert!(!is_summary_driver_configured());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary_driver_configure_and_check() {
|
||||
let driver = TauriSummaryDriver::new(
|
||||
"https://example.com/v1".to_string(),
|
||||
"test-key".to_string(),
|
||||
None,
|
||||
);
|
||||
assert!(driver.is_configured());
|
||||
|
||||
let empty_driver = TauriSummaryDriver::new(String::new(), String::new(), None);
|
||||
assert!(!empty_driver.is_configured());
|
||||
}
|
||||
}
|
||||
@@ -67,6 +67,13 @@ pub struct VikingAddResult {
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct EmbeddingConfigResult {
|
||||
pub provider: String,
|
||||
pub configured: bool,
|
||||
}
|
||||
|
||||
// === Global Storage Instance ===
|
||||
|
||||
/// Global storage instance
|
||||
@@ -100,12 +107,20 @@ pub async fn init_storage() -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the storage instance (public for use by other modules)
|
||||
/// Get the storage instance, initializing on first access if needed
|
||||
pub async fn get_storage() -> Result<Arc<SqliteStorage>, String> {
|
||||
if let Some(storage) = STORAGE.get() {
|
||||
return Ok(storage.clone());
|
||||
}
|
||||
|
||||
// Attempt lazy initialization
|
||||
tracing::info!("[VikingCommands] Storage not yet initialized, attempting lazy init...");
|
||||
init_storage().await?;
|
||||
|
||||
STORAGE
|
||||
.get()
|
||||
.cloned()
|
||||
.ok_or_else(|| "Storage not initialized. Call init_storage() first.".to_string())
|
||||
.ok_or_else(|| "Storage initialization failed. Check logs for details.".to_string())
|
||||
}
|
||||
|
||||
/// Get storage directory for status
|
||||
@@ -217,12 +232,24 @@ pub async fn viking_find(
|
||||
Ok(entries
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, entry)| VikingFindResult {
|
||||
uri: entry.uri,
|
||||
score: 1.0 - (i as f64 * 0.1), // Simple scoring based on rank
|
||||
content: entry.content,
|
||||
level: "L1".to_string(),
|
||||
overview: None,
|
||||
.map(|(i, entry)| {
|
||||
// Use overview (L1) when available, full content otherwise (L2)
|
||||
let (content, level, overview) = if let Some(ref ov) = entry.overview {
|
||||
if !ov.is_empty() {
|
||||
(ov.clone(), "L1".to_string(), None)
|
||||
} else {
|
||||
(entry.content.clone(), "L2".to_string(), None)
|
||||
}
|
||||
} else {
|
||||
(entry.content.clone(), "L2".to_string(), None)
|
||||
};
|
||||
VikingFindResult {
|
||||
uri: entry.uri,
|
||||
score: 1.0 - (i as f64 * 0.1), // Simple scoring based on rank
|
||||
content,
|
||||
level,
|
||||
overview,
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
@@ -309,7 +336,7 @@ pub async fn viking_ls(path: String) -> Result<Vec<VikingResource>, String> {
|
||||
|
||||
/// Read memory content
|
||||
#[tauri::command]
|
||||
pub async fn viking_read(uri: String, _level: Option<String>) -> Result<String, String> {
|
||||
pub async fn viking_read(uri: String, level: Option<String>) -> Result<String, String> {
|
||||
let storage = get_storage().await?;
|
||||
|
||||
let entry = storage
|
||||
@@ -318,7 +345,34 @@ pub async fn viking_read(uri: String, _level: Option<String>) -> Result<String,
|
||||
.map_err(|e| format!("Failed to read memory: {}", e))?;
|
||||
|
||||
match entry {
|
||||
Some(e) => Ok(e.content),
|
||||
Some(e) => {
|
||||
// Support level-based content retrieval
|
||||
let content = match level.as_deref() {
|
||||
Some("L0") | Some("l0") => {
|
||||
// L0: abstract_summary (keywords)
|
||||
e.abstract_summary
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(|| {
|
||||
// Fallback: first 50 chars of overview
|
||||
e.overview
|
||||
.as_ref()
|
||||
.map(|ov| ov.chars().take(50).collect())
|
||||
.unwrap_or_else(|| e.content.chars().take(50).collect())
|
||||
})
|
||||
}
|
||||
Some("L1") | Some("l1") => {
|
||||
// L1: overview (1-2 sentence summary)
|
||||
e.overview
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(|| truncate_text(&e.content, 200))
|
||||
}
|
||||
_ => {
|
||||
// L2 or default: full content
|
||||
e.content
|
||||
}
|
||||
};
|
||||
Ok(content)
|
||||
}
|
||||
None => Err(format!("Memory not found: {}", uri)),
|
||||
}
|
||||
}
|
||||
@@ -442,6 +496,16 @@ pub async fn viking_inject_prompt(
|
||||
|
||||
// === Helper Functions ===
|
||||
|
||||
/// Truncate text to approximately max_chars characters
|
||||
fn truncate_text(text: &str, max_chars: usize) -> String {
|
||||
if text.chars().count() <= max_chars {
|
||||
text.to_string()
|
||||
} else {
|
||||
let truncated: String = text.chars().take(max_chars).collect();
|
||||
format!("{}...", truncated)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse URI to extract components
|
||||
fn parse_uri(uri: &str) -> Result<(String, MemoryType, String), String> {
|
||||
// Expected format: agent://{agent_id}/{type}/{category}
|
||||
@@ -462,6 +526,136 @@ fn parse_uri(uri: &str) -> Result<(String, MemoryType, String), String> {
|
||||
Ok((agent_id, memory_type, category))
|
||||
}
|
||||
|
||||
/// Configure embedding for semantic memory search
|
||||
/// Configures both SqliteStorage (VikingPanel) and PersistentMemoryStore (chat flow)
|
||||
#[tauri::command]
|
||||
pub async fn viking_configure_embedding(
|
||||
provider: String,
|
||||
api_key: String,
|
||||
model: Option<String>,
|
||||
endpoint: Option<String>,
|
||||
) -> Result<EmbeddingConfigResult, String> {
|
||||
let storage = get_storage().await?;
|
||||
|
||||
// 1. Configure SqliteStorage (VikingPanel / VikingCommands)
|
||||
let config_viking = crate::llm::EmbeddingConfig {
|
||||
provider: provider.clone(),
|
||||
api_key: api_key.clone(),
|
||||
endpoint: endpoint.clone(),
|
||||
model: model.clone(),
|
||||
};
|
||||
|
||||
let client_viking = crate::llm::EmbeddingClient::new(config_viking);
|
||||
let adapter = crate::embedding_adapter::TauriEmbeddingAdapter::new(client_viking);
|
||||
|
||||
storage
|
||||
.configure_embedding(std::sync::Arc::new(adapter))
|
||||
.await
|
||||
.map_err(|e| format!("Failed to configure embedding: {}", e))?;
|
||||
|
||||
// 2. Configure PersistentMemoryStore (chat flow)
|
||||
let config_memory = crate::llm::EmbeddingConfig {
|
||||
provider: provider.clone(),
|
||||
api_key,
|
||||
endpoint,
|
||||
model,
|
||||
};
|
||||
let client_memory = std::sync::Arc::new(crate::llm::EmbeddingClient::new(config_memory));
|
||||
|
||||
let embed_fn: crate::memory::EmbedFn = {
|
||||
let client_arc = client_memory.clone();
|
||||
std::sync::Arc::new(move |text: &str| {
|
||||
let client = client_arc.clone();
|
||||
let text = text.to_string();
|
||||
Box::pin(async move {
|
||||
let response = client.embed(&text).await?;
|
||||
Ok(response.embedding)
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
crate::memory::configure_embedding_client(embed_fn);
|
||||
|
||||
tracing::info!("[VikingCommands] Embedding configured with provider: {} (both storage systems)", provider);
|
||||
|
||||
Ok(EmbeddingConfigResult {
|
||||
provider,
|
||||
configured: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure summary driver for L0/L1 auto-generation
|
||||
#[tauri::command]
|
||||
pub async fn viking_configure_summary_driver(
|
||||
endpoint: String,
|
||||
api_key: String,
|
||||
model: Option<String>,
|
||||
) -> Result<bool, String> {
|
||||
let driver = crate::summarizer_adapter::TauriSummaryDriver::new(endpoint, api_key, model);
|
||||
crate::summarizer_adapter::configure_summary_driver(driver);
|
||||
|
||||
tracing::info!("[VikingCommands] Summary driver configured");
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Store a memory and optionally generate L0/L1 summaries in the background
|
||||
#[tauri::command]
|
||||
pub async fn viking_store_with_summaries(
|
||||
uri: String,
|
||||
content: String,
|
||||
) -> Result<VikingAddResult, String> {
|
||||
let storage = get_storage().await?;
|
||||
let (agent_id, memory_type, category) = parse_uri(&uri)?;
|
||||
|
||||
let entry = MemoryEntry::new(&agent_id, memory_type, &category, content);
|
||||
|
||||
// Store the entry immediately (L2 full content)
|
||||
storage
|
||||
.store(&entry)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to store memory: {}", e))?;
|
||||
|
||||
// Background: generate L0/L1 summaries if driver is configured
|
||||
if crate::summarizer_adapter::is_summary_driver_configured() {
|
||||
let entry_uri = entry.uri.clone();
|
||||
let storage_clone = storage.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(driver) = crate::summarizer_adapter::get_summary_driver() {
|
||||
let (overview, abstract_summary) =
|
||||
zclaw_growth::summarizer::generate_summaries(driver.as_ref(), &entry).await;
|
||||
|
||||
if overview.is_some() || abstract_summary.is_some() {
|
||||
// Update the entry with summaries
|
||||
let updated = MemoryEntry {
|
||||
overview,
|
||||
abstract_summary,
|
||||
..entry
|
||||
};
|
||||
|
||||
if let Err(e) = storage_clone.store(&updated).await {
|
||||
tracing::debug!(
|
||||
"[VikingCommands] Failed to update summaries for {}: {}",
|
||||
entry_uri,
|
||||
e
|
||||
);
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"[VikingCommands] Updated L0/L1 summaries for {}",
|
||||
entry_uri
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(VikingAddResult {
|
||||
uri,
|
||||
status: "added".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
// === Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user