feat(hermes): implement intelligence pipeline — 4 chunks, 684 tests passing
Hermes Intelligence Pipeline closes breakpoints in ZCLAW's existing intelligence components with 4 self-contained modules: Chunk 1 — Self-improvement Loop: - ExperienceStore (zclaw-growth): FTS5+TF-IDF wrapper with scope prefix - ExperienceExtractor (desktop/intelligence): template-based extraction from successful proposals with implicit keyword detection Chunk 2 — User Modeling: - UserProfileStore (zclaw-memory): SQLite-backed structured profiles with industry/role/expertise/comm_style/recent_topics/pain_points - UserProfiler (desktop/intelligence): fact classification by category (Preference/Knowledge/Behavior) with profile summary formatting Chunk 3 — NL Cron Chinese Time Parser: - NlScheduleParser (zclaw-runtime): 6 pattern matchers for Chinese time expressions (每天/每周/工作日/间隔/每月/一次性) producing cron expressions - Period-aware hour adjustment (下午3点→15, 晚上8点→20) - Schedule intent detection + task description extraction Chunk 4 — Trajectory Compression: - TrajectoryStore (zclaw-memory): trajectory_events + compressed_trajectories - TrajectoryRecorderMiddleware (zclaw-runtime/middleware): priority 650, async non-blocking event recording via tokio::spawn - TrajectoryCompressor (desktop/intelligence): dedup, request classification, satisfaction detection, execution chain JSON Schema migrations: v2→v3 (user_profiles), v3→v4 (trajectory tables)
This commit is contained in:
356
crates/zclaw-growth/src/experience_store.rs
Normal file
356
crates/zclaw-growth/src/experience_store.rs
Normal file
@@ -0,0 +1,356 @@
|
||||
//! ExperienceStore — CRUD wrapper over VikingStorage for agent experiences.
|
||||
//!
|
||||
//! Stores structured experiences extracted from successful solution proposals
|
||||
//! using the scope prefix `agent://{agent_id}/experience/{pattern_hash}`.
|
||||
//! Leverages existing FTS5 + TF-IDF + embedding retrieval via VikingAdapter.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::types::{MemoryEntry, MemoryType};
|
||||
use crate::viking_adapter::{FindOptions, VikingAdapter};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Experience data model
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A structured experience record representing a solved pain point.
|
||||
///
|
||||
/// Stored as JSON content inside a VikingStorage `MemoryEntry` with
|
||||
/// `memory_type = Experience`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Experience {
|
||||
/// Unique experience identifier.
|
||||
pub id: String,
|
||||
/// Owning agent.
|
||||
pub agent_id: String,
|
||||
/// Short pattern describing the pain that was solved (e.g. "logistics export packaging").
|
||||
pub pain_pattern: String,
|
||||
/// Context in which the problem occurred.
|
||||
pub context: String,
|
||||
/// Ordered steps that resolved the problem.
|
||||
pub solution_steps: Vec<String>,
|
||||
/// Verbal outcome reported by the user.
|
||||
pub outcome: String,
|
||||
/// How many times this experience has been reused as a reference.
|
||||
pub reuse_count: u32,
|
||||
/// Timestamp of initial creation.
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// Timestamp of most recent reuse or update.
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl Experience {
|
||||
/// Create a new experience with the given fields.
|
||||
pub fn new(
|
||||
agent_id: &str,
|
||||
pain_pattern: &str,
|
||||
context: &str,
|
||||
solution_steps: Vec<String>,
|
||||
outcome: &str,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
agent_id: agent_id.to_string(),
|
||||
pain_pattern: pain_pattern.to_string(),
|
||||
context: context.to_string(),
|
||||
solution_steps,
|
||||
outcome: outcome.to_string(),
|
||||
reuse_count: 0,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Deterministic URI for this experience, keyed on a stable hash of the
|
||||
/// pain pattern so duplicate patterns overwrite the same entry.
|
||||
pub fn uri(&self) -> String {
|
||||
let hash = simple_hash(&self.pain_pattern);
|
||||
format!("agent://{}/experience/{}", self.agent_id, hash)
|
||||
}
|
||||
}
|
||||
|
||||
/// FNV-1a–inspired stable 8-hex-char hash. Good enough for deduplication;
|
||||
/// collisions are acceptable because the full `pain_pattern` is still stored.
|
||||
fn simple_hash(s: &str) -> String {
|
||||
let mut h: u32 = 2166136261;
|
||||
for b in s.as_bytes() {
|
||||
h ^= *b as u32;
|
||||
h = h.wrapping_mul(16777619);
|
||||
}
|
||||
format!("{:08x}", h)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExperienceStore
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// CRUD wrapper that persists [`Experience`] records through [`VikingAdapter`].
|
||||
pub struct ExperienceStore {
|
||||
viking: Arc<VikingAdapter>,
|
||||
}
|
||||
|
||||
impl ExperienceStore {
|
||||
/// Create a new store backed by the given VikingAdapter.
|
||||
pub fn new(viking: Arc<VikingAdapter>) -> Self {
|
||||
Self { viking }
|
||||
}
|
||||
|
||||
/// Store (or overwrite) an experience. The URI is derived from
|
||||
/// `agent_id + pain_pattern`, ensuring one experience per pattern.
|
||||
pub async fn store_experience(&self, exp: &Experience) -> zclaw_types::Result<()> {
|
||||
let uri = exp.uri();
|
||||
let content = serde_json::to_string(exp)?;
|
||||
let mut keywords = vec![exp.pain_pattern.clone()];
|
||||
keywords.extend(exp.solution_steps.iter().take(3).cloned());
|
||||
|
||||
let entry = MemoryEntry {
|
||||
uri,
|
||||
memory_type: MemoryType::Experience,
|
||||
content,
|
||||
keywords,
|
||||
importance: 8,
|
||||
access_count: 0,
|
||||
created_at: exp.created_at,
|
||||
last_accessed: exp.updated_at,
|
||||
overview: Some(exp.pain_pattern.clone()),
|
||||
abstract_summary: Some(exp.outcome.clone()),
|
||||
};
|
||||
|
||||
self.viking.store(&entry).await?;
|
||||
debug!("[ExperienceStore] Stored experience {} for agent {}", exp.id, exp.agent_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find experiences whose pain pattern matches the given query.
|
||||
pub async fn find_by_pattern(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
pattern_query: &str,
|
||||
) -> zclaw_types::Result<Vec<Experience>> {
|
||||
let scope = format!("agent://{}/experience/", agent_id);
|
||||
let opts = FindOptions {
|
||||
scope: Some(scope),
|
||||
limit: Some(10),
|
||||
min_similarity: None,
|
||||
};
|
||||
let entries = self.viking.find(pattern_query, opts).await?;
|
||||
let mut results = Vec::with_capacity(entries.len());
|
||||
for entry in entries {
|
||||
match serde_json::from_str::<Experience>(&entry.content) {
|
||||
Ok(exp) => results.push(exp),
|
||||
Err(e) => warn!("[ExperienceStore] Failed to deserialize experience at {}: {}", entry.uri, e),
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Return all experiences for a given agent.
|
||||
pub async fn find_by_agent(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
) -> zclaw_types::Result<Vec<Experience>> {
|
||||
let prefix = format!("agent://{}/experience/", agent_id);
|
||||
let entries = self.viking.find_by_prefix(&prefix).await?;
|
||||
let mut results = Vec::with_capacity(entries.len());
|
||||
for entry in entries {
|
||||
match serde_json::from_str::<Experience>(&entry.content) {
|
||||
Ok(exp) => results.push(exp),
|
||||
Err(e) => warn!("[ExperienceStore] Failed to deserialize experience at {}: {}", entry.uri, e),
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Increment the reuse counter for an existing experience.
|
||||
/// On failure, logs a warning but does **not** propagate the error so
|
||||
/// callers are never blocked.
|
||||
pub async fn increment_reuse(&self, exp: &Experience) {
|
||||
let mut updated = exp.clone();
|
||||
updated.reuse_count += 1;
|
||||
updated.updated_at = Utc::now();
|
||||
if let Err(e) = self.store_experience(&updated).await {
|
||||
warn!("[ExperienceStore] Failed to increment reuse for {}: {}", exp.id, e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete a single experience by its URI.
|
||||
pub async fn delete(&self, exp: &Experience) -> zclaw_types::Result<()> {
|
||||
let uri = exp.uri();
|
||||
self.viking.delete(&uri).await?;
|
||||
debug!("[ExperienceStore] Deleted experience {} for agent {}", exp.id, exp.agent_id);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_experience_new() {
|
||||
let exp = Experience::new(
|
||||
"agent-1",
|
||||
"logistics export packaging",
|
||||
"export packaging rejected by customs",
|
||||
vec!["check regulations".into(), "use approved materials".into()],
|
||||
"packaging passed customs",
|
||||
);
|
||||
assert!(!exp.id.is_empty());
|
||||
assert_eq!(exp.agent_id, "agent-1");
|
||||
assert_eq!(exp.solution_steps.len(), 2);
|
||||
assert_eq!(exp.reuse_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uri_deterministic() {
|
||||
let exp1 = Experience::new(
|
||||
"agent-1", "packaging issue", "ctx",
|
||||
vec!["step1".into()], "ok",
|
||||
);
|
||||
// Second experience with same agent + pattern should produce the same URI.
|
||||
let mut exp2 = exp1.clone();
|
||||
exp2.id = "different-id".to_string();
|
||||
assert_eq!(exp1.uri(), exp2.uri());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uri_differs_for_different_patterns() {
|
||||
let exp_a = Experience::new(
|
||||
"agent-1", "packaging issue", "ctx",
|
||||
vec!["step1".into()], "ok",
|
||||
);
|
||||
let exp_b = Experience::new(
|
||||
"agent-1", "compliance gap", "ctx",
|
||||
vec!["step1".into()], "ok",
|
||||
);
|
||||
assert_ne!(exp_a.uri(), exp_b.uri());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_hash_stability() {
|
||||
let h1 = simple_hash("hello world");
|
||||
let h2 = simple_hash("hello world");
|
||||
assert_eq!(h1, h2);
|
||||
assert_eq!(h1.len(), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_find_by_agent() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
let exp = Experience::new(
|
||||
"agent-42",
|
||||
"export document errors",
|
||||
"recurring mistakes in export docs",
|
||||
vec!["use template".into(), "auto-validate".into()],
|
||||
"no more errors",
|
||||
);
|
||||
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
|
||||
let found = store.find_by_agent("agent-42").await.unwrap();
|
||||
assert_eq!(found.len(), 1);
|
||||
assert_eq!(found[0].pain_pattern, "export document errors");
|
||||
assert_eq!(found[0].solution_steps.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_overwrites_same_pattern() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
let exp_v1 = Experience::new(
|
||||
"agent-1", "packaging", "v1",
|
||||
vec!["old step".into()], "ok",
|
||||
);
|
||||
store.store_experience(&exp_v1).await.unwrap();
|
||||
|
||||
let exp_v2 = Experience::new(
|
||||
"agent-1", "packaging", "v2 updated",
|
||||
vec!["new step".into()], "better",
|
||||
);
|
||||
// Force same URI by reusing the ID logic — same pattern → same URI.
|
||||
store.store_experience(&exp_v2).await.unwrap();
|
||||
|
||||
let found = store.find_by_agent("agent-1").await.unwrap();
|
||||
// Should be overwritten, not duplicated (same URI).
|
||||
assert_eq!(found.len(), 1);
|
||||
assert_eq!(found[0].context, "v2 updated");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_by_pattern() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
let exp = Experience::new(
|
||||
"agent-1",
|
||||
"logistics packaging compliance",
|
||||
"export compliance issues",
|
||||
vec!["check regulations".into()],
|
||||
"passed audit",
|
||||
);
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
|
||||
let found = store.find_by_pattern("agent-1", "packaging").await.unwrap();
|
||||
assert_eq!(found.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_reuse() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
let exp = Experience::new(
|
||||
"agent-1", "packaging", "ctx",
|
||||
vec!["step".into()], "ok",
|
||||
);
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
store.increment_reuse(&exp).await;
|
||||
|
||||
let found = store.find_by_agent("agent-1").await.unwrap();
|
||||
assert_eq!(found[0].reuse_count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_experience() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
let exp = Experience::new(
|
||||
"agent-1", "packaging", "ctx",
|
||||
vec!["step".into()], "ok",
|
||||
);
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
store.delete(&exp).await.unwrap();
|
||||
|
||||
let found = store.find_by_agent("agent-1").await.unwrap();
|
||||
assert!(found.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_by_agent_filters_other_agents() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
let exp_a = Experience::new("agent-a", "packaging", "ctx", vec!["s".into()], "ok");
|
||||
let exp_b = Experience::new("agent-b", "compliance", "ctx", vec!["s".into()], "ok");
|
||||
store.store_experience(&exp_a).await.unwrap();
|
||||
store.store_experience(&exp_b).await.unwrap();
|
||||
|
||||
let found_a = store.find_by_agent("agent-a").await.unwrap();
|
||||
assert_eq!(found_a.len(), 1);
|
||||
assert_eq!(found_a[0].pain_pattern, "packaging");
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,7 @@ pub mod viking_adapter;
|
||||
pub mod storage;
|
||||
pub mod retrieval;
|
||||
pub mod summarizer;
|
||||
pub mod experience_store;
|
||||
|
||||
// Re-export main types for convenience
|
||||
pub use types::{
|
||||
@@ -85,6 +86,7 @@ pub use injector::{InjectionFormat, PromptInjector};
|
||||
pub use tracker::{AgentMetadata, GrowthTracker, LearningEvent};
|
||||
pub use viking_adapter::{FindOptions, VikingAdapter, VikingLevel, VikingStorage};
|
||||
pub use storage::SqliteStorage;
|
||||
pub use experience_store::{Experience, ExperienceStore};
|
||||
pub use retrieval::{EmbeddingClient, MemoryCache, QueryAnalyzer, SemanticScorer};
|
||||
pub use summarizer::SummaryLlmDriver;
|
||||
|
||||
|
||||
@@ -6,8 +6,15 @@ mod store;
|
||||
mod session;
|
||||
mod schema;
|
||||
pub mod fact;
|
||||
pub mod user_profile_store;
|
||||
pub mod trajectory_store;
|
||||
|
||||
pub use store::*;
|
||||
pub use session::*;
|
||||
pub use schema::*;
|
||||
pub use fact::{Fact, FactCategory, ExtractedFactBatch};
|
||||
pub use user_profile_store::{UserProfileStore, UserProfile, Level, CommStyle};
|
||||
pub use trajectory_store::{
|
||||
TrajectoryEvent, TrajectoryStore, TrajectoryStepType,
|
||||
CompressedTrajectory, CompletionStatus, SatisfactionSignal,
|
||||
};
|
||||
|
||||
@@ -93,4 +93,47 @@ pub const MIGRATIONS: &[&str] = &[
|
||||
// v1→v2: persist runtime state and message count
|
||||
"ALTER TABLE agents ADD COLUMN state TEXT NOT NULL DEFAULT 'running'",
|
||||
"ALTER TABLE agents ADD COLUMN message_count INTEGER NOT NULL DEFAULT 0",
|
||||
// v2→v3: user profiles for structured user modeling
|
||||
"CREATE TABLE IF NOT EXISTS user_profiles (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
industry TEXT,
|
||||
role TEXT,
|
||||
expertise_level TEXT,
|
||||
communication_style TEXT,
|
||||
preferred_language TEXT DEFAULT 'zh-CN',
|
||||
recent_topics TEXT DEFAULT '[]',
|
||||
active_pain_points TEXT DEFAULT '[]',
|
||||
preferred_tools TEXT DEFAULT '[]',
|
||||
confidence REAL DEFAULT 0.0,
|
||||
updated_at TEXT NOT NULL
|
||||
)",
|
||||
// v3→v4: trajectory recording for tool-call chain analysis
|
||||
"CREATE TABLE IF NOT EXISTS trajectory_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
step_index INTEGER NOT NULL,
|
||||
step_type TEXT NOT NULL,
|
||||
input_summary TEXT,
|
||||
output_summary TEXT,
|
||||
duration_ms INTEGER DEFAULT 0,
|
||||
timestamp TEXT NOT NULL
|
||||
)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_trajectory_session ON trajectory_events(session_id)",
|
||||
"CREATE TABLE IF NOT EXISTS compressed_trajectories (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
request_type TEXT NOT NULL,
|
||||
tools_used TEXT,
|
||||
outcome TEXT NOT NULL,
|
||||
total_steps INTEGER DEFAULT 0,
|
||||
total_duration_ms INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
execution_chain TEXT NOT NULL,
|
||||
satisfaction_signal TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_ct_request_type ON compressed_trajectories(request_type)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_ct_outcome ON compressed_trajectories(outcome)",
|
||||
];
|
||||
|
||||
563
crates/zclaw-memory/src/trajectory_store.rs
Normal file
563
crates/zclaw-memory/src/trajectory_store.rs
Normal file
@@ -0,0 +1,563 @@
|
||||
//! Trajectory Store -- record and compress tool-call chains for analysis.
|
||||
//!
|
||||
//! Stores raw trajectory events (user requests, tool calls, LLM generations)
|
||||
//! and compressed trajectory summaries. Used by the Hermes Intelligence Pipeline
|
||||
//! to analyze agent behaviour patterns and improve routing over time.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::SqlitePool;
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Step type in a trajectory.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TrajectoryStepType {
|
||||
UserRequest,
|
||||
IntentClassification,
|
||||
SkillSelection,
|
||||
ToolExecution,
|
||||
LlmGeneration,
|
||||
UserFeedback,
|
||||
}
|
||||
|
||||
impl TrajectoryStepType {
|
||||
/// Serialize to the string stored in SQLite.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::UserRequest => "user_request",
|
||||
Self::IntentClassification => "intent_classification",
|
||||
Self::SkillSelection => "skill_selection",
|
||||
Self::ToolExecution => "tool_execution",
|
||||
Self::LlmGeneration => "llm_generation",
|
||||
Self::UserFeedback => "user_feedback",
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize from the SQLite string representation.
|
||||
pub fn from_str_lossy(s: &str) -> Self {
|
||||
match s {
|
||||
"user_request" => Self::UserRequest,
|
||||
"intent_classification" => Self::IntentClassification,
|
||||
"skill_selection" => Self::SkillSelection,
|
||||
"tool_execution" => Self::ToolExecution,
|
||||
"llm_generation" => Self::LlmGeneration,
|
||||
"user_feedback" => Self::UserFeedback,
|
||||
_ => Self::UserRequest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Single trajectory event.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrajectoryEvent {
|
||||
pub id: String,
|
||||
pub session_id: String,
|
||||
pub agent_id: String,
|
||||
pub step_index: usize,
|
||||
pub step_type: TrajectoryStepType,
|
||||
/// Summarised input (max 200 chars).
|
||||
pub input_summary: String,
|
||||
/// Summarised output (max 200 chars).
|
||||
pub output_summary: String,
|
||||
pub duration_ms: u64,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Satisfaction signal inferred from user feedback.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SatisfactionSignal {
|
||||
Positive,
|
||||
Negative,
|
||||
Neutral,
|
||||
}
|
||||
|
||||
impl SatisfactionSignal {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Positive => "positive",
|
||||
Self::Negative => "negative",
|
||||
Self::Neutral => "neutral",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str_lossy(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"positive" => Some(Self::Positive),
|
||||
"negative" => Some(Self::Negative),
|
||||
"neutral" => Some(Self::Neutral),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Completion status of a compressed trajectory.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum CompletionStatus {
|
||||
Success,
|
||||
Partial,
|
||||
Failed,
|
||||
Abandoned,
|
||||
}
|
||||
|
||||
impl CompletionStatus {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Success => "success",
|
||||
Self::Partial => "partial",
|
||||
Self::Failed => "failed",
|
||||
Self::Abandoned => "abandoned",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str_lossy(s: &str) -> Self {
|
||||
match s {
|
||||
"success" => Self::Success,
|
||||
"partial" => Self::Partial,
|
||||
"failed" => Self::Failed,
|
||||
"abandoned" => Self::Abandoned,
|
||||
_ => Self::Success,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compressed trajectory (generated at session end).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompressedTrajectory {
|
||||
pub id: String,
|
||||
pub session_id: String,
|
||||
pub agent_id: String,
|
||||
pub request_type: String,
|
||||
pub tools_used: Vec<String>,
|
||||
pub outcome: CompletionStatus,
|
||||
pub total_steps: usize,
|
||||
pub total_duration_ms: u64,
|
||||
pub total_tokens: u32,
|
||||
/// Serialised JSON execution chain for analysis.
|
||||
pub execution_chain: String,
|
||||
pub satisfaction_signal: Option<SatisfactionSignal>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Store
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Persistent store for trajectory events and compressed trajectories.
|
||||
pub struct TrajectoryStore {
|
||||
pool: SqlitePool,
|
||||
}
|
||||
|
||||
impl TrajectoryStore {
|
||||
/// Create a new `TrajectoryStore` backed by the given SQLite pool.
|
||||
pub fn new(pool: SqlitePool) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
/// Create the required tables. Idempotent -- safe to call on startup.
|
||||
pub async fn initialize_schema(&self) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS trajectory_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
step_index INTEGER NOT NULL,
|
||||
step_type TEXT NOT NULL,
|
||||
input_summary TEXT,
|
||||
output_summary TEXT,
|
||||
duration_ms INTEGER DEFAULT 0,
|
||||
timestamp TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_trajectory_session ON trajectory_events(session_id);
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS compressed_trajectories (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
request_type TEXT NOT NULL,
|
||||
tools_used TEXT,
|
||||
outcome TEXT NOT NULL,
|
||||
total_steps INTEGER DEFAULT 0,
|
||||
total_duration_ms INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
execution_chain TEXT NOT NULL,
|
||||
satisfaction_signal TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_ct_request_type ON compressed_trajectories(request_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_ct_outcome ON compressed_trajectories(outcome);
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Insert a raw trajectory event.
|
||||
pub async fn insert_event(&self, event: &TrajectoryEvent) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO trajectory_events
|
||||
(id, session_id, agent_id, step_index, step_type,
|
||||
input_summary, output_summary, duration_ms, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&event.id)
|
||||
.bind(&event.session_id)
|
||||
.bind(&event.agent_id)
|
||||
.bind(event.step_index as i64)
|
||||
.bind(event.step_type.as_str())
|
||||
.bind(&event.input_summary)
|
||||
.bind(&event.output_summary)
|
||||
.bind(event.duration_ms as i64)
|
||||
.bind(event.timestamp.to_rfc3339())
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::warn!("[TrajectoryStore] insert_event failed: {}", e);
|
||||
ZclawError::StorageError(e.to_string())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve all raw events for a session, ordered by step_index.
|
||||
pub async fn get_events_by_session(&self, session_id: &str) -> Result<Vec<TrajectoryEvent>> {
|
||||
let rows = sqlx::query_as::<_, (String, String, String, i64, String, Option<String>, Option<String>, Option<i64>, String)>(
|
||||
r#"
|
||||
SELECT id, session_id, agent_id, step_index, step_type,
|
||||
input_summary, output_summary, duration_ms, timestamp
|
||||
FROM trajectory_events
|
||||
WHERE session_id = ?
|
||||
ORDER BY step_index ASC
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
let mut events = Vec::with_capacity(rows.len());
|
||||
for (id, sid, aid, step_idx, stype, input_s, output_s, dur_ms, ts) in rows {
|
||||
let timestamp = DateTime::parse_from_rfc3339(&ts)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now());
|
||||
|
||||
events.push(TrajectoryEvent {
|
||||
id,
|
||||
session_id: sid,
|
||||
agent_id: aid,
|
||||
step_index: step_idx as usize,
|
||||
step_type: TrajectoryStepType::from_str_lossy(&stype),
|
||||
input_summary: input_s.unwrap_or_default(),
|
||||
output_summary: output_s.unwrap_or_default(),
|
||||
duration_ms: dur_ms.unwrap_or(0) as u64,
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
/// Insert a compressed trajectory.
|
||||
pub async fn insert_compressed(&self, trajectory: &CompressedTrajectory) -> Result<()> {
|
||||
let tools_json = serde_json::to_string(&trajectory.tools_used)
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO compressed_trajectories
|
||||
(id, session_id, agent_id, request_type, tools_used,
|
||||
outcome, total_steps, total_duration_ms, total_tokens,
|
||||
execution_chain, satisfaction_signal, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&trajectory.id)
|
||||
.bind(&trajectory.session_id)
|
||||
.bind(&trajectory.agent_id)
|
||||
.bind(&trajectory.request_type)
|
||||
.bind(&tools_json)
|
||||
.bind(trajectory.outcome.as_str())
|
||||
.bind(trajectory.total_steps as i64)
|
||||
.bind(trajectory.total_duration_ms as i64)
|
||||
.bind(trajectory.total_tokens as i64)
|
||||
.bind(&trajectory.execution_chain)
|
||||
.bind(trajectory.satisfaction_signal.map(|s| s.as_str()))
|
||||
.bind(trajectory.created_at.to_rfc3339())
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::warn!("[TrajectoryStore] insert_compressed failed: {}", e);
|
||||
ZclawError::StorageError(e.to_string())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve the compressed trajectory for a session, if any.
|
||||
pub async fn get_compressed_by_session(&self, session_id: &str) -> Result<Option<CompressedTrajectory>> {
|
||||
let row = sqlx::query_as::<_, (
|
||||
String, String, String, String, Option<String>,
|
||||
String, i64, i64, i64, String, Option<String>, String,
|
||||
)>(
|
||||
r#"
|
||||
SELECT id, session_id, agent_id, request_type, tools_used,
|
||||
outcome, total_steps, total_duration_ms, total_tokens,
|
||||
execution_chain, satisfaction_signal, created_at
|
||||
FROM compressed_trajectories
|
||||
WHERE session_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
match row {
|
||||
Some((id, sid, aid, req_type, tools_json, outcome_str, steps, dur_ms, tokens, chain, sat, created)) => {
|
||||
let tools_used: Vec<String> = tools_json
|
||||
.as_deref()
|
||||
.and_then(|j| serde_json::from_str(j).ok())
|
||||
.unwrap_or_default();
|
||||
|
||||
let timestamp = DateTime::parse_from_rfc3339(&created)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now());
|
||||
|
||||
Ok(Some(CompressedTrajectory {
|
||||
id,
|
||||
session_id: sid,
|
||||
agent_id: aid,
|
||||
request_type: req_type,
|
||||
tools_used,
|
||||
outcome: CompletionStatus::from_str_lossy(&outcome_str),
|
||||
total_steps: steps as usize,
|
||||
total_duration_ms: dur_ms as u64,
|
||||
total_tokens: tokens as u32,
|
||||
execution_chain: chain,
|
||||
satisfaction_signal: sat.as_deref().and_then(SatisfactionSignal::from_str_lossy),
|
||||
created_at: timestamp,
|
||||
}))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete raw trajectory events older than `days` days. Returns count deleted.
|
||||
pub async fn delete_events_older_than(&self, days: i64) -> Result<u64> {
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
DELETE FROM trajectory_events
|
||||
WHERE timestamp < datetime('now', ?)
|
||||
"#,
|
||||
)
|
||||
.bind(format!("-{} days", days))
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::warn!("[TrajectoryStore] delete_events_older_than failed: {}", e);
|
||||
ZclawError::StorageError(e.to_string())
|
||||
})?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Delete compressed trajectories older than `days` days. Returns count deleted.
|
||||
pub async fn delete_compressed_older_than(&self, days: i64) -> Result<u64> {
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
DELETE FROM compressed_trajectories
|
||||
WHERE created_at < datetime('now', ?)
|
||||
"#,
|
||||
)
|
||||
.bind(format!("-{} days", days))
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::warn!("[TrajectoryStore] delete_compressed_older_than failed: {}", e);
|
||||
ZclawError::StorageError(e.to_string())
|
||||
})?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
async fn test_store() -> TrajectoryStore {
|
||||
let pool = SqlitePool::connect("sqlite::memory:")
|
||||
.await
|
||||
.expect("in-memory pool");
|
||||
let store = TrajectoryStore::new(pool);
|
||||
store.initialize_schema().await.expect("schema init");
|
||||
store
|
||||
}
|
||||
|
||||
fn sample_event(index: usize) -> TrajectoryEvent {
|
||||
TrajectoryEvent {
|
||||
id: format!("evt-{}", index),
|
||||
session_id: "sess-1".to_string(),
|
||||
agent_id: "agent-1".to_string(),
|
||||
step_index: index,
|
||||
step_type: TrajectoryStepType::ToolExecution,
|
||||
input_summary: "search query".to_string(),
|
||||
output_summary: "3 results found".to_string(),
|
||||
duration_ms: 150,
|
||||
timestamp: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_and_get_events() {
|
||||
let store = test_store().await;
|
||||
|
||||
let e1 = sample_event(0);
|
||||
let e2 = TrajectoryEvent {
|
||||
id: "evt-1".to_string(),
|
||||
step_index: 1,
|
||||
step_type: TrajectoryStepType::LlmGeneration,
|
||||
..sample_event(0)
|
||||
};
|
||||
|
||||
store.insert_event(&e1).await.unwrap();
|
||||
store.insert_event(&e2).await.unwrap();
|
||||
|
||||
let events = store.get_events_by_session("sess-1").await.unwrap();
|
||||
assert_eq!(events.len(), 2);
|
||||
assert_eq!(events[0].step_index, 0);
|
||||
assert_eq!(events[1].step_index, 1);
|
||||
assert_eq!(events[0].step_type, TrajectoryStepType::ToolExecution);
|
||||
assert_eq!(events[1].step_type, TrajectoryStepType::LlmGeneration);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_events_empty_session() {
|
||||
let store = test_store().await;
|
||||
let events = store.get_events_by_session("nonexistent").await.unwrap();
|
||||
assert!(events.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_and_get_compressed() {
|
||||
let store = test_store().await;
|
||||
|
||||
let ct = CompressedTrajectory {
|
||||
id: "ct-1".to_string(),
|
||||
session_id: "sess-1".to_string(),
|
||||
agent_id: "agent-1".to_string(),
|
||||
request_type: "data_query".to_string(),
|
||||
tools_used: vec!["search".to_string(), "calculate".to_string()],
|
||||
outcome: CompletionStatus::Success,
|
||||
total_steps: 5,
|
||||
total_duration_ms: 1200,
|
||||
total_tokens: 350,
|
||||
execution_chain: r#"[{"step":0,"type":"tool_execution"}]"#.to_string(),
|
||||
satisfaction_signal: Some(SatisfactionSignal::Positive),
|
||||
created_at: Utc::now(),
|
||||
};
|
||||
|
||||
store.insert_compressed(&ct).await.unwrap();
|
||||
|
||||
let loaded = store.get_compressed_by_session("sess-1").await.unwrap();
|
||||
assert!(loaded.is_some());
|
||||
|
||||
let loaded = loaded.unwrap();
|
||||
assert_eq!(loaded.id, "ct-1");
|
||||
assert_eq!(loaded.request_type, "data_query");
|
||||
assert_eq!(loaded.tools_used.len(), 2);
|
||||
assert_eq!(loaded.outcome, CompletionStatus::Success);
|
||||
assert_eq!(loaded.satisfaction_signal, Some(SatisfactionSignal::Positive));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_compressed_nonexistent() {
|
||||
let store = test_store().await;
|
||||
let result = store.get_compressed_by_session("nonexistent").await.unwrap();
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step_type_roundtrip() {
|
||||
let all_types = [
|
||||
TrajectoryStepType::UserRequest,
|
||||
TrajectoryStepType::IntentClassification,
|
||||
TrajectoryStepType::SkillSelection,
|
||||
TrajectoryStepType::ToolExecution,
|
||||
TrajectoryStepType::LlmGeneration,
|
||||
TrajectoryStepType::UserFeedback,
|
||||
];
|
||||
|
||||
for st in all_types {
|
||||
assert_eq!(TrajectoryStepType::from_str_lossy(st.as_str()), st);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_satisfaction_signal_roundtrip() {
|
||||
let signals = [SatisfactionSignal::Positive, SatisfactionSignal::Negative, SatisfactionSignal::Neutral];
|
||||
for sig in signals {
|
||||
assert_eq!(SatisfactionSignal::from_str_lossy(sig.as_str()), Some(sig));
|
||||
}
|
||||
assert_eq!(SatisfactionSignal::from_str_lossy("bogus"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_completion_status_roundtrip() {
|
||||
let statuses = [CompletionStatus::Success, CompletionStatus::Partial, CompletionStatus::Failed, CompletionStatus::Abandoned];
|
||||
for s in statuses {
|
||||
assert_eq!(CompletionStatus::from_str_lossy(s.as_str()), s);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_events_older_than() {
|
||||
let store = test_store().await;
|
||||
|
||||
// Insert an event with a timestamp far in the past
|
||||
let old_event = TrajectoryEvent {
|
||||
id: "old-evt".to_string(),
|
||||
timestamp: Utc::now() - chrono::Duration::days(100),
|
||||
..sample_event(0)
|
||||
};
|
||||
store.insert_event(&old_event).await.unwrap();
|
||||
|
||||
// Insert a recent event
|
||||
let recent_event = TrajectoryEvent {
|
||||
id: "recent-evt".to_string(),
|
||||
step_index: 1,
|
||||
..sample_event(0)
|
||||
};
|
||||
store.insert_event(&recent_event).await.unwrap();
|
||||
|
||||
let deleted = store.delete_events_older_than(30).await.unwrap();
|
||||
assert_eq!(deleted, 1);
|
||||
|
||||
let remaining = store.get_events_by_session("sess-1").await.unwrap();
|
||||
assert_eq!(remaining.len(), 1);
|
||||
assert_eq!(remaining[0].id, "recent-evt");
|
||||
}
|
||||
}
|
||||
592
crates/zclaw-memory/src/user_profile_store.rs
Normal file
592
crates/zclaw-memory/src/user_profile_store.rs
Normal file
@@ -0,0 +1,592 @@
|
||||
//! User Profile Store — structured user modeling from conversation patterns.
|
||||
//!
|
||||
//! Maintains a single `UserProfile` per user (desktop uses "default_user")
|
||||
//! in a dedicated SQLite table. Vec fields (recent_topics, pain points,
|
||||
//! preferred_tools) are stored as JSON arrays and transparently
|
||||
//! (de)serialised on read/write.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Row;
|
||||
use sqlx::SqlitePool;
|
||||
use zclaw_types::Result;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Data types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Expertise level inferred from conversation patterns.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Level {
|
||||
Beginner,
|
||||
Intermediate,
|
||||
Expert,
|
||||
}
|
||||
|
||||
impl Level {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Level::Beginner => "beginner",
|
||||
Level::Intermediate => "intermediate",
|
||||
Level::Expert => "expert",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str_lossy(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"beginner" => Some(Level::Beginner),
|
||||
"intermediate" => Some(Level::Intermediate),
|
||||
"expert" => Some(Level::Expert),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Communication style preference.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum CommStyle {
|
||||
Concise,
|
||||
Detailed,
|
||||
Formal,
|
||||
Casual,
|
||||
}
|
||||
|
||||
impl CommStyle {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
CommStyle::Concise => "concise",
|
||||
CommStyle::Detailed => "detailed",
|
||||
CommStyle::Formal => "formal",
|
||||
CommStyle::Casual => "casual",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str_lossy(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"concise" => Some(CommStyle::Concise),
|
||||
"detailed" => Some(CommStyle::Detailed),
|
||||
"formal" => Some(CommStyle::Formal),
|
||||
"casual" => Some(CommStyle::Casual),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Structured user profile (one record per user).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserProfile {
|
||||
pub user_id: String,
|
||||
pub industry: Option<String>,
|
||||
pub role: Option<String>,
|
||||
pub expertise_level: Option<Level>,
|
||||
pub communication_style: Option<CommStyle>,
|
||||
pub preferred_language: String,
|
||||
pub recent_topics: Vec<String>,
|
||||
pub active_pain_points: Vec<String>,
|
||||
pub preferred_tools: Vec<String>,
|
||||
pub confidence: f32,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl UserProfile {
|
||||
/// Create a blank profile for the given user.
|
||||
pub fn blank(user_id: &str) -> Self {
|
||||
Self {
|
||||
user_id: user_id.to_string(),
|
||||
industry: None,
|
||||
role: None,
|
||||
expertise_level: None,
|
||||
communication_style: None,
|
||||
preferred_language: "zh-CN".to_string(),
|
||||
recent_topics: Vec::new(),
|
||||
active_pain_points: Vec::new(),
|
||||
preferred_tools: Vec::new(),
|
||||
confidence: 0.0,
|
||||
updated_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Default profile for single-user desktop mode ("default_user").
|
||||
pub fn default_profile() -> Self {
|
||||
Self::blank("default_user")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DDL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const PROFILE_DDL: &str = r#"
|
||||
CREATE TABLE IF NOT EXISTS user_profiles (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
industry TEXT,
|
||||
role TEXT,
|
||||
expertise_level TEXT,
|
||||
communication_style TEXT,
|
||||
preferred_language TEXT DEFAULT 'zh-CN',
|
||||
recent_topics TEXT DEFAULT '[]',
|
||||
active_pain_points TEXT DEFAULT '[]',
|
||||
preferred_tools TEXT DEFAULT '[]',
|
||||
confidence REAL DEFAULT 0.0,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
"#;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Row mapping
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn row_to_profile(row: &sqlx::sqlite::SqliteRow) -> Result<UserProfile> {
|
||||
let recent_topics_json: String = row.try_get("recent_topics").unwrap_or_else(|_| "[]".to_string());
|
||||
let pain_json: String = row.try_get("active_pain_points").unwrap_or_else(|_| "[]".to_string());
|
||||
let tools_json: String = row.try_get("preferred_tools").unwrap_or_else(|_| "[]".to_string());
|
||||
|
||||
let recent_topics: Vec<String> = serde_json::from_str(&recent_topics_json)?;
|
||||
let active_pain_points: Vec<String> = serde_json::from_str(&pain_json)?;
|
||||
let preferred_tools: Vec<String> = serde_json::from_str(&tools_json)?;
|
||||
|
||||
let expertise_str: Option<String> = row.try_get("expertise_level").unwrap_or(None);
|
||||
let comm_str: Option<String> = row.try_get("communication_style").unwrap_or(None);
|
||||
|
||||
let updated_at_str: String = row.try_get("updated_at").unwrap_or_else(|_| Utc::now().to_rfc3339());
|
||||
let updated_at = DateTime::parse_from_rfc3339(&updated_at_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now());
|
||||
|
||||
Ok(UserProfile {
|
||||
user_id: row.try_get("user_id").unwrap_or_default(),
|
||||
industry: row.try_get("industry").unwrap_or(None),
|
||||
role: row.try_get("role").unwrap_or(None),
|
||||
expertise_level: expertise_str.as_deref().and_then(Level::from_str_lossy),
|
||||
communication_style: comm_str.as_deref().and_then(CommStyle::from_str_lossy),
|
||||
preferred_language: row.try_get("preferred_language").unwrap_or_else(|_| "zh-CN".to_string()),
|
||||
recent_topics,
|
||||
active_pain_points,
|
||||
preferred_tools,
|
||||
confidence: row.try_get("confidence").unwrap_or(0.0),
|
||||
updated_at,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// UserProfileStore
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// SQLite-backed store for user profiles.
|
||||
pub struct UserProfileStore {
|
||||
pool: SqlitePool,
|
||||
}
|
||||
|
||||
impl UserProfileStore {
|
||||
/// Create a new store backed by the given connection pool.
|
||||
pub fn new(pool: SqlitePool) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
/// Create tables. Idempotent — safe to call on every startup.
|
||||
pub async fn initialize_schema(&self) -> Result<()> {
|
||||
sqlx::query(PROFILE_DDL)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fetch the profile for a user. Returns `None` when no row exists.
|
||||
pub async fn get(&self, user_id: &str) -> Result<Option<UserProfile>> {
|
||||
let row = sqlx::query(
|
||||
"SELECT user_id, industry, role, expertise_level, communication_style, \
|
||||
preferred_language, recent_topics, active_pain_points, preferred_tools, \
|
||||
confidence, updated_at \
|
||||
FROM user_profiles WHERE user_id = ?",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
match row {
|
||||
Some(r) => Ok(Some(row_to_profile(&r)?)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert or replace the full profile.
|
||||
pub async fn upsert(&self, profile: &UserProfile) -> Result<()> {
|
||||
let topics = serde_json::to_string(&profile.recent_topics)?;
|
||||
let pains = serde_json::to_string(&profile.active_pain_points)?;
|
||||
let tools = serde_json::to_string(&profile.preferred_tools)?;
|
||||
let expertise = profile.expertise_level.map(|l| l.as_str());
|
||||
let comm = profile.communication_style.map(|c| c.as_str());
|
||||
let updated = profile.updated_at.to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT OR REPLACE INTO user_profiles \
|
||||
(user_id, industry, role, expertise_level, communication_style, \
|
||||
preferred_language, recent_topics, active_pain_points, preferred_tools, \
|
||||
confidence, updated_at) \
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
)
|
||||
.bind(&profile.user_id)
|
||||
.bind(&profile.industry)
|
||||
.bind(&profile.role)
|
||||
.bind(expertise)
|
||||
.bind(comm)
|
||||
.bind(&profile.preferred_language)
|
||||
.bind(&topics)
|
||||
.bind(&pains)
|
||||
.bind(&tools)
|
||||
.bind(profile.confidence)
|
||||
.bind(&updated)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update a single scalar field by name.
|
||||
///
|
||||
/// `field` must be one of: industry, role, expertise_level,
|
||||
/// communication_style, preferred_language, confidence.
|
||||
/// Returns error for unrecognised field names (prevents SQL injection).
|
||||
pub async fn update_field(&self, user_id: &str, field: &str, value: &str) -> Result<()> {
|
||||
let sql = match field {
|
||||
"industry" => "UPDATE user_profiles SET industry = ?, updated_at = ? WHERE user_id = ?",
|
||||
"role" => "UPDATE user_profiles SET role = ?, updated_at = ? WHERE user_id = ?",
|
||||
"expertise_level" => {
|
||||
"UPDATE user_profiles SET expertise_level = ?, updated_at = ? WHERE user_id = ?"
|
||||
}
|
||||
"communication_style" => {
|
||||
"UPDATE user_profiles SET communication_style = ?, updated_at = ? WHERE user_id = ?"
|
||||
}
|
||||
"preferred_language" => {
|
||||
"UPDATE user_profiles SET preferred_language = ?, updated_at = ? WHERE user_id = ?"
|
||||
}
|
||||
"confidence" => {
|
||||
"UPDATE user_profiles SET confidence = ?, updated_at = ? WHERE user_id = ?"
|
||||
}
|
||||
_ => {
|
||||
return Err(zclaw_types::ZclawError::InvalidInput(format!(
|
||||
"Unknown profile field: {}",
|
||||
field
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
// confidence is REAL; parse the value string.
|
||||
if field == "confidence" {
|
||||
let f: f32 = value.parse().map_err(|_| {
|
||||
zclaw_types::ZclawError::InvalidInput(format!("Invalid confidence: {}", value))
|
||||
})?;
|
||||
sqlx::query(sql)
|
||||
.bind(f)
|
||||
.bind(&now)
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
|
||||
} else {
|
||||
sqlx::query(sql)
|
||||
.bind(value)
|
||||
.bind(&now)
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append a topic to `recent_topics`, trimming to `max_topics`.
|
||||
/// Creates a default profile row if none exists.
|
||||
pub async fn add_recent_topic(
|
||||
&self,
|
||||
user_id: &str,
|
||||
topic: &str,
|
||||
max_topics: usize,
|
||||
) -> Result<()> {
|
||||
let mut profile = self
|
||||
.get(user_id)
|
||||
.await?
|
||||
.unwrap_or_else(|| UserProfile::blank(user_id));
|
||||
|
||||
// Deduplicate: remove if already present, then push to front.
|
||||
profile.recent_topics.retain(|t| t != topic);
|
||||
profile.recent_topics.insert(0, topic.to_string());
|
||||
profile.recent_topics.truncate(max_topics);
|
||||
profile.updated_at = Utc::now();
|
||||
|
||||
self.upsert(&profile).await
|
||||
}
|
||||
|
||||
/// Append a pain point, trimming to `max_pains`.
|
||||
/// Creates a default profile row if none exists.
|
||||
pub async fn add_pain_point(
|
||||
&self,
|
||||
user_id: &str,
|
||||
pain: &str,
|
||||
max_pains: usize,
|
||||
) -> Result<()> {
|
||||
let mut profile = self
|
||||
.get(user_id)
|
||||
.await?
|
||||
.unwrap_or_else(|| UserProfile::blank(user_id));
|
||||
|
||||
profile.active_pain_points.retain(|p| p != pain);
|
||||
profile.active_pain_points.insert(0, pain.to_string());
|
||||
profile.active_pain_points.truncate(max_pains);
|
||||
profile.updated_at = Utc::now();
|
||||
|
||||
self.upsert(&profile).await
|
||||
}
|
||||
|
||||
/// Append a preferred tool, trimming to `max_tools`.
|
||||
/// Creates a default profile row if none exists.
|
||||
pub async fn add_preferred_tool(
|
||||
&self,
|
||||
user_id: &str,
|
||||
tool: &str,
|
||||
max_tools: usize,
|
||||
) -> Result<()> {
|
||||
let mut profile = self
|
||||
.get(user_id)
|
||||
.await?
|
||||
.unwrap_or_else(|| UserProfile::blank(user_id));
|
||||
|
||||
profile.preferred_tools.retain(|t| t != tool);
|
||||
profile.preferred_tools.insert(0, tool.to_string());
|
||||
profile.preferred_tools.truncate(max_tools);
|
||||
profile.updated_at = Utc::now();
|
||||
|
||||
self.upsert(&profile).await
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Helper: create an in-memory store with schema.
|
||||
async fn test_store() -> UserProfileStore {
|
||||
let pool = SqlitePool::connect("sqlite::memory:")
|
||||
.await
|
||||
.expect("in-memory pool");
|
||||
let store = UserProfileStore::new(pool);
|
||||
store.initialize_schema().await.expect("schema init");
|
||||
store
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialize_schema_idempotent() {
|
||||
let store = test_store().await;
|
||||
// Second call should succeed without error.
|
||||
store.initialize_schema().await.unwrap();
|
||||
store.initialize_schema().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_returns_none_for_missing() {
|
||||
let store = test_store().await;
|
||||
let profile = store.get("nonexistent").await.unwrap();
|
||||
assert!(profile.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upsert_and_get() {
|
||||
let store = test_store().await;
|
||||
let mut profile = UserProfile::blank("default_user");
|
||||
profile.industry = Some("healthcare".to_string());
|
||||
profile.role = Some("admin".to_string());
|
||||
profile.expertise_level = Some(Level::Intermediate);
|
||||
profile.communication_style = Some(CommStyle::Concise);
|
||||
profile.recent_topics = vec!["reporting".to_string(), "compliance".to_string()];
|
||||
profile.confidence = 0.65;
|
||||
|
||||
store.upsert(&profile).await.unwrap();
|
||||
|
||||
let loaded = store.get("default_user").await.unwrap().unwrap();
|
||||
assert_eq!(loaded.user_id, "default_user");
|
||||
assert_eq!(loaded.industry.as_deref(), Some("healthcare"));
|
||||
assert_eq!(loaded.role.as_deref(), Some("admin"));
|
||||
assert_eq!(loaded.expertise_level, Some(Level::Intermediate));
|
||||
assert_eq!(loaded.communication_style, Some(CommStyle::Concise));
|
||||
assert_eq!(loaded.recent_topics, vec!["reporting", "compliance"]);
|
||||
assert!((loaded.confidence - 0.65).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upsert_replaces_existing() {
|
||||
let store = test_store().await;
|
||||
let mut profile = UserProfile::blank("user1");
|
||||
profile.industry = Some("tech".to_string());
|
||||
store.upsert(&profile).await.unwrap();
|
||||
|
||||
profile.industry = Some("finance".to_string());
|
||||
store.upsert(&profile).await.unwrap();
|
||||
|
||||
let loaded = store.get("user1").await.unwrap().unwrap();
|
||||
assert_eq!(loaded.industry.as_deref(), Some("finance"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_field_scalar() {
|
||||
let store = test_store().await;
|
||||
let profile = UserProfile::blank("user2");
|
||||
store.upsert(&profile).await.unwrap();
|
||||
|
||||
store
|
||||
.update_field("user2", "industry", "education")
|
||||
.await
|
||||
.unwrap();
|
||||
store
|
||||
.update_field("user2", "role", "teacher")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let loaded = store.get("user2").await.unwrap().unwrap();
|
||||
assert_eq!(loaded.industry.as_deref(), Some("education"));
|
||||
assert_eq!(loaded.role.as_deref(), Some("teacher"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_field_confidence() {
|
||||
let store = test_store().await;
|
||||
let profile = UserProfile::blank("user3");
|
||||
store.upsert(&profile).await.unwrap();
|
||||
|
||||
store
|
||||
.update_field("user3", "confidence", "0.88")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let loaded = store.get("user3").await.unwrap().unwrap();
|
||||
assert!((loaded.confidence - 0.88).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_field_rejects_unknown() {
|
||||
let store = test_store().await;
|
||||
let result = store.update_field("user", "evil_column", "oops").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_recent_topic_auto_creates_profile() {
|
||||
let store = test_store().await;
|
||||
|
||||
// No profile exists yet.
|
||||
store
|
||||
.add_recent_topic("new_user", "data analysis", 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let loaded = store.get("new_user").await.unwrap().unwrap();
|
||||
assert_eq!(loaded.recent_topics, vec!["data analysis"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_recent_topic_dedup_and_trim() {
|
||||
let store = test_store().await;
|
||||
let profile = UserProfile::blank("user");
|
||||
store.upsert(&profile).await.unwrap();
|
||||
|
||||
store.add_recent_topic("user", "topic_a", 3).await.unwrap();
|
||||
store.add_recent_topic("user", "topic_b", 3).await.unwrap();
|
||||
store.add_recent_topic("user", "topic_c", 3).await.unwrap();
|
||||
// Duplicate — should move to front, not add.
|
||||
store.add_recent_topic("user", "topic_a", 3).await.unwrap();
|
||||
|
||||
let loaded = store.get("user").await.unwrap().unwrap();
|
||||
assert_eq!(
|
||||
loaded.recent_topics,
|
||||
vec!["topic_a", "topic_c", "topic_b"]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_pain_point_trim() {
|
||||
let store = test_store().await;
|
||||
|
||||
for i in 0..5 {
|
||||
store
|
||||
.add_pain_point("user", &format!("pain_{}", i), 3)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let loaded = store.get("user").await.unwrap().unwrap();
|
||||
assert_eq!(loaded.active_pain_points.len(), 3);
|
||||
// Most recent first.
|
||||
assert_eq!(loaded.active_pain_points[0], "pain_4");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_preferred_tool_trim() {
|
||||
let store = test_store().await;
|
||||
|
||||
store
|
||||
.add_preferred_tool("user", "python", 5)
|
||||
.await
|
||||
.unwrap();
|
||||
store
|
||||
.add_preferred_tool("user", "rust", 5)
|
||||
.await
|
||||
.unwrap();
|
||||
// Duplicate — moved to front.
|
||||
store
|
||||
.add_preferred_tool("user", "python", 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let loaded = store.get("user").await.unwrap().unwrap();
|
||||
assert_eq!(loaded.preferred_tools, vec!["python", "rust"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_level_round_trip() {
|
||||
for level in [Level::Beginner, Level::Intermediate, Level::Expert] {
|
||||
assert_eq!(Level::from_str_lossy(level.as_str()), Some(level));
|
||||
}
|
||||
assert_eq!(Level::from_str_lossy("unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_comm_style_round_trip() {
|
||||
for style in [
|
||||
CommStyle::Concise,
|
||||
CommStyle::Detailed,
|
||||
CommStyle::Formal,
|
||||
CommStyle::Casual,
|
||||
] {
|
||||
assert_eq!(CommStyle::from_str_lossy(style.as_str()), Some(style));
|
||||
}
|
||||
assert_eq!(CommStyle::from_str_lossy("unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_profile_serialization() {
|
||||
let mut p = UserProfile::blank("test_user");
|
||||
p.industry = Some("logistics".into());
|
||||
p.expertise_level = Some(Level::Expert);
|
||||
p.communication_style = Some(CommStyle::Detailed);
|
||||
p.recent_topics = vec!["exports".into(), "customs".into()];
|
||||
|
||||
let json = serde_json::to_string(&p).unwrap();
|
||||
let decoded: UserProfile = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(decoded.user_id, "test_user");
|
||||
assert_eq!(decoded.industry.as_deref(), Some("logistics"));
|
||||
assert_eq!(decoded.expertise_level, Some(Level::Expert));
|
||||
assert_eq!(decoded.communication_style, Some(CommStyle::Detailed));
|
||||
assert_eq!(decoded.recent_topics, vec!["exports", "customs"]);
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ pub mod growth;
|
||||
pub mod compaction;
|
||||
pub mod middleware;
|
||||
pub mod prompt;
|
||||
pub mod nl_schedule;
|
||||
|
||||
// Re-export main types
|
||||
pub use driver::{
|
||||
|
||||
@@ -278,3 +278,4 @@ pub mod title;
|
||||
pub mod token_calibration;
|
||||
pub mod tool_error;
|
||||
pub mod tool_output_guard;
|
||||
pub mod trajectory_recorder;
|
||||
|
||||
231
crates/zclaw-runtime/src/middleware/trajectory_recorder.rs
Normal file
231
crates/zclaw-runtime/src/middleware/trajectory_recorder.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
//! Trajectory Recorder Middleware — records tool-call chains for analysis.
|
||||
//!
|
||||
//! Priority 650 (telemetry range: after business middleware at 400-599,
|
||||
//! before token_calibration at 700). Records events asynchronously via
|
||||
//! `tokio::spawn` so the main conversation flow is never blocked.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tokio::sync::RwLock;
|
||||
use zclaw_memory::trajectory_store::{
|
||||
TrajectoryEvent, TrajectoryStepType, TrajectoryStore,
|
||||
};
|
||||
use zclaw_types::{Result, SessionId};
|
||||
use crate::driver::ContentBlock;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Step counter per session
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Tracks step indices per session so events are ordered correctly.
|
||||
struct StepCounter {
|
||||
counters: RwLock<Vec<(String, Arc<AtomicU64>)>>,
|
||||
}
|
||||
|
||||
impl StepCounter {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
counters: RwLock::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn next(&self, session_id: &str) -> usize {
|
||||
let map = self.counters.read().await;
|
||||
for (sid, counter) in map.iter() {
|
||||
if sid == session_id {
|
||||
return counter.fetch_add(1, Ordering::Relaxed) as usize;
|
||||
}
|
||||
}
|
||||
drop(map);
|
||||
|
||||
let mut map = self.counters.write().await;
|
||||
// Double-check after acquiring write lock
|
||||
for (sid, counter) in map.iter() {
|
||||
if sid == session_id {
|
||||
return counter.fetch_add(1, Ordering::Relaxed) as usize;
|
||||
}
|
||||
}
|
||||
let counter = Arc::new(AtomicU64::new(1));
|
||||
map.push((session_id.to_string(), counter.clone()));
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrajectoryRecorderMiddleware
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Middleware that records agent loop events into `TrajectoryStore`.
|
||||
///
|
||||
/// Hooks:
|
||||
/// - `before_completion` → records UserRequest step
|
||||
/// - `after_tool_call` → records ToolExecution step
|
||||
/// - `after_completion` → records LlmGeneration step
|
||||
pub struct TrajectoryRecorderMiddleware {
|
||||
store: Arc<TrajectoryStore>,
|
||||
step_counter: StepCounter,
|
||||
}
|
||||
|
||||
impl TrajectoryRecorderMiddleware {
|
||||
pub fn new(store: Arc<TrajectoryStore>) -> Self {
|
||||
Self {
|
||||
store,
|
||||
step_counter: StepCounter::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn an async write — fire-and-forget, non-blocking.
|
||||
fn spawn_write(&self, event: TrajectoryEvent) {
|
||||
let store = self.store.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = store.insert_event(&event).await {
|
||||
tracing::warn!(
|
||||
"[TrajectoryRecorder] Async write failed (non-fatal): {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn truncate(s: &str, max: usize) -> String {
|
||||
if s.len() <= max {
|
||||
s.to_string()
|
||||
} else {
|
||||
s.chars().take(max).collect::<String>() + "…"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for TrajectoryRecorderMiddleware {
|
||||
fn name(&self) -> &str {
|
||||
"trajectory_recorder"
|
||||
}
|
||||
|
||||
fn priority(&self) -> i32 {
|
||||
650
|
||||
}
|
||||
|
||||
async fn before_completion(
|
||||
&self,
|
||||
ctx: &mut MiddlewareContext,
|
||||
) -> Result<MiddlewareDecision> {
|
||||
if ctx.user_input.is_empty() {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
let step = self.step_counter.next(&ctx.session_id.to_string()).await;
|
||||
let event = TrajectoryEvent {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
session_id: ctx.session_id.to_string(),
|
||||
agent_id: ctx.agent_id.to_string(),
|
||||
step_index: step,
|
||||
step_type: TrajectoryStepType::UserRequest,
|
||||
input_summary: Self::truncate(&ctx.user_input, 200),
|
||||
output_summary: String::new(),
|
||||
duration_ms: 0,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
self.spawn_write(event);
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
|
||||
async fn after_tool_call(
|
||||
&self,
|
||||
ctx: &mut MiddlewareContext,
|
||||
tool_name: &str,
|
||||
result: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
let step = self.step_counter.next(&ctx.session_id.to_string()).await;
|
||||
let result_summary = match result {
|
||||
serde_json::Value::String(s) => Self::truncate(s, 200),
|
||||
serde_json::Value::Object(_) => {
|
||||
let s = serde_json::to_string(result).unwrap_or_default();
|
||||
Self::truncate(&s, 200)
|
||||
}
|
||||
other => Self::truncate(&other.to_string(), 200),
|
||||
};
|
||||
|
||||
let event = TrajectoryEvent {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
session_id: ctx.session_id.to_string(),
|
||||
agent_id: ctx.agent_id.to_string(),
|
||||
step_index: step,
|
||||
step_type: TrajectoryStepType::ToolExecution,
|
||||
input_summary: Self::truncate(tool_name, 200),
|
||||
output_summary: result_summary,
|
||||
duration_ms: 0,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
self.spawn_write(event);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||
let step = self.step_counter.next(&ctx.session_id.to_string()).await;
|
||||
let output_summary = ctx.response_content.iter()
|
||||
.filter_map(|b| match b {
|
||||
ContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
let event = TrajectoryEvent {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
session_id: ctx.session_id.to_string(),
|
||||
agent_id: ctx.agent_id.to_string(),
|
||||
step_index: step,
|
||||
step_type: TrajectoryStepType::LlmGeneration,
|
||||
input_summary: String::new(),
|
||||
output_summary: Self::truncate(&output_summary, 200),
|
||||
duration_ms: 0,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
self.spawn_write(event);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step_counter_sequential() {
|
||||
let counter = StepCounter::new();
|
||||
assert_eq!(counter.next("sess-1").await, 0);
|
||||
assert_eq!(counter.next("sess-1").await, 1);
|
||||
assert_eq!(counter.next("sess-1").await, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step_counter_different_sessions() {
|
||||
let counter = StepCounter::new();
|
||||
assert_eq!(counter.next("sess-1").await, 0);
|
||||
assert_eq!(counter.next("sess-2").await, 0);
|
||||
assert_eq!(counter.next("sess-1").await, 1);
|
||||
assert_eq!(counter.next("sess-2").await, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_short() {
|
||||
assert_eq!(TrajectoryRecorderMiddleware::truncate("hello", 10), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_long() {
|
||||
let long: String = "中".repeat(300);
|
||||
let truncated = TrajectoryRecorderMiddleware::truncate(&long, 200);
|
||||
assert!(truncated.chars().count() <= 201); // 200 + …
|
||||
}
|
||||
}
|
||||
593
crates/zclaw-runtime/src/nl_schedule.rs
Normal file
593
crates/zclaw-runtime/src/nl_schedule.rs
Normal file
@@ -0,0 +1,593 @@
|
||||
//! Natural Language Schedule Parser — transforms Chinese time expressions into cron.
|
||||
//!
|
||||
//! Three-layer fallback strategy:
|
||||
//! 1. Regex pattern matching (covers ~80% of common expressions)
|
||||
//! 2. LLM-assisted parsing (for ambiguous/complex expressions) — TODO: wire when Haiku driver available
|
||||
//! 3. Interactive clarification (return `Unclear`)
|
||||
//!
|
||||
//! Lives in `zclaw-runtime` because it's a pure text→cron utility with no kernel dependency.
|
||||
|
||||
use chrono::{Datelike, Timelike};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use zclaw_types::AgentId;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Data structures
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Result of parsing a natural language schedule expression.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParsedSchedule {
|
||||
/// Cron expression, e.g. "0 9 * * *"
|
||||
pub cron_expression: String,
|
||||
/// Human-readable description of the schedule
|
||||
pub natural_description: String,
|
||||
/// Confidence of the parse (0.0–1.0)
|
||||
pub confidence: f32,
|
||||
/// What the task does (extracted from user input)
|
||||
pub task_description: String,
|
||||
/// What to trigger when the schedule fires
|
||||
pub task_target: TaskTarget,
|
||||
}
|
||||
|
||||
/// Target to trigger on schedule.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", content = "id")]
|
||||
pub enum TaskTarget {
|
||||
/// Trigger a specific agent
|
||||
Agent(String),
|
||||
/// Trigger a specific hand
|
||||
Hand(String),
|
||||
/// Trigger a specific workflow
|
||||
Workflow(String),
|
||||
/// Generic reminder (no specific target)
|
||||
Reminder,
|
||||
}
|
||||
|
||||
/// Outcome of NL schedule parsing.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ScheduleParseResult {
|
||||
/// High-confidence single parse
|
||||
Exact(ParsedSchedule),
|
||||
/// Multiple possible interpretations
|
||||
Ambiguous(Vec<ParsedSchedule>),
|
||||
/// Unable to parse — needs user clarification
|
||||
Unclear,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Regex pattern library
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A single pattern for matching Chinese time expressions.
|
||||
struct SchedulePattern {
|
||||
/// Regex pattern string
|
||||
regex: &'static str,
|
||||
/// Cron template — use {h} for hour, {m} for minute, {dow} for day-of-week, {dom} for day-of-month
|
||||
cron_template: &'static str,
|
||||
/// Human description template
|
||||
description: &'static str,
|
||||
/// Base confidence for this pattern
|
||||
confidence: f32,
|
||||
}
|
||||
|
||||
/// Chinese time period keywords → hour mapping
|
||||
fn period_to_hour(period: &str) -> Option<u32> {
|
||||
match period {
|
||||
"凌晨" => Some(0),
|
||||
"早上" | "早晨" | "上午" => Some(9),
|
||||
"中午" => Some(12),
|
||||
"下午" | "午后" => Some(15),
|
||||
"傍晚" | "黄昏" => Some(18),
|
||||
"晚上" | "晚间" | "夜里" | "夜晚" => Some(21),
|
||||
"半夜" | "午夜" => Some(0),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Chinese weekday names → cron day-of-week
|
||||
fn weekday_to_cron(day: &str) -> Option<&'static str> {
|
||||
match day {
|
||||
"一" | "周一" | "星期一" | "礼拜一" => Some("1"),
|
||||
"二" | "周二" | "星期二" | "礼拜二" => Some("2"),
|
||||
"三" | "周三" | "星期三" | "礼拜三" => Some("3"),
|
||||
"四" | "周四" | "星期四" | "礼拜四" => Some("4"),
|
||||
"五" | "周五" | "星期五" | "礼拜五" => Some("5"),
|
||||
"六" | "周六" | "星期六" | "礼拜六" => Some("6"),
|
||||
"日" | "周日" | "星期日" | "礼拜日" | "天" | "周天" | "星期天" | "礼拜天" => Some("0"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Parser implementation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Parse a natural language schedule expression into a cron expression.
|
||||
///
|
||||
/// Uses a series of regex-based pattern matchers covering common Chinese
|
||||
/// time expressions. Returns `Unclear` if no pattern matches.
|
||||
pub fn parse_nl_schedule(input: &str, default_agent_id: &AgentId) -> ScheduleParseResult {
|
||||
let input = input.trim();
|
||||
if input.is_empty() {
|
||||
return ScheduleParseResult::Unclear;
|
||||
}
|
||||
|
||||
// Extract task description (everything after keywords like "提醒我", "帮我")
|
||||
let task_description = extract_task_description(input);
|
||||
|
||||
// --- Pattern 1: 每天 + 时间 ---
|
||||
if let Some(result) = try_every_day(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// --- Pattern 2: 每周N + 时间 ---
|
||||
if let Some(result) = try_every_week(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// --- Pattern 3: 工作日 + 时间 ---
|
||||
if let Some(result) = try_workday(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// --- Pattern 4: 每N小时/分钟 ---
|
||||
if let Some(result) = try_interval(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// --- Pattern 5: 每月N号 ---
|
||||
if let Some(result) = try_monthly(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// --- Pattern 6: 明天/后天 + 时间 (one-shot) ---
|
||||
if let Some(result) = try_one_shot(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
|
||||
ScheduleParseResult::Unclear
|
||||
}
|
||||
|
||||
/// Extract task description from input, stripping schedule-related keywords.
|
||||
fn extract_task_description(input: &str) -> String {
|
||||
let strip_prefixes = [
|
||||
"每天", "每日", "每周", "工作日", "每个工作日",
|
||||
"每月", "每", "定时", "定期",
|
||||
"提醒我", "提醒", "帮我", "帮", "请",
|
||||
"明天", "后天", "大后天",
|
||||
];
|
||||
|
||||
let mut desc = input.to_string();
|
||||
|
||||
// Strip prefixes + time expressions in alternating passes until stable
|
||||
let time_re = regex::Regex::new(
|
||||
r"^(?:凌晨|早上|早晨|上午|中午|下午|午后|傍晚|黄昏|晚上|晚间|夜里|夜晚|半夜|午夜)?\d{1,2}[点时::]\d{0,2}分?"
|
||||
).unwrap_or_else(|_| regex::Regex::new("").unwrap());
|
||||
|
||||
for _ in 0..3 {
|
||||
// Pass 1: strip prefixes
|
||||
loop {
|
||||
let mut stripped = false;
|
||||
for prefix in &strip_prefixes {
|
||||
if desc.starts_with(prefix) {
|
||||
desc = desc[prefix.len()..].to_string();
|
||||
stripped = true;
|
||||
}
|
||||
}
|
||||
if !stripped { break; }
|
||||
}
|
||||
// Pass 2: strip time expressions
|
||||
let new_desc = time_re.replace(&desc, "").to_string();
|
||||
if new_desc == desc { break; }
|
||||
desc = new_desc;
|
||||
}
|
||||
|
||||
desc.trim().to_string()
|
||||
}
|
||||
|
||||
// -- Pattern matchers --
|
||||
|
||||
/// Adjust hour based on time-of-day period. Chinese 12-hour convention:
|
||||
/// 下午3点 = 15, 晚上8点 = 20, etc. Morning hours stay as-is.
|
||||
fn adjust_hour_for_period(hour: u32, period: Option<&str>) -> u32 {
|
||||
if let Some(p) = period {
|
||||
match p {
|
||||
"下午" | "午后" => { if hour < 12 { hour + 12 } else { hour } }
|
||||
"晚上" | "晚间" | "夜里" | "夜晚" => { if hour < 12 { hour + 12 } else { hour } }
|
||||
"傍晚" | "黄昏" => { if hour < 12 { hour + 12 } else { hour } }
|
||||
"中午" => { if hour == 12 { 12 } else if hour < 12 { hour + 12 } else { hour } }
|
||||
"半夜" | "午夜" => { if hour == 12 { 0 } else { hour } }
|
||||
_ => hour,
|
||||
}
|
||||
} else {
|
||||
hour
|
||||
}
|
||||
}
|
||||
|
||||
const PERIOD_PATTERN: &str = "(凌晨|早上|早晨|上午|中午|下午|午后|傍晚|黄昏|晚上|晚间|夜里|夜晚|半夜|午夜)?";
|
||||
|
||||
fn try_every_day(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
let re = regex::Regex::new(
|
||||
&format!(r"(?:每天|每日)(?:的)?{}(\d{{1,2}})[点时::](\d{{1,2}})?", PERIOD_PATTERN)
|
||||
).ok()?;
|
||||
if let Some(caps) = re.captures(input) {
|
||||
let period = caps.get(1).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(2)?.as_str().parse().ok()?;
|
||||
let minute: u32 = caps.get(3).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
}
|
||||
return Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: format!("{} {} * * *", minute, hour),
|
||||
natural_description: format!("每天{:02}:{:02}", hour, minute),
|
||||
confidence: 0.95,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// "每天早上/下午..." without explicit hour
|
||||
let re2 = regex::Regex::new(r"(?:每天|每日)(?:的)?(凌晨|早上|早晨|上午|中午|下午|午后|傍晚|黄昏|晚上|晚间|夜里|夜晚|半夜|午夜)").ok()?;
|
||||
if let Some(caps) = re2.captures(input) {
|
||||
let period = caps.get(1)?.as_str();
|
||||
if let Some(hour) = period_to_hour(period) {
|
||||
return Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: format!("0 {} * * *", hour),
|
||||
natural_description: format!("每天{}", period),
|
||||
confidence: 0.85,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn try_every_week(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
let re = regex::Regex::new(
|
||||
&format!(r"(?:每周|每个?星期|每个?礼拜)(一|二|三|四|五|六|日|天|周一|周二|周三|周四|周五|周六|周日|周天|星期一|星期二|星期三|星期四|星期五|星期六|星期日|星期天|礼拜一|礼拜二|礼拜三|礼拜四|礼拜五|礼拜六|礼拜日|礼拜天)(?:的)?{}(\d{{1,2}})[点时::](\d{{1,2}})?", PERIOD_PATTERN)
|
||||
).ok()?;
|
||||
|
||||
let caps = re.captures(input)?;
|
||||
let day_str = caps.get(1)?.as_str();
|
||||
let dow = weekday_to_cron(day_str)?;
|
||||
let period = caps.get(2).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(3)?.as_str().parse().ok()?;
|
||||
let minute: u32 = caps.get(4).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: format!("{} {} * * {}", minute, hour, dow),
|
||||
natural_description: format!("每周{} {:02}:{:02}", day_str, hour, minute),
|
||||
confidence: 0.92,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}))
|
||||
}
|
||||
|
||||
fn try_workday(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
let re = regex::Regex::new(
|
||||
&format!(r"(?:工作日|每个?工作日|工作日(?:的)?){}(\d{{1,2}})[点时::](\d{{1,2}})?", PERIOD_PATTERN)
|
||||
).ok()?;
|
||||
|
||||
if let Some(caps) = re.captures(input) {
|
||||
let period = caps.get(1).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(2)?.as_str().parse().ok()?;
|
||||
let minute: u32 = caps.get(3).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
}
|
||||
return Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: format!("{} {} * * 1-5", minute, hour),
|
||||
natural_description: format!("工作日{:02}:{:02}", hour, minute),
|
||||
confidence: 0.90,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// "工作日下午3点" style
|
||||
let re2 = regex::Regex::new(
|
||||
r"(?:工作日|每个?工作日)(?:的)?(凌晨|早上|早晨|上午|中午|下午|午后|傍晚|黄昏|晚上|晚间|夜里|夜晚|半夜|午夜)"
|
||||
).ok()?;
|
||||
if let Some(caps) = re2.captures(input) {
|
||||
let period = caps.get(1)?.as_str();
|
||||
if let Some(hour) = period_to_hour(period) {
|
||||
return Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: format!("0 {} * * 1-5", hour),
|
||||
natural_description: format!("工作日{}", period),
|
||||
confidence: 0.85,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn try_interval(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
// "每2小时", "每30分钟", "每N小时/分钟"
|
||||
let re = regex::Regex::new(r"每(\d{1,2})(小时|分钟|分|钟|个小时)").ok()?;
|
||||
if let Some(caps) = re.captures(input) {
|
||||
let n: u32 = caps.get(1)?.as_str().parse().ok()?;
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
let unit = caps.get(2)?.as_str();
|
||||
let (cron, desc) = if unit.contains("小") {
|
||||
(format!("0 */{} * * *", n), format!("每{}小时", n))
|
||||
} else {
|
||||
(format!("*/{} * * * *", n), format!("每{}分钟", n))
|
||||
};
|
||||
return Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: cron,
|
||||
natural_description: desc,
|
||||
confidence: 0.90,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn try_monthly(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
let re = regex::Regex::new(
|
||||
&format!(r"(?:每月|每个月)(?:的)?(\d{{1,2}})[号日](?:的)?{}(\d{{1,2}})?[点时::]?(\d{{1,2}})?", PERIOD_PATTERN)
|
||||
).ok()?;
|
||||
|
||||
if let Some(caps) = re.captures(input) {
|
||||
let day: u32 = caps.get(1)?.as_str().parse().ok()?;
|
||||
let period = caps.get(2).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(3).map(|m| m.as_str().parse().unwrap_or(9)).unwrap_or(9);
|
||||
let minute: u32 = caps.get(4).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if day > 31 || hour > 23 || minute > 59 {
|
||||
return None;
|
||||
}
|
||||
return Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: format!("{} {} {} * *", minute, hour, day),
|
||||
natural_description: format!("每月{}号 {:02}:{:02}", day, hour, minute),
|
||||
confidence: 0.90,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn try_one_shot(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
let re = regex::Regex::new(
|
||||
&format!(r"(明天|后天|大后天)(?:的)?{}(\d{{1,2}})[点时::](\d{{1,2}})?", PERIOD_PATTERN)
|
||||
).ok()?;
|
||||
|
||||
let caps = re.captures(input)?;
|
||||
let day_offset = match caps.get(1)?.as_str() {
|
||||
"明天" => 1,
|
||||
"后天" => 2,
|
||||
"大后天" => 3,
|
||||
_ => return None,
|
||||
};
|
||||
let period = caps.get(2).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(3)?.as_str().parse().ok()?;
|
||||
let minute: u32 = caps.get(4).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let target = chrono::Utc::now()
|
||||
.checked_add_signed(chrono::Duration::days(day_offset))
|
||||
.unwrap_or_else(chrono::Utc::now)
|
||||
.with_hour(hour)
|
||||
.unwrap_or_else(|| chrono::Utc::now())
|
||||
.with_minute(minute)
|
||||
.unwrap_or_else(|| chrono::Utc::now())
|
||||
.with_second(0)
|
||||
.unwrap_or_else(|| chrono::Utc::now());
|
||||
|
||||
Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: target.to_rfc3339(),
|
||||
natural_description: format!("{} {:02}:{:02}", caps.get(1)?.as_str(), hour, minute),
|
||||
confidence: 0.88,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Schedule intent detection
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Keywords indicating the user wants to set a scheduled task.
|
||||
const SCHEDULE_INTENT_KEYWORDS: &[&str] = &[
|
||||
"提醒我", "提醒", "定时", "每天", "每日", "每周", "每月",
|
||||
"工作日", "每隔", "每", "定期", "到时候", "准时",
|
||||
"闹钟", "闹铃", "日程", "日历",
|
||||
];
|
||||
|
||||
/// Check if user input contains schedule intent.
|
||||
pub fn has_schedule_intent(input: &str) -> bool {
|
||||
let lower = input.to_lowercase();
|
||||
SCHEDULE_INTENT_KEYWORDS.iter().any(|kw| lower.contains(kw))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_agent() -> AgentId {
|
||||
AgentId::new()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_every_day_explicit_time() {
|
||||
let result = parse_nl_schedule("每天早上9点提醒我查房", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "0 9 * * *");
|
||||
assert!(s.confidence >= 0.9);
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_every_day_with_minute() {
|
||||
let result = parse_nl_schedule("每天下午3点30分提醒我", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "30 15 * * *");
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_every_day_period_only() {
|
||||
let result = parse_nl_schedule("每天早上提醒我看看报告", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "0 9 * * *");
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_every_week_monday() {
|
||||
let result = parse_nl_schedule("每周一上午10点提醒我开会", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "0 10 * * 1");
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_every_week_friday() {
|
||||
let result = parse_nl_schedule("每个星期五下午2点", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "0 14 * * 5");
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workday() {
|
||||
let result = parse_nl_schedule("工作日下午3点提醒我写周报", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "0 15 * * 1-5");
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interval_hours() {
|
||||
let result = parse_nl_schedule("每2小时提醒我喝水", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "0 */2 * * *");
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interval_minutes() {
|
||||
let result = parse_nl_schedule("每30分钟检查一次", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "*/30 * * * *");
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_monthly() {
|
||||
let result = parse_nl_schedule("每月1号早上9点提醒我", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "0 9 1 * *");
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_shot_tomorrow() {
|
||||
let result = parse_nl_schedule("明天下午3点提醒我开会", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert!(s.cron_expression.contains('T'));
|
||||
assert!(s.natural_description.contains("明天"));
|
||||
}
|
||||
_ => panic!("Expected Exact"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unclear_input() {
|
||||
let result = parse_nl_schedule("今天天气怎么样", &default_agent());
|
||||
assert!(matches!(result, ScheduleParseResult::Unclear));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_input() {
|
||||
let result = parse_nl_schedule("", &default_agent());
|
||||
assert!(matches!(result, ScheduleParseResult::Unclear));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_schedule_intent_detection() {
|
||||
assert!(has_schedule_intent("每天早上9点提醒我查房"));
|
||||
assert!(has_schedule_intent("帮我设个定时任务"));
|
||||
assert!(has_schedule_intent("工作日提醒我打卡"));
|
||||
assert!(!has_schedule_intent("今天天气怎么样"));
|
||||
assert!(!has_schedule_intent("帮我写个报告"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_period_to_hour_mapping() {
|
||||
assert_eq!(period_to_hour("凌晨"), Some(0));
|
||||
assert_eq!(period_to_hour("早上"), Some(9));
|
||||
assert_eq!(period_to_hour("中午"), Some(12));
|
||||
assert_eq!(period_to_hour("下午"), Some(15));
|
||||
assert_eq!(period_to_hour("晚上"), Some(21));
|
||||
assert_eq!(period_to_hour("不知道"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weekday_to_cron_mapping() {
|
||||
assert_eq!(weekday_to_cron("一"), Some("1"));
|
||||
assert_eq!(weekday_to_cron("五"), Some("5"));
|
||||
assert_eq!(weekday_to_cron("日"), Some("0"));
|
||||
assert_eq!(weekday_to_cron("星期三"), Some("3"));
|
||||
assert_eq!(weekday_to_cron("礼拜天"), Some("0"));
|
||||
assert_eq!(weekday_to_cron("未知"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_task_description_extraction() {
|
||||
assert_eq!(extract_task_description("每天早上9点提醒我查房"), "查房");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user