FIX-01: TrajectoryRecorderMiddleware 注册到 create_middleware_chain() (@650优先级) FIX-02: industryStore 接入 ButlerPanel 行业专长展示 + 自动拉取 FIX-03: 桌面端知识库搜索 saas-knowledge mixin + VikingPanel SaaS KB UI FIX-04: webhook 迁移标注 deprecated + 添加 down migration 注释 FIX-05: Admin Knowledge 添加结构化数据 Tab (CRUD + 行浏览) FIX-06: PersistentMemoryStore 精化 dead_code 标注 (完整迁移留后续) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1029 lines
35 KiB
Rust
1029 lines
35 KiB
Rust
//! Memory store implementation
|
|
|
|
use sqlx::SqlitePool;
|
|
use zclaw_types::{AgentConfig, AgentId, SessionId, Message, Result, ZclawError, HandRun, HandRunId, HandRunStatus, HandRunFilter};
|
|
|
|
/// Memory store for persisting ZCLAW data
|
|
pub struct MemoryStore {
|
|
pool: SqlitePool,
|
|
}
|
|
|
|
impl MemoryStore {
|
|
/// Create a new memory store with the given database path
|
|
pub async fn new(database_url: &str) -> Result<Self> {
|
|
// Ensure parent directory exists for file-based SQLite databases
|
|
Self::ensure_database_dir(database_url)?;
|
|
|
|
let pool = SqlitePool::connect(database_url).await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
let store = Self { pool };
|
|
store.run_migrations().await?;
|
|
Ok(store)
|
|
}
|
|
|
|
/// Get a clone of the underlying SQLite pool.
|
|
///
|
|
/// Used by subsystems (e.g. `TrajectoryStore`) that need to share the
|
|
/// same database connection pool for their own tables.
|
|
pub fn pool(&self) -> SqlitePool {
|
|
self.pool.clone()
|
|
}
|
|
|
|
/// Ensure the parent directory for the database file exists
|
|
fn ensure_database_dir(database_url: &str) -> Result<()> {
|
|
// Parse SQLite URL to extract file path
|
|
// Format: sqlite:/path/to/db or sqlite://path/to/db
|
|
if database_url.starts_with("sqlite:") {
|
|
let path_part = database_url.strip_prefix("sqlite:")
|
|
.ok_or_else(|| ZclawError::StorageError(
|
|
format!("Invalid database URL format: {}", database_url)
|
|
))?;
|
|
|
|
// Skip in-memory databases
|
|
if path_part == ":memory:" {
|
|
return Ok(());
|
|
}
|
|
|
|
// Remove query parameters (e.g., ?mode=rwc)
|
|
let path_without_query = path_part.split('?').next()
|
|
.ok_or_else(|| ZclawError::StorageError(
|
|
format!("Invalid database URL path: {}", path_part)
|
|
))?;
|
|
|
|
// Handle both absolute and relative paths
|
|
let path = std::path::Path::new(path_without_query);
|
|
|
|
// Get parent directory
|
|
if let Some(parent) = path.parent() {
|
|
if !parent.exists() {
|
|
std::fs::create_dir_all(parent)
|
|
.map_err(|e| ZclawError::StorageError(
|
|
format!("Failed to create database directory {}: {}", parent.display(), e)
|
|
))?;
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Create an in-memory database (for testing)
|
|
pub async fn in_memory() -> Result<Self> {
|
|
Self::new("sqlite::memory:").await
|
|
}
|
|
|
|
/// Run database migrations
|
|
async fn run_migrations(&self) -> Result<()> {
|
|
sqlx::query(crate::schema::CREATE_SCHEMA)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
// Run incremental migrations (ALTER … ADD COLUMN is idempotent with error suppression)
|
|
for migration in crate::schema::MIGRATIONS {
|
|
if let Err(e) = sqlx::query(migration)
|
|
.execute(&self.pool)
|
|
.await
|
|
{
|
|
// Column already exists — expected on repeated runs
|
|
tracing::debug!("[MemoryStore] Migration skipped (already applied): {}", e);
|
|
}
|
|
}
|
|
|
|
// Persist current schema version
|
|
let version = crate::schema::SCHEMA_VERSION;
|
|
sqlx::query("INSERT OR REPLACE INTO schema_version (version) VALUES (?)")
|
|
.bind(version)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// === Agent CRUD ===
|
|
|
|
/// Save an agent configuration
|
|
pub async fn save_agent(&self, agent: &AgentConfig) -> Result<()> {
|
|
let config_json = serde_json::to_string(agent)?;
|
|
let id = agent.id.to_string();
|
|
let name = &agent.name;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO agents (id, name, config, created_at, updated_at)
|
|
VALUES (?, ?, ?, datetime('now'), datetime('now'))
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
name = excluded.name,
|
|
config = excluded.config,
|
|
updated_at = datetime('now')
|
|
"#,
|
|
)
|
|
.bind(&id)
|
|
.bind(name)
|
|
.bind(&config_json)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Load an agent by ID
|
|
pub async fn load_agent(&self, id: &AgentId) -> Result<Option<AgentConfig>> {
|
|
let id_str = id.to_string();
|
|
|
|
let row = sqlx::query_as::<_, (String,)>(
|
|
"SELECT config FROM agents WHERE id = ?"
|
|
)
|
|
.bind(&id_str)
|
|
.fetch_optional(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
match row {
|
|
Some((config,)) => {
|
|
let agent: AgentConfig = serde_json::from_str(&config)?;
|
|
Ok(Some(agent))
|
|
}
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
|
|
/// List all agents
|
|
pub async fn list_agents(&self) -> Result<Vec<AgentConfig>> {
|
|
let rows = sqlx::query_as::<_, (String,)>(
|
|
"SELECT config FROM agents"
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
let agents = rows
|
|
.into_iter()
|
|
.filter_map(|(config,)| serde_json::from_str(&config).ok())
|
|
.collect();
|
|
Ok(agents)
|
|
}
|
|
|
|
/// Delete an agent
|
|
pub async fn delete_agent(&self, id: &AgentId) -> Result<()> {
|
|
let id_str = id.to_string();
|
|
|
|
sqlx::query("DELETE FROM agents WHERE id = ?")
|
|
.bind(&id_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Persist runtime state and message count for an agent
|
|
pub async fn update_agent_runtime(
|
|
&self,
|
|
id: &AgentId,
|
|
state: &str,
|
|
message_count: u64,
|
|
) -> Result<()> {
|
|
let id_str = id.to_string();
|
|
sqlx::query(
|
|
"UPDATE agents SET state = ?, message_count = ?, updated_at = datetime('now') WHERE id = ?",
|
|
)
|
|
.bind(state)
|
|
.bind(message_count as i64)
|
|
.bind(&id_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
/// List all agents with their persisted runtime state
|
|
/// Returns (AgentConfig, state_string, message_count)
|
|
pub async fn list_agents_with_runtime(&self) -> Result<Vec<(AgentConfig, String, u64)>> {
|
|
let rows = sqlx::query_as::<_, (String, String, i64)>(
|
|
"SELECT config, state, message_count FROM agents",
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
let agents = rows
|
|
.into_iter()
|
|
.filter_map(|(config, state, mc)| {
|
|
let agent: AgentConfig = serde_json::from_str(&config).ok()?;
|
|
Some((agent, state, mc as u64))
|
|
})
|
|
.collect();
|
|
Ok(agents)
|
|
}
|
|
|
|
// === Session Management ===
|
|
|
|
/// Create a new session for an agent
|
|
pub async fn create_session(&self, agent_id: &AgentId) -> Result<SessionId> {
|
|
let session_id = SessionId::new();
|
|
let session_str = session_id.to_string();
|
|
let agent_str = agent_id.to_string();
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO sessions (id, agent_id, created_at, updated_at)
|
|
VALUES (?, ?, datetime('now'), datetime('now'))
|
|
"#,
|
|
)
|
|
.bind(&session_str)
|
|
.bind(&agent_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(session_id)
|
|
}
|
|
|
|
/// Get an existing session or create it with the given ID.
|
|
///
|
|
/// This is the critical bridge between frontend session IDs and the database.
|
|
/// The frontend generates a UUID (`sessionKey`) and sends it with each message.
|
|
/// Without this method, the kernel would create a *different* session ID on
|
|
/// every call, so conversation history would never be found.
|
|
pub async fn get_or_create_session(
|
|
&self,
|
|
session_id: &SessionId,
|
|
agent_id: &AgentId,
|
|
) -> Result<SessionId> {
|
|
let session_str = session_id.to_string();
|
|
let agent_str = agent_id.to_string();
|
|
|
|
// Check if session already exists
|
|
let exists: bool = sqlx::query_scalar(
|
|
"SELECT COUNT(*) > 0 FROM sessions WHERE id = ?",
|
|
)
|
|
.bind(&session_str)
|
|
.fetch_one(&self.pool)
|
|
.await
|
|
.unwrap_or(false);
|
|
|
|
if exists {
|
|
return Ok(session_id.clone());
|
|
}
|
|
|
|
// Create session with the frontend-provided ID
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO sessions (id, agent_id, created_at, updated_at)
|
|
VALUES (?, ?, datetime('now'), datetime('now'))
|
|
"#,
|
|
)
|
|
.bind(&session_str)
|
|
.bind(&agent_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(session_id.clone())
|
|
}
|
|
|
|
/// Append a message to a session
|
|
pub async fn append_message(&self, session_id: &SessionId, message: &Message) -> Result<()> {
|
|
let session_str = session_id.to_string();
|
|
let message_json = serde_json::to_string(message)?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO messages (session_id, seq, content, created_at)
|
|
SELECT ?, COALESCE(MAX(seq), 0) + 1, ?, datetime('now')
|
|
FROM messages WHERE session_id = ?
|
|
"#,
|
|
)
|
|
.bind(&session_str)
|
|
.bind(&message_json)
|
|
.bind(&session_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
// Update session updated_at
|
|
sqlx::query("UPDATE sessions SET updated_at = datetime('now') WHERE id = ?")
|
|
.bind(&session_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Get all messages for a session
|
|
pub async fn get_messages(&self, session_id: &SessionId) -> Result<Vec<Message>> {
|
|
let session_str = session_id.to_string();
|
|
|
|
let rows = sqlx::query_as::<_, (String,)>(
|
|
"SELECT content FROM messages WHERE session_id = ? ORDER BY seq"
|
|
)
|
|
.bind(&session_str)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
let messages = rows
|
|
.into_iter()
|
|
.filter_map(|(content,)| serde_json::from_str(&content).ok())
|
|
.collect();
|
|
Ok(messages)
|
|
}
|
|
|
|
/// Get messages for a session with pagination
|
|
pub async fn get_messages_paginated(
|
|
&self,
|
|
session_id: &SessionId,
|
|
limit: u32,
|
|
offset: u32,
|
|
) -> Result<Vec<Message>> {
|
|
let session_str = session_id.to_string();
|
|
|
|
let rows = sqlx::query_as::<_, (String,)>(
|
|
"SELECT content FROM messages WHERE session_id = ? ORDER BY seq LIMIT ? OFFSET ?"
|
|
)
|
|
.bind(&session_str)
|
|
.bind(limit)
|
|
.bind(offset)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
let messages = rows
|
|
.into_iter()
|
|
.filter_map(|(content,)| serde_json::from_str(&content).ok())
|
|
.collect();
|
|
Ok(messages)
|
|
}
|
|
|
|
/// Count messages in a session
|
|
pub async fn count_messages(&self, session_id: &SessionId) -> Result<u64> {
|
|
let session_str = session_id.to_string();
|
|
let count: i64 = sqlx::query_scalar(
|
|
"SELECT COUNT(*) FROM messages WHERE session_id = ?"
|
|
)
|
|
.bind(&session_str)
|
|
.fetch_one(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(count as u64)
|
|
}
|
|
|
|
// === KV Store ===
|
|
|
|
/// Store a key-value pair for an agent
|
|
pub async fn kv_store(&self, agent_id: &AgentId, key: &str, value: &serde_json::Value) -> Result<()> {
|
|
let agent_str = agent_id.to_string();
|
|
let value_json = serde_json::to_string(value)?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO kv_store (agent_id, key, value, updated_at)
|
|
VALUES (?, ?, ?, datetime('now'))
|
|
ON CONFLICT(agent_id, key) DO UPDATE SET
|
|
value = excluded.value,
|
|
updated_at = datetime('now')
|
|
"#,
|
|
)
|
|
.bind(&agent_str)
|
|
.bind(key)
|
|
.bind(&value_json)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Recall a value by key
|
|
pub async fn kv_recall(&self, agent_id: &AgentId, key: &str) -> Result<Option<serde_json::Value>> {
|
|
let agent_str = agent_id.to_string();
|
|
|
|
let row = sqlx::query_as::<_, (String,)>(
|
|
"SELECT value FROM kv_store WHERE agent_id = ? AND key = ?"
|
|
)
|
|
.bind(&agent_str)
|
|
.bind(key)
|
|
.fetch_optional(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
match row {
|
|
Some((value,)) => {
|
|
let v: serde_json::Value = serde_json::from_str(&value)?;
|
|
Ok(Some(v))
|
|
}
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
|
|
/// List all keys for an agent
|
|
pub async fn kv_list(&self, agent_id: &AgentId) -> Result<Vec<String>> {
|
|
let agent_str = agent_id.to_string();
|
|
|
|
let rows = sqlx::query_as::<_, (String,)>(
|
|
"SELECT key FROM kv_store WHERE agent_id = ?"
|
|
)
|
|
.bind(&agent_str)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(rows.into_iter().map(|(key,)| key).collect())
|
|
}
|
|
|
|
// === Hand Run Tracking ===
|
|
|
|
/// Save a new hand run record
|
|
pub async fn save_hand_run(&self, run: &HandRun) -> Result<()> {
|
|
let id = run.id.to_string();
|
|
let trigger_source = serde_json::to_string(&run.trigger_source)?;
|
|
let params = serde_json::to_string(&run.params)?;
|
|
let result = run.result.as_ref().map(|v| serde_json::to_string(v)).transpose()?;
|
|
let error = run.error.as_ref().map(|e| serde_json::to_string(e)).transpose()?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO hand_runs (id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
"#,
|
|
)
|
|
.bind(&id)
|
|
.bind(&run.hand_name)
|
|
.bind(&trigger_source)
|
|
.bind(¶ms)
|
|
.bind(run.status.to_string())
|
|
.bind(result.as_deref())
|
|
.bind(error.as_deref())
|
|
.bind(run.duration_ms.map(|d| d as i64))
|
|
.bind(&run.created_at)
|
|
.bind(run.started_at.as_deref())
|
|
.bind(run.completed_at.as_deref())
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Update an existing hand run record
|
|
pub async fn update_hand_run(&self, run: &HandRun) -> Result<()> {
|
|
let id = run.id.to_string();
|
|
let trigger_source = serde_json::to_string(&run.trigger_source)?;
|
|
let params = serde_json::to_string(&run.params)?;
|
|
let result = run.result.as_ref().map(|v| serde_json::to_string(v)).transpose()?;
|
|
let error = run.error.as_ref().map(|e| serde_json::to_string(e)).transpose()?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
UPDATE hand_runs SET
|
|
hand_name = ?, trigger_source = ?, params = ?, status = ?,
|
|
result = ?, error = ?, duration_ms = ?,
|
|
started_at = ?, completed_at = ?
|
|
WHERE id = ?
|
|
"#,
|
|
)
|
|
.bind(&run.hand_name)
|
|
.bind(&trigger_source)
|
|
.bind(¶ms)
|
|
.bind(run.status.to_string())
|
|
.bind(result.as_deref())
|
|
.bind(error.as_deref())
|
|
.bind(run.duration_ms.map(|d| d as i64))
|
|
.bind(run.started_at.as_deref())
|
|
.bind(run.completed_at.as_deref())
|
|
.bind(&id)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Get a hand run by ID
|
|
pub async fn get_hand_run(&self, id: &HandRunId) -> Result<Option<HandRun>> {
|
|
let id_str = id.to_string();
|
|
|
|
let row = sqlx::query_as::<_, (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>)>(
|
|
"SELECT id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at FROM hand_runs WHERE id = ?"
|
|
)
|
|
.bind(&id_str)
|
|
.fetch_optional(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
match row {
|
|
Some(r) => Ok(Some(Self::row_to_hand_run(r)?)),
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
|
|
/// List hand runs with optional filter
|
|
pub async fn list_hand_runs(&self, filter: &HandRunFilter) -> Result<Vec<HandRun>> {
|
|
let mut query = String::from(
|
|
"SELECT id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at FROM hand_runs WHERE 1=1"
|
|
);
|
|
let mut bind_values: Vec<String> = Vec::new();
|
|
|
|
if let Some(ref hand_name) = filter.hand_name {
|
|
query.push_str(" AND hand_name = ?");
|
|
bind_values.push(hand_name.clone());
|
|
}
|
|
|
|
if let Some(ref status) = filter.status {
|
|
query.push_str(" AND status = ?");
|
|
bind_values.push(status.to_string());
|
|
}
|
|
|
|
query.push_str(" ORDER BY created_at DESC");
|
|
|
|
if let Some(limit) = filter.limit {
|
|
query.push_str(&format!(" LIMIT {}", limit));
|
|
}
|
|
if let Some(offset) = filter.offset {
|
|
query.push_str(&format!(" OFFSET {}", offset));
|
|
}
|
|
|
|
let mut sql_query = sqlx::query_as::<_, (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>)>(&query);
|
|
|
|
for val in &bind_values {
|
|
sql_query = sql_query.bind(val);
|
|
}
|
|
|
|
let rows = sql_query
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
rows.into_iter()
|
|
.map(|r| Self::row_to_hand_run(r))
|
|
.collect()
|
|
}
|
|
|
|
/// Count hand runs matching filter
|
|
pub async fn count_hand_runs(&self, filter: &HandRunFilter) -> Result<u32> {
|
|
let mut query = String::from("SELECT COUNT(*) FROM hand_runs WHERE 1=1");
|
|
let mut bind_values: Vec<String> = Vec::new();
|
|
|
|
if let Some(ref hand_name) = filter.hand_name {
|
|
query.push_str(" AND hand_name = ?");
|
|
bind_values.push(hand_name.clone());
|
|
}
|
|
if let Some(ref status) = filter.status {
|
|
query.push_str(" AND status = ?");
|
|
bind_values.push(status.to_string());
|
|
}
|
|
|
|
let mut sql_query = sqlx::query_scalar::<_, i64>(&query);
|
|
for val in &bind_values {
|
|
sql_query = sql_query.bind(val);
|
|
}
|
|
|
|
let count = sql_query
|
|
.fetch_one(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(count as u32)
|
|
}
|
|
|
|
// === Fact CRUD ===
|
|
|
|
/// Store extracted facts for an agent (upsert by id).
|
|
pub async fn store_facts(&self, agent_id: &str, facts: &[crate::fact::Fact]) -> Result<()> {
|
|
for fact in facts {
|
|
let category_str = serde_json::to_string(&fact.category)
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
// Trim the JSON quotes from serialized enum variant
|
|
let category_clean = category_str.trim_matches('"');
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO facts (id, agent_id, content, category, confidence, source_session, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
content = excluded.content,
|
|
category = excluded.category,
|
|
confidence = excluded.confidence,
|
|
source_session = excluded.source_session
|
|
"#,
|
|
)
|
|
.bind(&fact.id)
|
|
.bind(agent_id)
|
|
.bind(&fact.content)
|
|
.bind(category_clean)
|
|
.bind(fact.confidence)
|
|
.bind(&fact.source)
|
|
.bind(fact.created_at as i64)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Get top facts for an agent, ordered by confidence descending.
|
|
pub async fn get_top_facts(&self, agent_id: &str, limit: usize) -> Result<Vec<crate::fact::Fact>> {
|
|
let rows = sqlx::query_as::<_, (String, String, String, f64, Option<String>, i64)>(
|
|
r#"
|
|
SELECT id, content, category, confidence, source_session, created_at
|
|
FROM facts
|
|
WHERE agent_id = ?
|
|
ORDER BY confidence DESC
|
|
LIMIT ?
|
|
"#,
|
|
)
|
|
.bind(agent_id)
|
|
.bind(limit as i64)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
let mut facts = Vec::with_capacity(rows.len());
|
|
for (id, content, category_str, confidence, source, created_at) in rows {
|
|
let category: crate::fact::FactCategory = serde_json::from_value(
|
|
serde_json::Value::String(category_str)
|
|
).map_err(|e| ZclawError::StorageError(format!("Invalid category: {}", e)))?;
|
|
|
|
facts.push(crate::fact::Fact {
|
|
id,
|
|
content,
|
|
category,
|
|
confidence,
|
|
created_at: created_at as u64,
|
|
source,
|
|
});
|
|
}
|
|
Ok(facts)
|
|
}
|
|
|
|
fn row_to_hand_run(
|
|
row: (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>),
|
|
) -> Result<HandRun> {
|
|
let (id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at) = row;
|
|
|
|
let run_id: HandRunId = id.parse()
|
|
.map_err(|e| ZclawError::StorageError(format!("Invalid HandRunId: {}", e)))?;
|
|
let trigger: zclaw_types::TriggerSource = serde_json::from_str(&trigger_source)?;
|
|
let params_val: serde_json::Value = serde_json::from_str(¶ms)?;
|
|
let run_status: HandRunStatus = status.parse()
|
|
.map_err(|e| ZclawError::StorageError(e))?;
|
|
let result_val: Option<serde_json::Value> = result.map(|r| serde_json::from_str(&r)).transpose()?;
|
|
let error_val: Option<String> = error.as_ref()
|
|
.map(|e| serde_json::from_str::<String>(e))
|
|
.transpose()
|
|
.unwrap_or_else(|_| error.clone());
|
|
|
|
Ok(HandRun {
|
|
id: run_id,
|
|
hand_name,
|
|
trigger_source: trigger,
|
|
params: params_val,
|
|
status: run_status,
|
|
result: result_val,
|
|
error: error_val,
|
|
duration_ms: duration_ms.map(|d| d as u64),
|
|
created_at,
|
|
started_at,
|
|
completed_at,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use zclaw_types::{AgentConfig, ModelConfig};
|
|
|
|
fn create_test_agent_config(name: &str) -> AgentConfig {
|
|
AgentConfig {
|
|
id: AgentId::new(),
|
|
name: name.to_string(),
|
|
description: None,
|
|
model: ModelConfig::default(),
|
|
system_prompt: None,
|
|
soul: None,
|
|
capabilities: vec![],
|
|
tools: vec![],
|
|
max_tokens: None,
|
|
temperature: None,
|
|
workspace: None,
|
|
compaction_threshold: None,
|
|
enabled: true,
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_in_memory_store_creation() {
|
|
let store = MemoryStore::in_memory().await;
|
|
assert!(store.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_save_and_load_agent() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("test-agent");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
|
|
let loaded = store.load_agent(&config.id).await.unwrap();
|
|
assert!(loaded.is_some());
|
|
let loaded = loaded.unwrap();
|
|
assert_eq!(loaded.id, config.id);
|
|
assert_eq!(loaded.name, config.name);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_load_nonexistent_agent() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let fake_id = AgentId::new();
|
|
|
|
let result = store.load_agent(&fake_id).await.unwrap();
|
|
assert!(result.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_save_agent_updates_existing() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let mut config = create_test_agent_config("original");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
|
|
config.name = "updated".to_string();
|
|
store.save_agent(&config).await.unwrap();
|
|
|
|
let loaded = store.load_agent(&config.id).await.unwrap().unwrap();
|
|
assert_eq!(loaded.name, "updated");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_list_agents() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
|
|
let config1 = create_test_agent_config("agent1");
|
|
let config2 = create_test_agent_config("agent2");
|
|
|
|
store.save_agent(&config1).await.unwrap();
|
|
store.save_agent(&config2).await.unwrap();
|
|
|
|
let agents = store.list_agents().await.unwrap();
|
|
assert_eq!(agents.len(), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_delete_agent() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("to-delete");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
store.delete_agent(&config.id).await.unwrap();
|
|
|
|
let loaded = store.load_agent(&config.id).await.unwrap();
|
|
assert!(loaded.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_delete_nonexistent_agent_succeeds() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let fake_id = AgentId::new();
|
|
|
|
// Deleting nonexistent agent should succeed (idempotent)
|
|
let result = store.delete_agent(&fake_id).await;
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_create_session() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("session-test");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
|
|
let session_id = store.create_session(&config.id).await.unwrap();
|
|
assert!(!session_id.as_uuid().is_nil());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_append_and_get_messages() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("msg-test");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
let session_id = store.create_session(&config.id).await.unwrap();
|
|
|
|
let msg1 = Message::user("Hello");
|
|
let msg2 = Message::assistant("Hi there!");
|
|
|
|
store.append_message(&session_id, &msg1).await.unwrap();
|
|
store.append_message(&session_id, &msg2).await.unwrap();
|
|
|
|
let messages = store.get_messages(&session_id).await.unwrap();
|
|
assert_eq!(messages.len(), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_message_ordering() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("order-test");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
let session_id = store.create_session(&config.id).await.unwrap();
|
|
|
|
for i in 0..10 {
|
|
let msg = Message::user(format!("Message {}", i));
|
|
store.append_message(&session_id, &msg).await.unwrap();
|
|
}
|
|
|
|
let messages = store.get_messages(&session_id).await.unwrap();
|
|
assert_eq!(messages.len(), 10);
|
|
|
|
// Verify ordering
|
|
for (i, msg) in messages.iter().enumerate() {
|
|
if let Message::User { content } = msg {
|
|
assert_eq!(content, &format!("Message {}", i));
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_kv_store_and_recall() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("kv-test");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
|
|
let value = serde_json::json!({"key": "value", "number": 42});
|
|
store.kv_store(&config.id, "test-key", &value).await.unwrap();
|
|
|
|
let recalled = store.kv_recall(&config.id, "test-key").await.unwrap();
|
|
assert!(recalled.is_some());
|
|
assert_eq!(recalled.unwrap(), value);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_kv_recall_nonexistent() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("kv-missing");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
|
|
let result = store.kv_recall(&config.id, "nonexistent").await.unwrap();
|
|
assert!(result.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_kv_update_existing() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("kv-update");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
|
|
let value1 = serde_json::json!({"version": 1});
|
|
let value2 = serde_json::json!({"version": 2});
|
|
|
|
store.kv_store(&config.id, "key", &value1).await.unwrap();
|
|
store.kv_store(&config.id, "key", &value2).await.unwrap();
|
|
|
|
let recalled = store.kv_recall(&config.id, "key").await.unwrap().unwrap();
|
|
assert_eq!(recalled["version"], 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_kv_list() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("kv-list");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
|
|
store.kv_store(&config.id, "key1", &serde_json::json!(1)).await.unwrap();
|
|
store.kv_store(&config.id, "key2", &serde_json::json!(2)).await.unwrap();
|
|
store.kv_store(&config.id, "key3", &serde_json::json!(3)).await.unwrap();
|
|
|
|
let keys = store.kv_list(&config.id).await.unwrap();
|
|
assert_eq!(keys.len(), 3);
|
|
assert!(keys.contains(&"key1".to_string()));
|
|
assert!(keys.contains(&"key2".to_string()));
|
|
assert!(keys.contains(&"key3".to_string()));
|
|
}
|
|
|
|
// === Edge Case Tests ===
|
|
|
|
#[tokio::test]
|
|
async fn test_agent_with_empty_name() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("");
|
|
|
|
// Empty name should still work (validation is elsewhere)
|
|
let result = store.save_agent(&config).await;
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_agent_with_special_characters_in_name() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("agent-with-特殊字符-🎉");
|
|
|
|
let result = store.save_agent(&config).await;
|
|
assert!(result.is_ok());
|
|
|
|
let loaded = store.load_agent(&config.id).await.unwrap().unwrap();
|
|
assert_eq!(loaded.name, "agent-with-特殊字符-🎉");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_large_message_content() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("large-msg");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
let session_id = store.create_session(&config.id).await.unwrap();
|
|
|
|
// Create a large message (100KB)
|
|
let large_content = "x".repeat(100_000);
|
|
let msg = Message::user(&large_content);
|
|
|
|
let result = store.append_message(&session_id, &msg).await;
|
|
assert!(result.is_ok());
|
|
|
|
let messages = store.get_messages(&session_id).await.unwrap();
|
|
assert_eq!(messages.len(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_message_with_tool_use() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("tool-msg");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
let session_id = store.create_session(&config.id).await.unwrap();
|
|
|
|
let tool_input = serde_json::json!({"query": "test", "options": {"limit": 10}});
|
|
let msg = Message::tool_use("call-123", zclaw_types::ToolId::new("search"), tool_input.clone());
|
|
|
|
store.append_message(&session_id, &msg).await.unwrap();
|
|
|
|
let messages = store.get_messages(&session_id).await.unwrap();
|
|
assert_eq!(messages.len(), 1);
|
|
|
|
if let Message::ToolUse { id, tool, input } = &messages[0] {
|
|
assert_eq!(id, "call-123");
|
|
assert_eq!(tool.as_str(), "search");
|
|
assert_eq!(*input, tool_input);
|
|
} else {
|
|
panic!("Expected ToolUse message");
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_message_with_tool_result() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("tool-result");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
let session_id = store.create_session(&config.id).await.unwrap();
|
|
|
|
let output = serde_json::json!({"results": ["a", "b", "c"]});
|
|
let msg = Message::tool_result("call-123", zclaw_types::ToolId::new("search"), output.clone(), false);
|
|
|
|
store.append_message(&session_id, &msg).await.unwrap();
|
|
|
|
let messages = store.get_messages(&session_id).await.unwrap();
|
|
assert_eq!(messages.len(), 1);
|
|
|
|
if let Message::ToolResult { tool_call_id, tool, output: o, is_error } = &messages[0] {
|
|
assert_eq!(tool_call_id, "call-123");
|
|
assert_eq!(tool.as_str(), "search");
|
|
assert_eq!(*o, output);
|
|
assert!(!is_error);
|
|
} else {
|
|
panic!("Expected ToolResult message");
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_message_with_thinking() {
|
|
let store = MemoryStore::in_memory().await.unwrap();
|
|
let config = create_test_agent_config("thinking");
|
|
|
|
store.save_agent(&config).await.unwrap();
|
|
let session_id = store.create_session(&config.id).await.unwrap();
|
|
|
|
let msg = Message::assistant_with_thinking("Final answer", "My reasoning...");
|
|
|
|
store.append_message(&session_id, &msg).await.unwrap();
|
|
|
|
let messages = store.get_messages(&session_id).await.unwrap();
|
|
assert_eq!(messages.len(), 1);
|
|
|
|
if let Message::Assistant { content, thinking } = &messages[0] {
|
|
assert_eq!(content, "Final answer");
|
|
assert_eq!(thinking.as_ref().unwrap(), "My reasoning...");
|
|
} else {
|
|
panic!("Expected Assistant message");
|
|
}
|
|
}
|
|
}
|