//! Memory store implementation use sqlx::SqlitePool; use zclaw_types::{AgentConfig, AgentId, SessionId, Message, Result, ZclawError}; /// Memory store for persisting ZCLAW data pub struct MemoryStore { pool: SqlitePool, } impl MemoryStore { /// Create a new memory store with the given database path pub async fn new(database_url: &str) -> Result { // Ensure parent directory exists for file-based SQLite databases Self::ensure_database_dir(database_url)?; let pool = SqlitePool::connect(database_url).await .map_err(|e| ZclawError::StorageError(e.to_string()))?; let store = Self { pool }; store.run_migrations().await?; Ok(store) } /// Ensure the parent directory for the database file exists fn ensure_database_dir(database_url: &str) -> Result<()> { // Parse SQLite URL to extract file path // Format: sqlite:/path/to/db or sqlite://path/to/db if database_url.starts_with("sqlite:") { let path_part = database_url.strip_prefix("sqlite:") .ok_or_else(|| ZclawError::StorageError( format!("Invalid database URL format: {}", database_url) ))?; // Skip in-memory databases if path_part == ":memory:" { return Ok(()); } // Remove query parameters (e.g., ?mode=rwc) let path_without_query = path_part.split('?').next() .ok_or_else(|| ZclawError::StorageError( format!("Invalid database URL path: {}", path_part) ))?; // Handle both absolute and relative paths let path = std::path::Path::new(path_without_query); // Get parent directory if let Some(parent) = path.parent() { if !parent.exists() { std::fs::create_dir_all(parent) .map_err(|e| ZclawError::StorageError( format!("Failed to create database directory {}: {}", parent.display(), e) ))?; } } } Ok(()) } /// Create an in-memory database (for testing) pub async fn in_memory() -> Result { Self::new("sqlite::memory:").await } /// Run database migrations async fn run_migrations(&self) -> Result<()> { sqlx::query(crate::schema::CREATE_SCHEMA) .execute(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; Ok(()) } // === Agent CRUD === /// Save an agent configuration pub async fn save_agent(&self, agent: &AgentConfig) -> Result<()> { let config_json = serde_json::to_string(agent)?; let id = agent.id.to_string(); let name = &agent.name; sqlx::query( r#" INSERT INTO agents (id, name, config, created_at, updated_at) VALUES (?, ?, ?, datetime('now'), datetime('now')) ON CONFLICT(id) DO UPDATE SET name = excluded.name, config = excluded.config, updated_at = datetime('now') "#, ) .bind(&id) .bind(name) .bind(&config_json) .execute(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; Ok(()) } /// Load an agent by ID pub async fn load_agent(&self, id: &AgentId) -> Result> { let id_str = id.to_string(); let row = sqlx::query_as::<_, (String,)>( "SELECT config FROM agents WHERE id = ?" ) .bind(&id_str) .fetch_optional(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; match row { Some((config,)) => { let agent: AgentConfig = serde_json::from_str(&config)?; Ok(Some(agent)) } None => Ok(None), } } /// List all agents pub async fn list_agents(&self) -> Result> { let rows = sqlx::query_as::<_, (String,)>( "SELECT config FROM agents" ) .fetch_all(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; let agents = rows .into_iter() .filter_map(|(config,)| serde_json::from_str(&config).ok()) .collect(); Ok(agents) } /// Delete an agent pub async fn delete_agent(&self, id: &AgentId) -> Result<()> { let id_str = id.to_string(); sqlx::query("DELETE FROM agents WHERE id = ?") .bind(&id_str) .execute(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; Ok(()) } // === Session Management === /// Create a new session for an agent pub async fn create_session(&self, agent_id: &AgentId) -> Result { let session_id = SessionId::new(); let session_str = session_id.to_string(); let agent_str = agent_id.to_string(); sqlx::query( r#" INSERT INTO sessions (id, agent_id, created_at, updated_at) VALUES (?, ?, datetime('now'), datetime('now')) "#, ) .bind(&session_str) .bind(&agent_str) .execute(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; Ok(session_id) } /// Append a message to a session pub async fn append_message(&self, session_id: &SessionId, message: &Message) -> Result<()> { let session_str = session_id.to_string(); let message_json = serde_json::to_string(message)?; sqlx::query( r#" INSERT INTO messages (session_id, seq, content, created_at) SELECT ?, COALESCE(MAX(seq), 0) + 1, ?, datetime('now') FROM messages WHERE session_id = ? "#, ) .bind(&session_str) .bind(&message_json) .bind(&session_str) .execute(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; // Update session updated_at sqlx::query("UPDATE sessions SET updated_at = datetime('now') WHERE id = ?") .bind(&session_str) .execute(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; Ok(()) } /// Get all messages for a session pub async fn get_messages(&self, session_id: &SessionId) -> Result> { let session_str = session_id.to_string(); let rows = sqlx::query_as::<_, (String,)>( "SELECT content FROM messages WHERE session_id = ? ORDER BY seq" ) .bind(&session_str) .fetch_all(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; let messages = rows .into_iter() .filter_map(|(content,)| serde_json::from_str(&content).ok()) .collect(); Ok(messages) } // === KV Store === /// Store a key-value pair for an agent pub async fn kv_store(&self, agent_id: &AgentId, key: &str, value: &serde_json::Value) -> Result<()> { let agent_str = agent_id.to_string(); let value_json = serde_json::to_string(value)?; sqlx::query( r#" INSERT INTO kv_store (agent_id, key, value, updated_at) VALUES (?, ?, ?, datetime('now')) ON CONFLICT(agent_id, key) DO UPDATE SET value = excluded.value, updated_at = datetime('now') "#, ) .bind(&agent_str) .bind(key) .bind(&value_json) .execute(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; Ok(()) } /// Recall a value by key pub async fn kv_recall(&self, agent_id: &AgentId, key: &str) -> Result> { let agent_str = agent_id.to_string(); let row = sqlx::query_as::<_, (String,)>( "SELECT value FROM kv_store WHERE agent_id = ? AND key = ?" ) .bind(&agent_str) .bind(key) .fetch_optional(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; match row { Some((value,)) => { let v: serde_json::Value = serde_json::from_str(&value)?; Ok(Some(v)) } None => Ok(None), } } /// List all keys for an agent pub async fn kv_list(&self, agent_id: &AgentId) -> Result> { let agent_str = agent_id.to_string(); let rows = sqlx::query_as::<_, (String,)>( "SELECT key FROM kv_store WHERE agent_id = ?" ) .bind(&agent_str) .fetch_all(&self.pool) .await .map_err(|e| ZclawError::StorageError(e.to_string()))?; Ok(rows.into_iter().map(|(key,)| key).collect()) } } #[cfg(test)] mod tests { use super::*; use zclaw_types::{AgentConfig, ModelConfig}; fn create_test_agent_config(name: &str) -> AgentConfig { AgentConfig { id: AgentId::new(), name: name.to_string(), description: None, model: ModelConfig::default(), system_prompt: None, capabilities: vec![], tools: vec![], max_tokens: None, temperature: None, enabled: true, } } #[tokio::test] async fn test_in_memory_store_creation() { let store = MemoryStore::in_memory().await; assert!(store.is_ok()); } #[tokio::test] async fn test_save_and_load_agent() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("test-agent"); store.save_agent(&config).await.unwrap(); let loaded = store.load_agent(&config.id).await.unwrap(); assert!(loaded.is_some()); let loaded = loaded.unwrap(); assert_eq!(loaded.id, config.id); assert_eq!(loaded.name, config.name); } #[tokio::test] async fn test_load_nonexistent_agent() { let store = MemoryStore::in_memory().await.unwrap(); let fake_id = AgentId::new(); let result = store.load_agent(&fake_id).await.unwrap(); assert!(result.is_none()); } #[tokio::test] async fn test_save_agent_updates_existing() { let store = MemoryStore::in_memory().await.unwrap(); let mut config = create_test_agent_config("original"); store.save_agent(&config).await.unwrap(); config.name = "updated".to_string(); store.save_agent(&config).await.unwrap(); let loaded = store.load_agent(&config.id).await.unwrap().unwrap(); assert_eq!(loaded.name, "updated"); } #[tokio::test] async fn test_list_agents() { let store = MemoryStore::in_memory().await.unwrap(); let config1 = create_test_agent_config("agent1"); let config2 = create_test_agent_config("agent2"); store.save_agent(&config1).await.unwrap(); store.save_agent(&config2).await.unwrap(); let agents = store.list_agents().await.unwrap(); assert_eq!(agents.len(), 2); } #[tokio::test] async fn test_delete_agent() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("to-delete"); store.save_agent(&config).await.unwrap(); store.delete_agent(&config.id).await.unwrap(); let loaded = store.load_agent(&config.id).await.unwrap(); assert!(loaded.is_none()); } #[tokio::test] async fn test_delete_nonexistent_agent_succeeds() { let store = MemoryStore::in_memory().await.unwrap(); let fake_id = AgentId::new(); // Deleting nonexistent agent should succeed (idempotent) let result = store.delete_agent(&fake_id).await; assert!(result.is_ok()); } #[tokio::test] async fn test_create_session() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("session-test"); store.save_agent(&config).await.unwrap(); let session_id = store.create_session(&config.id).await.unwrap(); assert!(!session_id.as_uuid().is_nil()); } #[tokio::test] async fn test_append_and_get_messages() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("msg-test"); store.save_agent(&config).await.unwrap(); let session_id = store.create_session(&config.id).await.unwrap(); let msg1 = Message::user("Hello"); let msg2 = Message::assistant("Hi there!"); store.append_message(&session_id, &msg1).await.unwrap(); store.append_message(&session_id, &msg2).await.unwrap(); let messages = store.get_messages(&session_id).await.unwrap(); assert_eq!(messages.len(), 2); } #[tokio::test] async fn test_message_ordering() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("order-test"); store.save_agent(&config).await.unwrap(); let session_id = store.create_session(&config.id).await.unwrap(); for i in 0..10 { let msg = Message::user(format!("Message {}", i)); store.append_message(&session_id, &msg).await.unwrap(); } let messages = store.get_messages(&session_id).await.unwrap(); assert_eq!(messages.len(), 10); // Verify ordering for (i, msg) in messages.iter().enumerate() { if let Message::User { content } = msg { assert_eq!(content, &format!("Message {}", i)); } } } #[tokio::test] async fn test_kv_store_and_recall() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("kv-test"); store.save_agent(&config).await.unwrap(); let value = serde_json::json!({"key": "value", "number": 42}); store.kv_store(&config.id, "test-key", &value).await.unwrap(); let recalled = store.kv_recall(&config.id, "test-key").await.unwrap(); assert!(recalled.is_some()); assert_eq!(recalled.unwrap(), value); } #[tokio::test] async fn test_kv_recall_nonexistent() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("kv-missing"); store.save_agent(&config).await.unwrap(); let result = store.kv_recall(&config.id, "nonexistent").await.unwrap(); assert!(result.is_none()); } #[tokio::test] async fn test_kv_update_existing() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("kv-update"); store.save_agent(&config).await.unwrap(); let value1 = serde_json::json!({"version": 1}); let value2 = serde_json::json!({"version": 2}); store.kv_store(&config.id, "key", &value1).await.unwrap(); store.kv_store(&config.id, "key", &value2).await.unwrap(); let recalled = store.kv_recall(&config.id, "key").await.unwrap().unwrap(); assert_eq!(recalled["version"], 2); } #[tokio::test] async fn test_kv_list() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("kv-list"); store.save_agent(&config).await.unwrap(); store.kv_store(&config.id, "key1", &serde_json::json!(1)).await.unwrap(); store.kv_store(&config.id, "key2", &serde_json::json!(2)).await.unwrap(); store.kv_store(&config.id, "key3", &serde_json::json!(3)).await.unwrap(); let keys = store.kv_list(&config.id).await.unwrap(); assert_eq!(keys.len(), 3); assert!(keys.contains(&"key1".to_string())); assert!(keys.contains(&"key2".to_string())); assert!(keys.contains(&"key3".to_string())); } // === Edge Case Tests === #[tokio::test] async fn test_agent_with_empty_name() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config(""); // Empty name should still work (validation is elsewhere) let result = store.save_agent(&config).await; assert!(result.is_ok()); } #[tokio::test] async fn test_agent_with_special_characters_in_name() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("agent-with-特殊字符-🎉"); let result = store.save_agent(&config).await; assert!(result.is_ok()); let loaded = store.load_agent(&config.id).await.unwrap().unwrap(); assert_eq!(loaded.name, "agent-with-特殊字符-🎉"); } #[tokio::test] async fn test_large_message_content() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("large-msg"); store.save_agent(&config).await.unwrap(); let session_id = store.create_session(&config.id).await.unwrap(); // Create a large message (100KB) let large_content = "x".repeat(100_000); let msg = Message::user(&large_content); let result = store.append_message(&session_id, &msg).await; assert!(result.is_ok()); let messages = store.get_messages(&session_id).await.unwrap(); assert_eq!(messages.len(), 1); } #[tokio::test] async fn test_message_with_tool_use() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("tool-msg"); store.save_agent(&config).await.unwrap(); let session_id = store.create_session(&config.id).await.unwrap(); let tool_input = serde_json::json!({"query": "test", "options": {"limit": 10}}); let msg = Message::tool_use("call-123", zclaw_types::ToolId::new("search"), tool_input.clone()); store.append_message(&session_id, &msg).await.unwrap(); let messages = store.get_messages(&session_id).await.unwrap(); assert_eq!(messages.len(), 1); if let Message::ToolUse { id, tool, input } = &messages[0] { assert_eq!(id, "call-123"); assert_eq!(tool.as_str(), "search"); assert_eq!(*input, tool_input); } else { panic!("Expected ToolUse message"); } } #[tokio::test] async fn test_message_with_tool_result() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("tool-result"); store.save_agent(&config).await.unwrap(); let session_id = store.create_session(&config.id).await.unwrap(); let output = serde_json::json!({"results": ["a", "b", "c"]}); let msg = Message::tool_result("call-123", zclaw_types::ToolId::new("search"), output.clone(), false); store.append_message(&session_id, &msg).await.unwrap(); let messages = store.get_messages(&session_id).await.unwrap(); assert_eq!(messages.len(), 1); if let Message::ToolResult { tool_call_id, tool, output: o, is_error } = &messages[0] { assert_eq!(tool_call_id, "call-123"); assert_eq!(tool.as_str(), "search"); assert_eq!(*o, output); assert!(!is_error); } else { panic!("Expected ToolResult message"); } } #[tokio::test] async fn test_message_with_thinking() { let store = MemoryStore::in_memory().await.unwrap(); let config = create_test_agent_config("thinking"); store.save_agent(&config).await.unwrap(); let session_id = store.create_session(&config.id).await.unwrap(); let msg = Message::assistant_with_thinking("Final answer", "My reasoning..."); store.append_message(&session_id, &msg).await.unwrap(); let messages = store.get_messages(&session_id).await.unwrap(); assert_eq!(messages.len(), 1); if let Message::Assistant { content, thinking } = &messages[0] { assert_eq!(content, "Final answer"); assert_eq!(thinking.as_ref().unwrap(), "My reasoning..."); } else { panic!("Expected Assistant message"); } } }