Files
zclaw_openfang/crates/zclaw-memory/src/store.rs
iven 978dc5cdd8
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix(安全): 修复HTML导出中的XSS漏洞并清理调试日志
refactor(日志): 替换console.log为tracing日志系统
style(代码): 移除未使用的代码和依赖项

feat(测试): 添加端到端测试文档和CI工作流
docs(变更日志): 更新CHANGELOG.md记录0.1.0版本变更

perf(构建): 更新依赖版本并优化CI流程
2026-03-26 19:49:03 +08:00

618 lines
20 KiB
Rust

//! Memory store implementation
use sqlx::SqlitePool;
use zclaw_types::{AgentConfig, AgentId, SessionId, Message, Result, ZclawError};
/// Memory store for persisting ZCLAW data
pub struct MemoryStore {
pool: SqlitePool,
}
impl MemoryStore {
/// Create a new memory store with the given database path
pub async fn new(database_url: &str) -> Result<Self> {
// Ensure parent directory exists for file-based SQLite databases
Self::ensure_database_dir(database_url)?;
let pool = SqlitePool::connect(database_url).await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
let store = Self { pool };
store.run_migrations().await?;
Ok(store)
}
/// Ensure the parent directory for the database file exists
fn ensure_database_dir(database_url: &str) -> Result<()> {
// Parse SQLite URL to extract file path
// Format: sqlite:/path/to/db or sqlite://path/to/db
if database_url.starts_with("sqlite:") {
let path_part = database_url.strip_prefix("sqlite:")
.ok_or_else(|| ZclawError::StorageError(
format!("Invalid database URL format: {}", database_url)
))?;
// Skip in-memory databases
if path_part == ":memory:" {
return Ok(());
}
// Remove query parameters (e.g., ?mode=rwc)
let path_without_query = path_part.split('?').next()
.ok_or_else(|| ZclawError::StorageError(
format!("Invalid database URL path: {}", path_part)
))?;
// Handle both absolute and relative paths
let path = std::path::Path::new(path_without_query);
// Get parent directory
if let Some(parent) = path.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent)
.map_err(|e| ZclawError::StorageError(
format!("Failed to create database directory {}: {}", parent.display(), e)
))?;
}
}
}
Ok(())
}
/// Create an in-memory database (for testing)
pub async fn in_memory() -> Result<Self> {
Self::new("sqlite::memory:").await
}
/// Run database migrations
async fn run_migrations(&self) -> Result<()> {
sqlx::query(crate::schema::CREATE_SCHEMA)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
Ok(())
}
// === Agent CRUD ===
/// Save an agent configuration
pub async fn save_agent(&self, agent: &AgentConfig) -> Result<()> {
let config_json = serde_json::to_string(agent)?;
let id = agent.id.to_string();
let name = &agent.name;
sqlx::query(
r#"
INSERT INTO agents (id, name, config, created_at, updated_at)
VALUES (?, ?, ?, datetime('now'), datetime('now'))
ON CONFLICT(id) DO UPDATE SET
name = excluded.name,
config = excluded.config,
updated_at = datetime('now')
"#,
)
.bind(&id)
.bind(name)
.bind(&config_json)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
Ok(())
}
/// Load an agent by ID
pub async fn load_agent(&self, id: &AgentId) -> Result<Option<AgentConfig>> {
let id_str = id.to_string();
let row = sqlx::query_as::<_, (String,)>(
"SELECT config FROM agents WHERE id = ?"
)
.bind(&id_str)
.fetch_optional(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
match row {
Some((config,)) => {
let agent: AgentConfig = serde_json::from_str(&config)?;
Ok(Some(agent))
}
None => Ok(None),
}
}
/// List all agents
pub async fn list_agents(&self) -> Result<Vec<AgentConfig>> {
let rows = sqlx::query_as::<_, (String,)>(
"SELECT config FROM agents"
)
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
let agents = rows
.into_iter()
.filter_map(|(config,)| serde_json::from_str(&config).ok())
.collect();
Ok(agents)
}
/// Delete an agent
pub async fn delete_agent(&self, id: &AgentId) -> Result<()> {
let id_str = id.to_string();
sqlx::query("DELETE FROM agents WHERE id = ?")
.bind(&id_str)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
Ok(())
}
// === Session Management ===
/// Create a new session for an agent
pub async fn create_session(&self, agent_id: &AgentId) -> Result<SessionId> {
let session_id = SessionId::new();
let session_str = session_id.to_string();
let agent_str = agent_id.to_string();
sqlx::query(
r#"
INSERT INTO sessions (id, agent_id, created_at, updated_at)
VALUES (?, ?, datetime('now'), datetime('now'))
"#,
)
.bind(&session_str)
.bind(&agent_str)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
Ok(session_id)
}
/// Append a message to a session
pub async fn append_message(&self, session_id: &SessionId, message: &Message) -> Result<()> {
let session_str = session_id.to_string();
let message_json = serde_json::to_string(message)?;
sqlx::query(
r#"
INSERT INTO messages (session_id, seq, content, created_at)
SELECT ?, COALESCE(MAX(seq), 0) + 1, ?, datetime('now')
FROM messages WHERE session_id = ?
"#,
)
.bind(&session_str)
.bind(&message_json)
.bind(&session_str)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
// Update session updated_at
sqlx::query("UPDATE sessions SET updated_at = datetime('now') WHERE id = ?")
.bind(&session_str)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
Ok(())
}
/// Get all messages for a session
pub async fn get_messages(&self, session_id: &SessionId) -> Result<Vec<Message>> {
let session_str = session_id.to_string();
let rows = sqlx::query_as::<_, (String,)>(
"SELECT content FROM messages WHERE session_id = ? ORDER BY seq"
)
.bind(&session_str)
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
let messages = rows
.into_iter()
.filter_map(|(content,)| serde_json::from_str(&content).ok())
.collect();
Ok(messages)
}
// === KV Store ===
/// Store a key-value pair for an agent
pub async fn kv_store(&self, agent_id: &AgentId, key: &str, value: &serde_json::Value) -> Result<()> {
let agent_str = agent_id.to_string();
let value_json = serde_json::to_string(value)?;
sqlx::query(
r#"
INSERT INTO kv_store (agent_id, key, value, updated_at)
VALUES (?, ?, ?, datetime('now'))
ON CONFLICT(agent_id, key) DO UPDATE SET
value = excluded.value,
updated_at = datetime('now')
"#,
)
.bind(&agent_str)
.bind(key)
.bind(&value_json)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
Ok(())
}
/// Recall a value by key
pub async fn kv_recall(&self, agent_id: &AgentId, key: &str) -> Result<Option<serde_json::Value>> {
let agent_str = agent_id.to_string();
let row = sqlx::query_as::<_, (String,)>(
"SELECT value FROM kv_store WHERE agent_id = ? AND key = ?"
)
.bind(&agent_str)
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
match row {
Some((value,)) => {
let v: serde_json::Value = serde_json::from_str(&value)?;
Ok(Some(v))
}
None => Ok(None),
}
}
/// List all keys for an agent
pub async fn kv_list(&self, agent_id: &AgentId) -> Result<Vec<String>> {
let agent_str = agent_id.to_string();
let rows = sqlx::query_as::<_, (String,)>(
"SELECT key FROM kv_store WHERE agent_id = ?"
)
.bind(&agent_str)
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
Ok(rows.into_iter().map(|(key,)| key).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use zclaw_types::{AgentConfig, ModelConfig};
fn create_test_agent_config(name: &str) -> AgentConfig {
AgentConfig {
id: AgentId::new(),
name: name.to_string(),
description: None,
model: ModelConfig::default(),
system_prompt: None,
capabilities: vec![],
tools: vec![],
max_tokens: None,
temperature: None,
enabled: true,
}
}
#[tokio::test]
async fn test_in_memory_store_creation() {
let store = MemoryStore::in_memory().await;
assert!(store.is_ok());
}
#[tokio::test]
async fn test_save_and_load_agent() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("test-agent");
store.save_agent(&config).await.unwrap();
let loaded = store.load_agent(&config.id).await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.id, config.id);
assert_eq!(loaded.name, config.name);
}
#[tokio::test]
async fn test_load_nonexistent_agent() {
let store = MemoryStore::in_memory().await.unwrap();
let fake_id = AgentId::new();
let result = store.load_agent(&fake_id).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_save_agent_updates_existing() {
let store = MemoryStore::in_memory().await.unwrap();
let mut config = create_test_agent_config("original");
store.save_agent(&config).await.unwrap();
config.name = "updated".to_string();
store.save_agent(&config).await.unwrap();
let loaded = store.load_agent(&config.id).await.unwrap().unwrap();
assert_eq!(loaded.name, "updated");
}
#[tokio::test]
async fn test_list_agents() {
let store = MemoryStore::in_memory().await.unwrap();
let config1 = create_test_agent_config("agent1");
let config2 = create_test_agent_config("agent2");
store.save_agent(&config1).await.unwrap();
store.save_agent(&config2).await.unwrap();
let agents = store.list_agents().await.unwrap();
assert_eq!(agents.len(), 2);
}
#[tokio::test]
async fn test_delete_agent() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("to-delete");
store.save_agent(&config).await.unwrap();
store.delete_agent(&config.id).await.unwrap();
let loaded = store.load_agent(&config.id).await.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_delete_nonexistent_agent_succeeds() {
let store = MemoryStore::in_memory().await.unwrap();
let fake_id = AgentId::new();
// Deleting nonexistent agent should succeed (idempotent)
let result = store.delete_agent(&fake_id).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_create_session() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("session-test");
store.save_agent(&config).await.unwrap();
let session_id = store.create_session(&config.id).await.unwrap();
assert!(!session_id.as_uuid().is_nil());
}
#[tokio::test]
async fn test_append_and_get_messages() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("msg-test");
store.save_agent(&config).await.unwrap();
let session_id = store.create_session(&config.id).await.unwrap();
let msg1 = Message::user("Hello");
let msg2 = Message::assistant("Hi there!");
store.append_message(&session_id, &msg1).await.unwrap();
store.append_message(&session_id, &msg2).await.unwrap();
let messages = store.get_messages(&session_id).await.unwrap();
assert_eq!(messages.len(), 2);
}
#[tokio::test]
async fn test_message_ordering() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("order-test");
store.save_agent(&config).await.unwrap();
let session_id = store.create_session(&config.id).await.unwrap();
for i in 0..10 {
let msg = Message::user(format!("Message {}", i));
store.append_message(&session_id, &msg).await.unwrap();
}
let messages = store.get_messages(&session_id).await.unwrap();
assert_eq!(messages.len(), 10);
// Verify ordering
for (i, msg) in messages.iter().enumerate() {
if let Message::User { content } = msg {
assert_eq!(content, &format!("Message {}", i));
}
}
}
#[tokio::test]
async fn test_kv_store_and_recall() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("kv-test");
store.save_agent(&config).await.unwrap();
let value = serde_json::json!({"key": "value", "number": 42});
store.kv_store(&config.id, "test-key", &value).await.unwrap();
let recalled = store.kv_recall(&config.id, "test-key").await.unwrap();
assert!(recalled.is_some());
assert_eq!(recalled.unwrap(), value);
}
#[tokio::test]
async fn test_kv_recall_nonexistent() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("kv-missing");
store.save_agent(&config).await.unwrap();
let result = store.kv_recall(&config.id, "nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_kv_update_existing() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("kv-update");
store.save_agent(&config).await.unwrap();
let value1 = serde_json::json!({"version": 1});
let value2 = serde_json::json!({"version": 2});
store.kv_store(&config.id, "key", &value1).await.unwrap();
store.kv_store(&config.id, "key", &value2).await.unwrap();
let recalled = store.kv_recall(&config.id, "key").await.unwrap().unwrap();
assert_eq!(recalled["version"], 2);
}
#[tokio::test]
async fn test_kv_list() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("kv-list");
store.save_agent(&config).await.unwrap();
store.kv_store(&config.id, "key1", &serde_json::json!(1)).await.unwrap();
store.kv_store(&config.id, "key2", &serde_json::json!(2)).await.unwrap();
store.kv_store(&config.id, "key3", &serde_json::json!(3)).await.unwrap();
let keys = store.kv_list(&config.id).await.unwrap();
assert_eq!(keys.len(), 3);
assert!(keys.contains(&"key1".to_string()));
assert!(keys.contains(&"key2".to_string()));
assert!(keys.contains(&"key3".to_string()));
}
// === Edge Case Tests ===
#[tokio::test]
async fn test_agent_with_empty_name() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("");
// Empty name should still work (validation is elsewhere)
let result = store.save_agent(&config).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_agent_with_special_characters_in_name() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("agent-with-特殊字符-🎉");
let result = store.save_agent(&config).await;
assert!(result.is_ok());
let loaded = store.load_agent(&config.id).await.unwrap().unwrap();
assert_eq!(loaded.name, "agent-with-特殊字符-🎉");
}
#[tokio::test]
async fn test_large_message_content() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("large-msg");
store.save_agent(&config).await.unwrap();
let session_id = store.create_session(&config.id).await.unwrap();
// Create a large message (100KB)
let large_content = "x".repeat(100_000);
let msg = Message::user(&large_content);
let result = store.append_message(&session_id, &msg).await;
assert!(result.is_ok());
let messages = store.get_messages(&session_id).await.unwrap();
assert_eq!(messages.len(), 1);
}
#[tokio::test]
async fn test_message_with_tool_use() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("tool-msg");
store.save_agent(&config).await.unwrap();
let session_id = store.create_session(&config.id).await.unwrap();
let tool_input = serde_json::json!({"query": "test", "options": {"limit": 10}});
let msg = Message::tool_use("call-123", zclaw_types::ToolId::new("search"), tool_input.clone());
store.append_message(&session_id, &msg).await.unwrap();
let messages = store.get_messages(&session_id).await.unwrap();
assert_eq!(messages.len(), 1);
if let Message::ToolUse { id, tool, input } = &messages[0] {
assert_eq!(id, "call-123");
assert_eq!(tool.as_str(), "search");
assert_eq!(*input, tool_input);
} else {
panic!("Expected ToolUse message");
}
}
#[tokio::test]
async fn test_message_with_tool_result() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("tool-result");
store.save_agent(&config).await.unwrap();
let session_id = store.create_session(&config.id).await.unwrap();
let output = serde_json::json!({"results": ["a", "b", "c"]});
let msg = Message::tool_result("call-123", zclaw_types::ToolId::new("search"), output.clone(), false);
store.append_message(&session_id, &msg).await.unwrap();
let messages = store.get_messages(&session_id).await.unwrap();
assert_eq!(messages.len(), 1);
if let Message::ToolResult { tool_call_id, tool, output: o, is_error } = &messages[0] {
assert_eq!(tool_call_id, "call-123");
assert_eq!(tool.as_str(), "search");
assert_eq!(*o, output);
assert!(!is_error);
} else {
panic!("Expected ToolResult message");
}
}
#[tokio::test]
async fn test_message_with_thinking() {
let store = MemoryStore::in_memory().await.unwrap();
let config = create_test_agent_config("thinking");
store.save_agent(&config).await.unwrap();
let session_id = store.create_session(&config.id).await.unwrap();
let msg = Message::assistant_with_thinking("Final answer", "My reasoning...");
store.append_message(&session_id, &msg).await.unwrap();
let messages = store.get_messages(&session_id).await.unwrap();
assert_eq!(messages.len(), 1);
if let Message::Assistant { content, thinking } = &messages[0] {
assert_eq!(content, "Final answer");
assert_eq!(thinking.as_ref().unwrap(), "My reasoning...");
} else {
panic!("Expected Assistant message");
}
}
}