247 lines
7.4 KiB
Rust
247 lines
7.4 KiB
Rust
//! Memory store implementation
|
|
|
|
use sqlx::SqlitePool;
|
|
use zclaw_types::{AgentConfig, AgentId, SessionId, Message, Result, ZclawError};
|
|
|
|
/// Memory store for persisting ZCLAW data
|
|
pub struct MemoryStore {
|
|
pool: SqlitePool,
|
|
}
|
|
|
|
impl MemoryStore {
|
|
/// Create a new memory store with the given database path
|
|
pub async fn new(database_url: &str) -> Result<Self> {
|
|
let pool = SqlitePool::connect(database_url).await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
let store = Self { pool };
|
|
store.run_migrations().await?;
|
|
Ok(store)
|
|
}
|
|
|
|
/// Create an in-memory database (for testing)
|
|
pub async fn in_memory() -> Result<Self> {
|
|
Self::new("sqlite::memory:").await
|
|
}
|
|
|
|
/// Run database migrations
|
|
async fn run_migrations(&self) -> Result<()> {
|
|
sqlx::query(crate::schema::CREATE_SCHEMA)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
// === Agent CRUD ===
|
|
|
|
/// Save an agent configuration
|
|
pub async fn save_agent(&self, agent: &AgentConfig) -> Result<()> {
|
|
let config_json = serde_json::to_string(agent)?;
|
|
let id = agent.id.to_string();
|
|
let name = &agent.name;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO agents (id, name, config, created_at, updated_at)
|
|
VALUES (?, ?, ?, datetime('now'), datetime('now'))
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
name = excluded.name,
|
|
config = excluded.config,
|
|
updated_at = datetime('now')
|
|
"#,
|
|
)
|
|
.bind(&id)
|
|
.bind(name)
|
|
.bind(&config_json)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Load an agent by ID
|
|
pub async fn load_agent(&self, id: &AgentId) -> Result<Option<AgentConfig>> {
|
|
let id_str = id.to_string();
|
|
|
|
let row = sqlx::query_as::<_, (String,)>(
|
|
"SELECT config FROM agents WHERE id = ?"
|
|
)
|
|
.bind(&id_str)
|
|
.fetch_optional(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
match row {
|
|
Some((config,)) => {
|
|
let agent: AgentConfig = serde_json::from_str(&config)?;
|
|
Ok(Some(agent))
|
|
}
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
|
|
/// List all agents
|
|
pub async fn list_agents(&self) -> Result<Vec<AgentConfig>> {
|
|
let rows = sqlx::query_as::<_, (String,)>(
|
|
"SELECT config FROM agents"
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
let agents = rows
|
|
.into_iter()
|
|
.filter_map(|(config,)| serde_json::from_str(&config).ok())
|
|
.collect();
|
|
Ok(agents)
|
|
}
|
|
|
|
/// Delete an agent
|
|
pub async fn delete_agent(&self, id: &AgentId) -> Result<()> {
|
|
let id_str = id.to_string();
|
|
|
|
sqlx::query("DELETE FROM agents WHERE id = ?")
|
|
.bind(&id_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// === Session Management ===
|
|
|
|
/// Create a new session for an agent
|
|
pub async fn create_session(&self, agent_id: &AgentId) -> Result<SessionId> {
|
|
let session_id = SessionId::new();
|
|
let session_str = session_id.to_string();
|
|
let agent_str = agent_id.to_string();
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO sessions (id, agent_id, created_at, updated_at)
|
|
VALUES (?, ?, datetime('now'), datetime('now'))
|
|
"#,
|
|
)
|
|
.bind(&session_str)
|
|
.bind(&agent_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(session_id)
|
|
}
|
|
|
|
/// Append a message to a session
|
|
pub async fn append_message(&self, session_id: &SessionId, message: &Message) -> Result<()> {
|
|
let session_str = session_id.to_string();
|
|
let message_json = serde_json::to_string(message)?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO messages (session_id, seq, content, created_at)
|
|
SELECT ?, COALESCE(MAX(seq), 0) + 1, datetime('now')
|
|
FROM messages WHERE session_id = ?
|
|
"#,
|
|
)
|
|
.bind(&session_str)
|
|
.bind(&message_json)
|
|
.bind(&session_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
// Update session updated_at
|
|
sqlx::query("UPDATE sessions SET updated_at = datetime('now') WHERE id = ?")
|
|
.bind(&session_str)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Get all messages for a session
|
|
pub async fn get_messages(&self, session_id: &SessionId) -> Result<Vec<Message>> {
|
|
let session_str = session_id.to_string();
|
|
|
|
let rows = sqlx::query_as::<_, (String,)>(
|
|
"SELECT content FROM messages WHERE session_id = ? ORDER BY seq"
|
|
)
|
|
.bind(&session_str)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
let messages = rows
|
|
.into_iter()
|
|
.filter_map(|(content,)| serde_json::from_str(&content).ok())
|
|
.collect();
|
|
Ok(messages)
|
|
}
|
|
|
|
// === KV Store ===
|
|
|
|
/// Store a key-value pair for an agent
|
|
pub async fn kv_store(&self, agent_id: &AgentId, key: &str, value: &serde_json::Value) -> Result<()> {
|
|
let agent_str = agent_id.to_string();
|
|
let value_json = serde_json::to_string(value)?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO kv_store (agent_id, key, value, updated_at)
|
|
VALUES (?, ?, ?, datetime('now'))
|
|
ON CONFLICT(agent_id, key) DO UPDATE SET
|
|
value = excluded.value,
|
|
updated_at = datetime('now')
|
|
"#,
|
|
)
|
|
.bind(&agent_str)
|
|
.bind(key)
|
|
.bind(&value_json)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Recall a value by key
|
|
pub async fn kv_recall(&self, agent_id: &AgentId, key: &str) -> Result<Option<serde_json::Value>> {
|
|
let agent_str = agent_id.to_string();
|
|
|
|
let row = sqlx::query_as::<_, (String,)>(
|
|
"SELECT value FROM kv_store WHERE agent_id = ? AND key = ?"
|
|
)
|
|
.bind(&agent_str)
|
|
.bind(key)
|
|
.fetch_optional(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
match row {
|
|
Some((value,)) => {
|
|
let v: serde_json::Value = serde_json::from_str(&value)?;
|
|
Ok(Some(v))
|
|
}
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
|
|
/// List all keys for an agent
|
|
pub async fn kv_list(&self, agent_id: &AgentId) -> Result<Vec<String>> {
|
|
let agent_str = agent_id.to_string();
|
|
|
|
let rows = sqlx::query_as::<_, (String,)>(
|
|
"SELECT key FROM kv_store WHERE agent_id = ?"
|
|
)
|
|
.bind(&agent_str)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
|
|
Ok(rows.into_iter().map(|(key,)| key).collect())
|
|
}
|
|
}
|