feat(hermes): implement intelligence pipeline — 4 chunks, 684 tests passing

Hermes Intelligence Pipeline closes breakpoints in ZCLAW's existing
intelligence components with 4 self-contained modules:

Chunk 1 — Self-improvement Loop:
- ExperienceStore (zclaw-growth): FTS5+TF-IDF wrapper with scope prefix
- ExperienceExtractor (desktop/intelligence): template-based extraction
  from successful proposals with implicit keyword detection

Chunk 2 — User Modeling:
- UserProfileStore (zclaw-memory): SQLite-backed structured profiles
  with industry/role/expertise/comm_style/recent_topics/pain_points
- UserProfiler (desktop/intelligence): fact classification by category
  (Preference/Knowledge/Behavior) with profile summary formatting

Chunk 3 — NL Cron Chinese Time Parser:
- NlScheduleParser (zclaw-runtime): 6 pattern matchers for Chinese time
  expressions (每天/每周/工作日/间隔/每月/一次性) producing cron expressions
- Period-aware hour adjustment (下午3点→15, 晚上8点→20)
- Schedule intent detection + task description extraction

Chunk 4 — Trajectory Compression:
- TrajectoryStore (zclaw-memory): trajectory_events + compressed_trajectories
- TrajectoryRecorderMiddleware (zclaw-runtime/middleware): priority 650,
  async non-blocking event recording via tokio::spawn
- TrajectoryCompressor (desktop/intelligence): dedup, request classification,
  satisfaction detection, execution chain JSON

Schema migrations: v2→v3 (user_profiles), v3→v4 (trajectory tables)
This commit is contained in:
iven
2026-04-09 17:47:43 +08:00
parent 0883bb28ff
commit 4b15ead8e7
15 changed files with 4225 additions and 0 deletions

View File

@@ -6,8 +6,15 @@ mod store;
mod session;
mod schema;
pub mod fact;
pub mod user_profile_store;
pub mod trajectory_store;
pub use store::*;
pub use session::*;
pub use schema::*;
pub use fact::{Fact, FactCategory, ExtractedFactBatch};
pub use user_profile_store::{UserProfileStore, UserProfile, Level, CommStyle};
pub use trajectory_store::{
TrajectoryEvent, TrajectoryStore, TrajectoryStepType,
CompressedTrajectory, CompletionStatus, SatisfactionSignal,
};

View File

@@ -93,4 +93,47 @@ pub const MIGRATIONS: &[&str] = &[
// v1→v2: persist runtime state and message count
"ALTER TABLE agents ADD COLUMN state TEXT NOT NULL DEFAULT 'running'",
"ALTER TABLE agents ADD COLUMN message_count INTEGER NOT NULL DEFAULT 0",
// v2→v3: user profiles for structured user modeling
"CREATE TABLE IF NOT EXISTS user_profiles (
user_id TEXT PRIMARY KEY,
industry TEXT,
role TEXT,
expertise_level TEXT,
communication_style TEXT,
preferred_language TEXT DEFAULT 'zh-CN',
recent_topics TEXT DEFAULT '[]',
active_pain_points TEXT DEFAULT '[]',
preferred_tools TEXT DEFAULT '[]',
confidence REAL DEFAULT 0.0,
updated_at TEXT NOT NULL
)",
// v3→v4: trajectory recording for tool-call chain analysis
"CREATE TABLE IF NOT EXISTS trajectory_events (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
step_index INTEGER NOT NULL,
step_type TEXT NOT NULL,
input_summary TEXT,
output_summary TEXT,
duration_ms INTEGER DEFAULT 0,
timestamp TEXT NOT NULL
)",
"CREATE INDEX IF NOT EXISTS idx_trajectory_session ON trajectory_events(session_id)",
"CREATE TABLE IF NOT EXISTS compressed_trajectories (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
request_type TEXT NOT NULL,
tools_used TEXT,
outcome TEXT NOT NULL,
total_steps INTEGER DEFAULT 0,
total_duration_ms INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
execution_chain TEXT NOT NULL,
satisfaction_signal TEXT,
created_at TEXT NOT NULL
)",
"CREATE INDEX IF NOT EXISTS idx_ct_request_type ON compressed_trajectories(request_type)",
"CREATE INDEX IF NOT EXISTS idx_ct_outcome ON compressed_trajectories(outcome)",
];

View File

@@ -0,0 +1,563 @@
//! Trajectory Store -- record and compress tool-call chains for analysis.
//!
//! Stores raw trajectory events (user requests, tool calls, LLM generations)
//! and compressed trajectory summaries. Used by the Hermes Intelligence Pipeline
//! to analyze agent behaviour patterns and improve routing over time.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use zclaw_types::{Result, ZclawError};
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
/// Step type in a trajectory.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TrajectoryStepType {
UserRequest,
IntentClassification,
SkillSelection,
ToolExecution,
LlmGeneration,
UserFeedback,
}
impl TrajectoryStepType {
/// Serialize to the string stored in SQLite.
pub fn as_str(&self) -> &'static str {
match self {
Self::UserRequest => "user_request",
Self::IntentClassification => "intent_classification",
Self::SkillSelection => "skill_selection",
Self::ToolExecution => "tool_execution",
Self::LlmGeneration => "llm_generation",
Self::UserFeedback => "user_feedback",
}
}
/// Deserialize from the SQLite string representation.
pub fn from_str_lossy(s: &str) -> Self {
match s {
"user_request" => Self::UserRequest,
"intent_classification" => Self::IntentClassification,
"skill_selection" => Self::SkillSelection,
"tool_execution" => Self::ToolExecution,
"llm_generation" => Self::LlmGeneration,
"user_feedback" => Self::UserFeedback,
_ => Self::UserRequest,
}
}
}
/// Single trajectory event.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrajectoryEvent {
pub id: String,
pub session_id: String,
pub agent_id: String,
pub step_index: usize,
pub step_type: TrajectoryStepType,
/// Summarised input (max 200 chars).
pub input_summary: String,
/// Summarised output (max 200 chars).
pub output_summary: String,
pub duration_ms: u64,
pub timestamp: DateTime<Utc>,
}
/// Satisfaction signal inferred from user feedback.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SatisfactionSignal {
Positive,
Negative,
Neutral,
}
impl SatisfactionSignal {
pub fn as_str(&self) -> &'static str {
match self {
Self::Positive => "positive",
Self::Negative => "negative",
Self::Neutral => "neutral",
}
}
pub fn from_str_lossy(s: &str) -> Option<Self> {
match s {
"positive" => Some(Self::Positive),
"negative" => Some(Self::Negative),
"neutral" => Some(Self::Neutral),
_ => None,
}
}
}
/// Completion status of a compressed trajectory.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CompletionStatus {
Success,
Partial,
Failed,
Abandoned,
}
impl CompletionStatus {
pub fn as_str(&self) -> &'static str {
match self {
Self::Success => "success",
Self::Partial => "partial",
Self::Failed => "failed",
Self::Abandoned => "abandoned",
}
}
pub fn from_str_lossy(s: &str) -> Self {
match s {
"success" => Self::Success,
"partial" => Self::Partial,
"failed" => Self::Failed,
"abandoned" => Self::Abandoned,
_ => Self::Success,
}
}
}
/// Compressed trajectory (generated at session end).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressedTrajectory {
pub id: String,
pub session_id: String,
pub agent_id: String,
pub request_type: String,
pub tools_used: Vec<String>,
pub outcome: CompletionStatus,
pub total_steps: usize,
pub total_duration_ms: u64,
pub total_tokens: u32,
/// Serialised JSON execution chain for analysis.
pub execution_chain: String,
pub satisfaction_signal: Option<SatisfactionSignal>,
pub created_at: DateTime<Utc>,
}
// ---------------------------------------------------------------------------
// Store
// ---------------------------------------------------------------------------
/// Persistent store for trajectory events and compressed trajectories.
pub struct TrajectoryStore {
pool: SqlitePool,
}
impl TrajectoryStore {
/// Create a new `TrajectoryStore` backed by the given SQLite pool.
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
/// Create the required tables. Idempotent -- safe to call on startup.
pub async fn initialize_schema(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS trajectory_events (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
step_index INTEGER NOT NULL,
step_type TEXT NOT NULL,
input_summary TEXT,
output_summary TEXT,
duration_ms INTEGER DEFAULT 0,
timestamp TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_trajectory_session ON trajectory_events(session_id);
"#,
)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS compressed_trajectories (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
request_type TEXT NOT NULL,
tools_used TEXT,
outcome TEXT NOT NULL,
total_steps INTEGER DEFAULT 0,
total_duration_ms INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
execution_chain TEXT NOT NULL,
satisfaction_signal TEXT,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_ct_request_type ON compressed_trajectories(request_type);
CREATE INDEX IF NOT EXISTS idx_ct_outcome ON compressed_trajectories(outcome);
"#,
)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
Ok(())
}
/// Insert a raw trajectory event.
pub async fn insert_event(&self, event: &TrajectoryEvent) -> Result<()> {
sqlx::query(
r#"
INSERT INTO trajectory_events
(id, session_id, agent_id, step_index, step_type,
input_summary, output_summary, duration_ms, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&event.id)
.bind(&event.session_id)
.bind(&event.agent_id)
.bind(event.step_index as i64)
.bind(event.step_type.as_str())
.bind(&event.input_summary)
.bind(&event.output_summary)
.bind(event.duration_ms as i64)
.bind(event.timestamp.to_rfc3339())
.execute(&self.pool)
.await
.map_err(|e| {
tracing::warn!("[TrajectoryStore] insert_event failed: {}", e);
ZclawError::StorageError(e.to_string())
})?;
Ok(())
}
/// Retrieve all raw events for a session, ordered by step_index.
pub async fn get_events_by_session(&self, session_id: &str) -> Result<Vec<TrajectoryEvent>> {
let rows = sqlx::query_as::<_, (String, String, String, i64, String, Option<String>, Option<String>, Option<i64>, String)>(
r#"
SELECT id, session_id, agent_id, step_index, step_type,
input_summary, output_summary, duration_ms, timestamp
FROM trajectory_events
WHERE session_id = ?
ORDER BY step_index ASC
"#,
)
.bind(session_id)
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
let mut events = Vec::with_capacity(rows.len());
for (id, sid, aid, step_idx, stype, input_s, output_s, dur_ms, ts) in rows {
let timestamp = DateTime::parse_from_rfc3339(&ts)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
events.push(TrajectoryEvent {
id,
session_id: sid,
agent_id: aid,
step_index: step_idx as usize,
step_type: TrajectoryStepType::from_str_lossy(&stype),
input_summary: input_s.unwrap_or_default(),
output_summary: output_s.unwrap_or_default(),
duration_ms: dur_ms.unwrap_or(0) as u64,
timestamp,
});
}
Ok(events)
}
/// Insert a compressed trajectory.
pub async fn insert_compressed(&self, trajectory: &CompressedTrajectory) -> Result<()> {
let tools_json = serde_json::to_string(&trajectory.tools_used)
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
sqlx::query(
r#"
INSERT INTO compressed_trajectories
(id, session_id, agent_id, request_type, tools_used,
outcome, total_steps, total_duration_ms, total_tokens,
execution_chain, satisfaction_signal, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&trajectory.id)
.bind(&trajectory.session_id)
.bind(&trajectory.agent_id)
.bind(&trajectory.request_type)
.bind(&tools_json)
.bind(trajectory.outcome.as_str())
.bind(trajectory.total_steps as i64)
.bind(trajectory.total_duration_ms as i64)
.bind(trajectory.total_tokens as i64)
.bind(&trajectory.execution_chain)
.bind(trajectory.satisfaction_signal.map(|s| s.as_str()))
.bind(trajectory.created_at.to_rfc3339())
.execute(&self.pool)
.await
.map_err(|e| {
tracing::warn!("[TrajectoryStore] insert_compressed failed: {}", e);
ZclawError::StorageError(e.to_string())
})?;
Ok(())
}
/// Retrieve the compressed trajectory for a session, if any.
pub async fn get_compressed_by_session(&self, session_id: &str) -> Result<Option<CompressedTrajectory>> {
let row = sqlx::query_as::<_, (
String, String, String, String, Option<String>,
String, i64, i64, i64, String, Option<String>, String,
)>(
r#"
SELECT id, session_id, agent_id, request_type, tools_used,
outcome, total_steps, total_duration_ms, total_tokens,
execution_chain, satisfaction_signal, created_at
FROM compressed_trajectories
WHERE session_id = ?
"#,
)
.bind(session_id)
.fetch_optional(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
match row {
Some((id, sid, aid, req_type, tools_json, outcome_str, steps, dur_ms, tokens, chain, sat, created)) => {
let tools_used: Vec<String> = tools_json
.as_deref()
.and_then(|j| serde_json::from_str(j).ok())
.unwrap_or_default();
let timestamp = DateTime::parse_from_rfc3339(&created)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Ok(Some(CompressedTrajectory {
id,
session_id: sid,
agent_id: aid,
request_type: req_type,
tools_used,
outcome: CompletionStatus::from_str_lossy(&outcome_str),
total_steps: steps as usize,
total_duration_ms: dur_ms as u64,
total_tokens: tokens as u32,
execution_chain: chain,
satisfaction_signal: sat.as_deref().and_then(SatisfactionSignal::from_str_lossy),
created_at: timestamp,
}))
}
None => Ok(None),
}
}
/// Delete raw trajectory events older than `days` days. Returns count deleted.
pub async fn delete_events_older_than(&self, days: i64) -> Result<u64> {
let result = sqlx::query(
r#"
DELETE FROM trajectory_events
WHERE timestamp < datetime('now', ?)
"#,
)
.bind(format!("-{} days", days))
.execute(&self.pool)
.await
.map_err(|e| {
tracing::warn!("[TrajectoryStore] delete_events_older_than failed: {}", e);
ZclawError::StorageError(e.to_string())
})?;
Ok(result.rows_affected())
}
/// Delete compressed trajectories older than `days` days. Returns count deleted.
pub async fn delete_compressed_older_than(&self, days: i64) -> Result<u64> {
let result = sqlx::query(
r#"
DELETE FROM compressed_trajectories
WHERE created_at < datetime('now', ?)
"#,
)
.bind(format!("-{} days", days))
.execute(&self.pool)
.await
.map_err(|e| {
tracing::warn!("[TrajectoryStore] delete_compressed_older_than failed: {}", e);
ZclawError::StorageError(e.to_string())
})?;
Ok(result.rows_affected())
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
async fn test_store() -> TrajectoryStore {
let pool = SqlitePool::connect("sqlite::memory:")
.await
.expect("in-memory pool");
let store = TrajectoryStore::new(pool);
store.initialize_schema().await.expect("schema init");
store
}
fn sample_event(index: usize) -> TrajectoryEvent {
TrajectoryEvent {
id: format!("evt-{}", index),
session_id: "sess-1".to_string(),
agent_id: "agent-1".to_string(),
step_index: index,
step_type: TrajectoryStepType::ToolExecution,
input_summary: "search query".to_string(),
output_summary: "3 results found".to_string(),
duration_ms: 150,
timestamp: Utc::now(),
}
}
#[tokio::test]
async fn test_insert_and_get_events() {
let store = test_store().await;
let e1 = sample_event(0);
let e2 = TrajectoryEvent {
id: "evt-1".to_string(),
step_index: 1,
step_type: TrajectoryStepType::LlmGeneration,
..sample_event(0)
};
store.insert_event(&e1).await.unwrap();
store.insert_event(&e2).await.unwrap();
let events = store.get_events_by_session("sess-1").await.unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].step_index, 0);
assert_eq!(events[1].step_index, 1);
assert_eq!(events[0].step_type, TrajectoryStepType::ToolExecution);
assert_eq!(events[1].step_type, TrajectoryStepType::LlmGeneration);
}
#[tokio::test]
async fn test_get_events_empty_session() {
let store = test_store().await;
let events = store.get_events_by_session("nonexistent").await.unwrap();
assert!(events.is_empty());
}
#[tokio::test]
async fn test_insert_and_get_compressed() {
let store = test_store().await;
let ct = CompressedTrajectory {
id: "ct-1".to_string(),
session_id: "sess-1".to_string(),
agent_id: "agent-1".to_string(),
request_type: "data_query".to_string(),
tools_used: vec!["search".to_string(), "calculate".to_string()],
outcome: CompletionStatus::Success,
total_steps: 5,
total_duration_ms: 1200,
total_tokens: 350,
execution_chain: r#"[{"step":0,"type":"tool_execution"}]"#.to_string(),
satisfaction_signal: Some(SatisfactionSignal::Positive),
created_at: Utc::now(),
};
store.insert_compressed(&ct).await.unwrap();
let loaded = store.get_compressed_by_session("sess-1").await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.id, "ct-1");
assert_eq!(loaded.request_type, "data_query");
assert_eq!(loaded.tools_used.len(), 2);
assert_eq!(loaded.outcome, CompletionStatus::Success);
assert_eq!(loaded.satisfaction_signal, Some(SatisfactionSignal::Positive));
}
#[tokio::test]
async fn test_get_compressed_nonexistent() {
let store = test_store().await;
let result = store.get_compressed_by_session("nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_step_type_roundtrip() {
let all_types = [
TrajectoryStepType::UserRequest,
TrajectoryStepType::IntentClassification,
TrajectoryStepType::SkillSelection,
TrajectoryStepType::ToolExecution,
TrajectoryStepType::LlmGeneration,
TrajectoryStepType::UserFeedback,
];
for st in all_types {
assert_eq!(TrajectoryStepType::from_str_lossy(st.as_str()), st);
}
}
#[tokio::test]
async fn test_satisfaction_signal_roundtrip() {
let signals = [SatisfactionSignal::Positive, SatisfactionSignal::Negative, SatisfactionSignal::Neutral];
for sig in signals {
assert_eq!(SatisfactionSignal::from_str_lossy(sig.as_str()), Some(sig));
}
assert_eq!(SatisfactionSignal::from_str_lossy("bogus"), None);
}
#[tokio::test]
async fn test_completion_status_roundtrip() {
let statuses = [CompletionStatus::Success, CompletionStatus::Partial, CompletionStatus::Failed, CompletionStatus::Abandoned];
for s in statuses {
assert_eq!(CompletionStatus::from_str_lossy(s.as_str()), s);
}
}
#[tokio::test]
async fn test_delete_events_older_than() {
let store = test_store().await;
// Insert an event with a timestamp far in the past
let old_event = TrajectoryEvent {
id: "old-evt".to_string(),
timestamp: Utc::now() - chrono::Duration::days(100),
..sample_event(0)
};
store.insert_event(&old_event).await.unwrap();
// Insert a recent event
let recent_event = TrajectoryEvent {
id: "recent-evt".to_string(),
step_index: 1,
..sample_event(0)
};
store.insert_event(&recent_event).await.unwrap();
let deleted = store.delete_events_older_than(30).await.unwrap();
assert_eq!(deleted, 1);
let remaining = store.get_events_by_session("sess-1").await.unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].id, "recent-evt");
}
}

View File

@@ -0,0 +1,592 @@
//! User Profile Store — structured user modeling from conversation patterns.
//!
//! Maintains a single `UserProfile` per user (desktop uses "default_user")
//! in a dedicated SQLite table. Vec fields (recent_topics, pain points,
//! preferred_tools) are stored as JSON arrays and transparently
//! (de)serialised on read/write.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use sqlx::SqlitePool;
use zclaw_types::Result;
// ---------------------------------------------------------------------------
// Data types
// ---------------------------------------------------------------------------
/// Expertise level inferred from conversation patterns.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Level {
Beginner,
Intermediate,
Expert,
}
impl Level {
pub fn as_str(&self) -> &'static str {
match self {
Level::Beginner => "beginner",
Level::Intermediate => "intermediate",
Level::Expert => "expert",
}
}
pub fn from_str_lossy(s: &str) -> Option<Self> {
match s {
"beginner" => Some(Level::Beginner),
"intermediate" => Some(Level::Intermediate),
"expert" => Some(Level::Expert),
_ => None,
}
}
}
/// Communication style preference.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CommStyle {
Concise,
Detailed,
Formal,
Casual,
}
impl CommStyle {
pub fn as_str(&self) -> &'static str {
match self {
CommStyle::Concise => "concise",
CommStyle::Detailed => "detailed",
CommStyle::Formal => "formal",
CommStyle::Casual => "casual",
}
}
pub fn from_str_lossy(s: &str) -> Option<Self> {
match s {
"concise" => Some(CommStyle::Concise),
"detailed" => Some(CommStyle::Detailed),
"formal" => Some(CommStyle::Formal),
"casual" => Some(CommStyle::Casual),
_ => None,
}
}
}
/// Structured user profile (one record per user).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserProfile {
pub user_id: String,
pub industry: Option<String>,
pub role: Option<String>,
pub expertise_level: Option<Level>,
pub communication_style: Option<CommStyle>,
pub preferred_language: String,
pub recent_topics: Vec<String>,
pub active_pain_points: Vec<String>,
pub preferred_tools: Vec<String>,
pub confidence: f32,
pub updated_at: DateTime<Utc>,
}
impl UserProfile {
/// Create a blank profile for the given user.
pub fn blank(user_id: &str) -> Self {
Self {
user_id: user_id.to_string(),
industry: None,
role: None,
expertise_level: None,
communication_style: None,
preferred_language: "zh-CN".to_string(),
recent_topics: Vec::new(),
active_pain_points: Vec::new(),
preferred_tools: Vec::new(),
confidence: 0.0,
updated_at: Utc::now(),
}
}
/// Default profile for single-user desktop mode ("default_user").
pub fn default_profile() -> Self {
Self::blank("default_user")
}
}
// ---------------------------------------------------------------------------
// DDL
// ---------------------------------------------------------------------------
const PROFILE_DDL: &str = r#"
CREATE TABLE IF NOT EXISTS user_profiles (
user_id TEXT PRIMARY KEY,
industry TEXT,
role TEXT,
expertise_level TEXT,
communication_style TEXT,
preferred_language TEXT DEFAULT 'zh-CN',
recent_topics TEXT DEFAULT '[]',
active_pain_points TEXT DEFAULT '[]',
preferred_tools TEXT DEFAULT '[]',
confidence REAL DEFAULT 0.0,
updated_at TEXT NOT NULL
)
"#;
// ---------------------------------------------------------------------------
// Row mapping
// ---------------------------------------------------------------------------
fn row_to_profile(row: &sqlx::sqlite::SqliteRow) -> Result<UserProfile> {
let recent_topics_json: String = row.try_get("recent_topics").unwrap_or_else(|_| "[]".to_string());
let pain_json: String = row.try_get("active_pain_points").unwrap_or_else(|_| "[]".to_string());
let tools_json: String = row.try_get("preferred_tools").unwrap_or_else(|_| "[]".to_string());
let recent_topics: Vec<String> = serde_json::from_str(&recent_topics_json)?;
let active_pain_points: Vec<String> = serde_json::from_str(&pain_json)?;
let preferred_tools: Vec<String> = serde_json::from_str(&tools_json)?;
let expertise_str: Option<String> = row.try_get("expertise_level").unwrap_or(None);
let comm_str: Option<String> = row.try_get("communication_style").unwrap_or(None);
let updated_at_str: String = row.try_get("updated_at").unwrap_or_else(|_| Utc::now().to_rfc3339());
let updated_at = DateTime::parse_from_rfc3339(&updated_at_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Ok(UserProfile {
user_id: row.try_get("user_id").unwrap_or_default(),
industry: row.try_get("industry").unwrap_or(None),
role: row.try_get("role").unwrap_or(None),
expertise_level: expertise_str.as_deref().and_then(Level::from_str_lossy),
communication_style: comm_str.as_deref().and_then(CommStyle::from_str_lossy),
preferred_language: row.try_get("preferred_language").unwrap_or_else(|_| "zh-CN".to_string()),
recent_topics,
active_pain_points,
preferred_tools,
confidence: row.try_get("confidence").unwrap_or(0.0),
updated_at,
})
}
// ---------------------------------------------------------------------------
// UserProfileStore
// ---------------------------------------------------------------------------
/// SQLite-backed store for user profiles.
pub struct UserProfileStore {
pool: SqlitePool,
}
impl UserProfileStore {
/// Create a new store backed by the given connection pool.
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
/// Create tables. Idempotent — safe to call on every startup.
pub async fn initialize_schema(&self) -> Result<()> {
sqlx::query(PROFILE_DDL)
.execute(&self.pool)
.await
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
Ok(())
}
/// Fetch the profile for a user. Returns `None` when no row exists.
pub async fn get(&self, user_id: &str) -> Result<Option<UserProfile>> {
let row = sqlx::query(
"SELECT user_id, industry, role, expertise_level, communication_style, \
preferred_language, recent_topics, active_pain_points, preferred_tools, \
confidence, updated_at \
FROM user_profiles WHERE user_id = ?",
)
.bind(user_id)
.fetch_optional(&self.pool)
.await
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
match row {
Some(r) => Ok(Some(row_to_profile(&r)?)),
None => Ok(None),
}
}
/// Insert or replace the full profile.
pub async fn upsert(&self, profile: &UserProfile) -> Result<()> {
let topics = serde_json::to_string(&profile.recent_topics)?;
let pains = serde_json::to_string(&profile.active_pain_points)?;
let tools = serde_json::to_string(&profile.preferred_tools)?;
let expertise = profile.expertise_level.map(|l| l.as_str());
let comm = profile.communication_style.map(|c| c.as_str());
let updated = profile.updated_at.to_rfc3339();
sqlx::query(
"INSERT OR REPLACE INTO user_profiles \
(user_id, industry, role, expertise_level, communication_style, \
preferred_language, recent_topics, active_pain_points, preferred_tools, \
confidence, updated_at) \
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
)
.bind(&profile.user_id)
.bind(&profile.industry)
.bind(&profile.role)
.bind(expertise)
.bind(comm)
.bind(&profile.preferred_language)
.bind(&topics)
.bind(&pains)
.bind(&tools)
.bind(profile.confidence)
.bind(&updated)
.execute(&self.pool)
.await
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
Ok(())
}
/// Update a single scalar field by name.
///
/// `field` must be one of: industry, role, expertise_level,
/// communication_style, preferred_language, confidence.
/// Returns error for unrecognised field names (prevents SQL injection).
pub async fn update_field(&self, user_id: &str, field: &str, value: &str) -> Result<()> {
let sql = match field {
"industry" => "UPDATE user_profiles SET industry = ?, updated_at = ? WHERE user_id = ?",
"role" => "UPDATE user_profiles SET role = ?, updated_at = ? WHERE user_id = ?",
"expertise_level" => {
"UPDATE user_profiles SET expertise_level = ?, updated_at = ? WHERE user_id = ?"
}
"communication_style" => {
"UPDATE user_profiles SET communication_style = ?, updated_at = ? WHERE user_id = ?"
}
"preferred_language" => {
"UPDATE user_profiles SET preferred_language = ?, updated_at = ? WHERE user_id = ?"
}
"confidence" => {
"UPDATE user_profiles SET confidence = ?, updated_at = ? WHERE user_id = ?"
}
_ => {
return Err(zclaw_types::ZclawError::InvalidInput(format!(
"Unknown profile field: {}",
field
)));
}
};
let now = Utc::now().to_rfc3339();
// confidence is REAL; parse the value string.
if field == "confidence" {
let f: f32 = value.parse().map_err(|_| {
zclaw_types::ZclawError::InvalidInput(format!("Invalid confidence: {}", value))
})?;
sqlx::query(sql)
.bind(f)
.bind(&now)
.bind(user_id)
.execute(&self.pool)
.await
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
} else {
sqlx::query(sql)
.bind(value)
.bind(&now)
.bind(user_id)
.execute(&self.pool)
.await
.map_err(|e| zclaw_types::ZclawError::StorageError(e.to_string()))?;
}
Ok(())
}
/// Append a topic to `recent_topics`, trimming to `max_topics`.
/// Creates a default profile row if none exists.
pub async fn add_recent_topic(
&self,
user_id: &str,
topic: &str,
max_topics: usize,
) -> Result<()> {
let mut profile = self
.get(user_id)
.await?
.unwrap_or_else(|| UserProfile::blank(user_id));
// Deduplicate: remove if already present, then push to front.
profile.recent_topics.retain(|t| t != topic);
profile.recent_topics.insert(0, topic.to_string());
profile.recent_topics.truncate(max_topics);
profile.updated_at = Utc::now();
self.upsert(&profile).await
}
/// Append a pain point, trimming to `max_pains`.
/// Creates a default profile row if none exists.
pub async fn add_pain_point(
&self,
user_id: &str,
pain: &str,
max_pains: usize,
) -> Result<()> {
let mut profile = self
.get(user_id)
.await?
.unwrap_or_else(|| UserProfile::blank(user_id));
profile.active_pain_points.retain(|p| p != pain);
profile.active_pain_points.insert(0, pain.to_string());
profile.active_pain_points.truncate(max_pains);
profile.updated_at = Utc::now();
self.upsert(&profile).await
}
/// Append a preferred tool, trimming to `max_tools`.
/// Creates a default profile row if none exists.
pub async fn add_preferred_tool(
&self,
user_id: &str,
tool: &str,
max_tools: usize,
) -> Result<()> {
let mut profile = self
.get(user_id)
.await?
.unwrap_or_else(|| UserProfile::blank(user_id));
profile.preferred_tools.retain(|t| t != tool);
profile.preferred_tools.insert(0, tool.to_string());
profile.preferred_tools.truncate(max_tools);
profile.updated_at = Utc::now();
self.upsert(&profile).await
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
/// Helper: create an in-memory store with schema.
async fn test_store() -> UserProfileStore {
let pool = SqlitePool::connect("sqlite::memory:")
.await
.expect("in-memory pool");
let store = UserProfileStore::new(pool);
store.initialize_schema().await.expect("schema init");
store
}
#[tokio::test]
async fn test_initialize_schema_idempotent() {
let store = test_store().await;
// Second call should succeed without error.
store.initialize_schema().await.unwrap();
store.initialize_schema().await.unwrap();
}
#[tokio::test]
async fn test_get_returns_none_for_missing() {
let store = test_store().await;
let profile = store.get("nonexistent").await.unwrap();
assert!(profile.is_none());
}
#[tokio::test]
async fn test_upsert_and_get() {
let store = test_store().await;
let mut profile = UserProfile::blank("default_user");
profile.industry = Some("healthcare".to_string());
profile.role = Some("admin".to_string());
profile.expertise_level = Some(Level::Intermediate);
profile.communication_style = Some(CommStyle::Concise);
profile.recent_topics = vec!["reporting".to_string(), "compliance".to_string()];
profile.confidence = 0.65;
store.upsert(&profile).await.unwrap();
let loaded = store.get("default_user").await.unwrap().unwrap();
assert_eq!(loaded.user_id, "default_user");
assert_eq!(loaded.industry.as_deref(), Some("healthcare"));
assert_eq!(loaded.role.as_deref(), Some("admin"));
assert_eq!(loaded.expertise_level, Some(Level::Intermediate));
assert_eq!(loaded.communication_style, Some(CommStyle::Concise));
assert_eq!(loaded.recent_topics, vec!["reporting", "compliance"]);
assert!((loaded.confidence - 0.65).abs() < f32::EPSILON);
}
#[tokio::test]
async fn test_upsert_replaces_existing() {
let store = test_store().await;
let mut profile = UserProfile::blank("user1");
profile.industry = Some("tech".to_string());
store.upsert(&profile).await.unwrap();
profile.industry = Some("finance".to_string());
store.upsert(&profile).await.unwrap();
let loaded = store.get("user1").await.unwrap().unwrap();
assert_eq!(loaded.industry.as_deref(), Some("finance"));
}
#[tokio::test]
async fn test_update_field_scalar() {
let store = test_store().await;
let profile = UserProfile::blank("user2");
store.upsert(&profile).await.unwrap();
store
.update_field("user2", "industry", "education")
.await
.unwrap();
store
.update_field("user2", "role", "teacher")
.await
.unwrap();
let loaded = store.get("user2").await.unwrap().unwrap();
assert_eq!(loaded.industry.as_deref(), Some("education"));
assert_eq!(loaded.role.as_deref(), Some("teacher"));
}
#[tokio::test]
async fn test_update_field_confidence() {
let store = test_store().await;
let profile = UserProfile::blank("user3");
store.upsert(&profile).await.unwrap();
store
.update_field("user3", "confidence", "0.88")
.await
.unwrap();
let loaded = store.get("user3").await.unwrap().unwrap();
assert!((loaded.confidence - 0.88).abs() < f32::EPSILON);
}
#[tokio::test]
async fn test_update_field_rejects_unknown() {
let store = test_store().await;
let result = store.update_field("user", "evil_column", "oops").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_add_recent_topic_auto_creates_profile() {
let store = test_store().await;
// No profile exists yet.
store
.add_recent_topic("new_user", "data analysis", 5)
.await
.unwrap();
let loaded = store.get("new_user").await.unwrap().unwrap();
assert_eq!(loaded.recent_topics, vec!["data analysis"]);
}
#[tokio::test]
async fn test_add_recent_topic_dedup_and_trim() {
let store = test_store().await;
let profile = UserProfile::blank("user");
store.upsert(&profile).await.unwrap();
store.add_recent_topic("user", "topic_a", 3).await.unwrap();
store.add_recent_topic("user", "topic_b", 3).await.unwrap();
store.add_recent_topic("user", "topic_c", 3).await.unwrap();
// Duplicate — should move to front, not add.
store.add_recent_topic("user", "topic_a", 3).await.unwrap();
let loaded = store.get("user").await.unwrap().unwrap();
assert_eq!(
loaded.recent_topics,
vec!["topic_a", "topic_c", "topic_b"]
);
}
#[tokio::test]
async fn test_add_pain_point_trim() {
let store = test_store().await;
for i in 0..5 {
store
.add_pain_point("user", &format!("pain_{}", i), 3)
.await
.unwrap();
}
let loaded = store.get("user").await.unwrap().unwrap();
assert_eq!(loaded.active_pain_points.len(), 3);
// Most recent first.
assert_eq!(loaded.active_pain_points[0], "pain_4");
}
#[tokio::test]
async fn test_add_preferred_tool_trim() {
let store = test_store().await;
store
.add_preferred_tool("user", "python", 5)
.await
.unwrap();
store
.add_preferred_tool("user", "rust", 5)
.await
.unwrap();
// Duplicate — moved to front.
store
.add_preferred_tool("user", "python", 5)
.await
.unwrap();
let loaded = store.get("user").await.unwrap().unwrap();
assert_eq!(loaded.preferred_tools, vec!["python", "rust"]);
}
#[test]
fn test_level_round_trip() {
for level in [Level::Beginner, Level::Intermediate, Level::Expert] {
assert_eq!(Level::from_str_lossy(level.as_str()), Some(level));
}
assert_eq!(Level::from_str_lossy("unknown"), None);
}
#[test]
fn test_comm_style_round_trip() {
for style in [
CommStyle::Concise,
CommStyle::Detailed,
CommStyle::Formal,
CommStyle::Casual,
] {
assert_eq!(CommStyle::from_str_lossy(style.as_str()), Some(style));
}
assert_eq!(CommStyle::from_str_lossy("unknown"), None);
}
#[test]
fn test_profile_serialization() {
let mut p = UserProfile::blank("test_user");
p.industry = Some("logistics".into());
p.expertise_level = Some(Level::Expert);
p.communication_style = Some(CommStyle::Detailed);
p.recent_topics = vec!["exports".into(), "customs".into()];
let json = serde_json::to_string(&p).unwrap();
let decoded: UserProfile = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.user_id, "test_user");
assert_eq!(decoded.industry.as_deref(), Some("logistics"));
assert_eq!(decoded.expertise_level, Some(Level::Expert));
assert_eq!(decoded.communication_style, Some(CommStyle::Detailed));
assert_eq!(decoded.recent_topics, vec!["exports", "customs"]);
}
}