初始化提交
Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled

This commit is contained in:
iven
2026-03-01 16:24:24 +08:00
commit 92e5def702
492 changed files with 211343 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
[package]
name = "openfang-memory"
version.workspace = true
edition.workspace = true
license.workspace = true
description = "Memory substrate for the OpenFang Agent OS"
[dependencies]
openfang-types = { path = "../openfang-types" }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
rmp-serde = { workspace = true }
rusqlite = { workspace = true }
chrono = { workspace = true }
uuid = { workspace = true }
thiserror = { workspace = true }
async-trait = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
tokio-test = { workspace = true }
tempfile = { workspace = true }

View File

@@ -0,0 +1,101 @@
//! Memory consolidation and decay logic.
//!
//! Reduces confidence of old, unaccessed memories and merges
//! duplicate/similar memories.
use chrono::Utc;
use openfang_types::error::{OpenFangError, OpenFangResult};
use openfang_types::memory::ConsolidationReport;
use rusqlite::Connection;
use std::sync::{Arc, Mutex};
/// Memory consolidation engine.
#[derive(Clone)]
pub struct ConsolidationEngine {
conn: Arc<Mutex<Connection>>,
/// Decay rate: how much to reduce confidence per consolidation cycle.
decay_rate: f32,
}
impl ConsolidationEngine {
/// Create a new consolidation engine.
pub fn new(conn: Arc<Mutex<Connection>>, decay_rate: f32) -> Self {
Self { conn, decay_rate }
}
/// Run a consolidation cycle: decay old memories.
pub fn consolidate(&self) -> OpenFangResult<ConsolidationReport> {
let start = std::time::Instant::now();
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
// Decay confidence of memories not accessed in the last 7 days
let cutoff = (Utc::now() - chrono::Duration::days(7)).to_rfc3339();
let decay_factor = 1.0 - self.decay_rate as f64;
let decayed = conn
.execute(
"UPDATE memories SET confidence = MAX(0.1, confidence * ?1)
WHERE deleted = 0 AND accessed_at < ?2 AND confidence > 0.1",
rusqlite::params![decay_factor, cutoff],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let duration_ms = start.elapsed().as_millis() as u64;
Ok(ConsolidationReport {
memories_merged: 0, // Phase 1: no merging
memories_decayed: decayed as u64,
duration_ms,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::migration::run_migrations;
fn setup() -> ConsolidationEngine {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
ConsolidationEngine::new(Arc::new(Mutex::new(conn)), 0.1)
}
#[test]
fn test_consolidation_empty() {
let engine = setup();
let report = engine.consolidate().unwrap();
assert_eq!(report.memories_decayed, 0);
}
#[test]
fn test_consolidation_decays_old_memories() {
let engine = setup();
let conn = engine.conn.lock().unwrap();
// Insert an old memory
let old_date = (Utc::now() - chrono::Duration::days(30)).to_rfc3339();
conn.execute(
"INSERT INTO memories (id, agent_id, content, source, scope, confidence, metadata, created_at, accessed_at, access_count, deleted)
VALUES ('test-id', 'agent-1', 'old memory', '\"conversation\"', 'episodic', 0.9, '{}', ?1, ?1, 0, 0)",
rusqlite::params![old_date],
).unwrap();
drop(conn);
let report = engine.consolidate().unwrap();
assert_eq!(report.memories_decayed, 1);
// Verify confidence was reduced
let conn = engine.conn.lock().unwrap();
let confidence: f64 = conn
.query_row(
"SELECT confidence FROM memories WHERE id = 'test-id'",
[],
|row| row.get(0),
)
.unwrap();
assert!(confidence < 0.9);
}
}

View File

@@ -0,0 +1,857 @@
//! Knowledge graph backed by SQLite.
//!
//! Stores entities and relations with support for graph pattern queries.
//! Supports recursive traversal with configurable depth limits.
use chrono::Utc;
use openfang_types::error::{OpenFangError, OpenFangResult};
use openfang_types::memory::{
Entity, EntityType, GraphMatch, GraphPattern, Relation, RelationType,
};
use rusqlite::Connection;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use uuid::Uuid;
/// Maximum allowed depth for recursive graph queries (safety limit).
const MAX_ALLOWED_DEPTH: u32 = 10;
/// Default result limit for graph queries.
const DEFAULT_RESULT_LIMIT: u32 = 1000;
/// Knowledge graph store backed by SQLite.
#[derive(Clone)]
pub struct KnowledgeStore {
conn: Arc<Mutex<Connection>>,
}
impl KnowledgeStore {
/// Create a new knowledge store wrapping the given connection.
pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
Self { conn }
}
/// Add an entity to the knowledge graph.
pub fn add_entity(&self, entity: Entity) -> OpenFangResult<String> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let id = if entity.id.is_empty() {
Uuid::new_v4().to_string()
} else {
entity.id.clone()
};
let entity_type_str = serde_json::to_string(&entity.entity_type)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let props_str = serde_json::to_string(&entity.properties)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO entities (id, entity_type, name, properties, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?5)
ON CONFLICT(id) DO UPDATE SET name = ?3, properties = ?4, updated_at = ?5",
rusqlite::params![id, entity_type_str, entity.name, props_str, now],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(id)
}
/// Add a relation between two entities.
pub fn add_relation(&self, relation: Relation) -> OpenFangResult<String> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let id = Uuid::new_v4().to_string();
let rel_type_str = serde_json::to_string(&relation.relation)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let props_str = serde_json::to_string(&relation.properties)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO relations (id, source_entity, relation_type, target_entity, properties, confidence, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
rusqlite::params![
id,
relation.source,
rel_type_str,
relation.target,
props_str,
relation.confidence as f64,
now,
],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(id)
}
/// Query the knowledge graph with a pattern.
///
/// Supports recursive traversal via `max_depth` parameter:
/// - `max_depth = 1` (default): Single-hop query (source -> target)
/// - `max_depth > 1`: Recursive traversal following relation chains
/// - Maximum allowed depth: 10 (safety limit)
pub fn query_graph(&self, pattern: GraphPattern) -> OpenFangResult<Vec<GraphMatch>> {
// Normalize depth to safe range
let max_depth = pattern.max_depth.min(MAX_ALLOWED_DEPTH).max(1);
if max_depth == 1 {
// Single-hop query (original behavior)
self.query_single_hop(&pattern)
} else {
// Recursive traversal using iterative approach
self.query_recursive(&pattern, max_depth)
}
}
/// Single-hop query (depth = 1).
fn query_single_hop(&self, pattern: &GraphPattern) -> OpenFangResult<Vec<GraphMatch>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut sql = String::from(
"SELECT
s.id, s.entity_type, s.name, s.properties, s.created_at, s.updated_at,
r.id, r.source_entity, r.relation_type, r.target_entity, r.properties, r.confidence, r.created_at,
t.id, t.entity_type, t.name, t.properties, t.created_at, t.updated_at
FROM relations r
JOIN entities s ON r.source_entity = s.id
JOIN entities t ON r.target_entity = t.id
WHERE 1=1",
);
let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut idx = 1;
if let Some(ref source) = pattern.source {
sql.push_str(&format!(" AND (s.id = ?{idx} OR s.name = ?{idx})"));
params.push(Box::new(source.clone()));
idx += 1;
}
if let Some(ref relation) = pattern.relation {
let rel_str = serde_json::to_string(relation)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
sql.push_str(&format!(" AND r.relation_type = ?{idx}"));
params.push(Box::new(rel_str));
idx += 1;
}
if let Some(ref target) = pattern.target {
sql.push_str(&format!(" AND (t.id = ?{idx} OR t.name = ?{idx})"));
params.push(Box::new(target.clone()));
}
sql.push_str(" LIMIT 100");
let mut stmt = conn
.prepare(&sql)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
params.iter().map(|p| p.as_ref()).collect();
let rows = stmt
.query_map(param_refs.as_slice(), |row| {
Ok(RawGraphRow {
s_id: row.get(0)?,
s_type: row.get(1)?,
s_name: row.get(2)?,
s_props: row.get(3)?,
s_created: row.get(4)?,
s_updated: row.get(5)?,
r_id: row.get(6)?,
r_source: row.get(7)?,
r_type: row.get(8)?,
r_target: row.get(9)?,
r_props: row.get(10)?,
r_confidence: row.get(11)?,
r_created: row.get(12)?,
t_id: row.get(13)?,
t_type: row.get(14)?,
t_name: row.get(15)?,
t_props: row.get(16)?,
t_created: row.get(17)?,
t_updated: row.get(18)?,
})
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut matches = Vec::new();
for row_result in rows {
let r = row_result.map_err(|e| OpenFangError::Memory(e.to_string()))?;
matches.push(GraphMatch {
source: parse_entity(
&r.s_id,
&r.s_type,
&r.s_name,
&r.s_props,
&r.s_created,
&r.s_updated,
),
relation: parse_relation(
&r.r_source,
&r.r_type,
&r.r_target,
&r.r_props,
r.r_confidence,
&r.r_created,
),
target: parse_entity(
&r.t_id,
&r.t_type,
&r.t_name,
&r.t_props,
&r.t_created,
&r.t_updated,
),
});
}
Ok(matches)
}
/// Recursive graph traversal using iterative BFS approach.
///
/// This method iteratively traverses the graph up to the specified depth,
/// tracking visited relations to avoid cycles.
fn query_recursive(&self, pattern: &GraphPattern, max_depth: u32) -> OpenFangResult<Vec<GraphMatch>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut all_matches: Vec<GraphMatch> = Vec::new();
let mut visited_relations: HashSet<String> = HashSet::new();
// Start with initial sources
let mut current_sources: HashSet<String> = if let Some(ref source) = pattern.source {
// Resolve source to entity ID(s)
self.resolve_entity_ids(&conn, source)?
} else {
// No source filter - start from all entities
self.get_all_entity_ids(&conn)?
};
// Iteratively traverse each depth level
for _depth in 1..=max_depth {
if current_sources.is_empty() {
break;
}
// Query relations from current sources
let matches = self.query_from_sources(&conn, &current_sources, pattern)?;
// Collect new targets for next iteration
let mut next_sources: HashSet<String> = HashSet::new();
for m in matches {
// Create unique key for this relation to detect cycles
let rel_key = format!("{}->{}", m.relation.source, m.relation.target);
if visited_relations.insert(rel_key) {
// New relation (not visited)
next_sources.insert(m.target.id.clone());
all_matches.push(m);
}
}
// Move to next level
current_sources = next_sources;
// Safety: limit total results
if all_matches.len() >= DEFAULT_RESULT_LIMIT as usize {
all_matches.truncate(DEFAULT_RESULT_LIMIT as usize);
break;
}
}
Ok(all_matches)
}
/// Resolve an entity identifier (ID or name) to entity IDs.
fn resolve_entity_ids(&self, conn: &Connection, identifier: &str) -> OpenFangResult<HashSet<String>> {
let mut stmt = conn
.prepare("SELECT id FROM entities WHERE id = ?1 OR name = ?1")
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows: Result<Vec<String>, _> = stmt
.query_map([identifier], |row| row.get(0))
.map_err(|e| OpenFangError::Memory(e.to_string()))?
.collect();
Ok(rows.map_err(|e| OpenFangError::Memory(e.to_string()))?.into_iter().collect())
}
/// Get all entity IDs in the graph.
fn get_all_entity_ids(&self, conn: &Connection) -> OpenFangResult<HashSet<String>> {
let mut stmt = conn
.prepare("SELECT id FROM entities")
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows: Result<Vec<String>, _> = stmt
.query_map([], |row| row.get(0))
.map_err(|e| OpenFangError::Memory(e.to_string()))?
.collect();
Ok(rows.map_err(|e| OpenFangError::Memory(e.to_string()))?.into_iter().collect())
}
/// Query relations from a set of source entities.
fn query_from_sources(
&self,
conn: &Connection,
sources: &HashSet<String>,
pattern: &GraphPattern,
) -> OpenFangResult<Vec<GraphMatch>> {
if sources.is_empty() {
return Ok(Vec::new());
}
// Build query with IN clause for sources
let placeholders: Vec<String> = (0..sources.len()).map(|i| format!("?{}", i + 1)).collect();
let placeholders_str = placeholders.join(", ");
let mut sql = format!(
"SELECT
s.id, s.entity_type, s.name, s.properties, s.created_at, s.updated_at,
r.id, r.source_entity, r.relation_type, r.target_entity, r.properties, r.confidence, r.created_at,
t.id, t.entity_type, t.name, t.properties, t.created_at, t.updated_at
FROM relations r
JOIN entities s ON r.source_entity = s.id
JOIN entities t ON r.target_entity = t.id
WHERE s.id IN ({})",
placeholders_str
);
let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = sources
.iter()
.map(|s| Box::new(s.clone()) as Box<dyn rusqlite::types::ToSql>)
.collect();
let mut idx = sources.len() + 1;
// Add relation type filter if specified
if let Some(ref relation) = pattern.relation {
let rel_str = serde_json::to_string(relation)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
sql.push_str(&format!(" AND r.relation_type = ?{idx}"));
params.push(Box::new(rel_str));
idx += 1;
}
// Add target filter if specified
if let Some(ref target) = pattern.target {
sql.push_str(&format!(" AND (t.id = ?{idx} OR t.name = ?{idx})"));
params.push(Box::new(target.clone()));
}
sql.push_str(&format!(" LIMIT {}", DEFAULT_RESULT_LIMIT));
let mut stmt = conn
.prepare(&sql)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
params.iter().map(|p| p.as_ref()).collect();
let rows = stmt
.query_map(param_refs.as_slice(), |row| {
Ok(RawGraphRow {
s_id: row.get(0)?,
s_type: row.get(1)?,
s_name: row.get(2)?,
s_props: row.get(3)?,
s_created: row.get(4)?,
s_updated: row.get(5)?,
r_id: row.get(6)?,
r_source: row.get(7)?,
r_type: row.get(8)?,
r_target: row.get(9)?,
r_props: row.get(10)?,
r_confidence: row.get(11)?,
r_created: row.get(12)?,
t_id: row.get(13)?,
t_type: row.get(14)?,
t_name: row.get(15)?,
t_props: row.get(16)?,
t_created: row.get(17)?,
t_updated: row.get(18)?,
})
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut matches = Vec::new();
for row_result in rows {
let r = row_result.map_err(|e| OpenFangError::Memory(e.to_string()))?;
matches.push(GraphMatch {
source: parse_entity(
&r.s_id,
&r.s_type,
&r.s_name,
&r.s_props,
&r.s_created,
&r.s_updated,
),
relation: parse_relation(
&r.r_source,
&r.r_type,
&r.r_target,
&r.r_props,
r.r_confidence,
&r.r_created,
),
target: parse_entity(
&r.t_id,
&r.t_type,
&r.t_name,
&r.t_props,
&r.t_created,
&r.t_updated,
),
});
}
Ok(matches)
}
}
/// Raw row from a graph query.
struct RawGraphRow {
s_id: String,
s_type: String,
s_name: String,
s_props: String,
s_created: String,
s_updated: String,
r_id: String,
r_source: String,
r_type: String,
r_target: String,
r_props: String,
r_confidence: f64,
r_created: String,
t_id: String,
t_type: String,
t_name: String,
t_props: String,
t_created: String,
t_updated: String,
}
// Suppress the unused field warning — r_id is part of the schema
impl RawGraphRow {
#[allow(dead_code)]
fn relation_id(&self) -> &str {
&self.r_id
}
}
fn parse_entity(
id: &str,
etype: &str,
name: &str,
props: &str,
created: &str,
updated: &str,
) -> Entity {
let entity_type: EntityType =
serde_json::from_str(etype).unwrap_or(EntityType::Custom("unknown".to_string()));
let properties: HashMap<String, serde_json::Value> =
serde_json::from_str(props).unwrap_or_default();
let created_at = chrono::DateTime::parse_from_rfc3339(created)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
let updated_at = chrono::DateTime::parse_from_rfc3339(updated)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Entity {
id: id.to_string(),
entity_type,
name: name.to_string(),
properties,
created_at,
updated_at,
}
}
fn parse_relation(
source: &str,
rtype: &str,
target: &str,
props: &str,
confidence: f64,
created: &str,
) -> Relation {
let relation: RelationType = serde_json::from_str(rtype).unwrap_or(RelationType::RelatedTo);
let properties: HashMap<String, serde_json::Value> =
serde_json::from_str(props).unwrap_or_default();
let created_at = chrono::DateTime::parse_from_rfc3339(created)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Relation {
source: source.to_string(),
relation,
target: target.to_string(),
properties,
confidence: confidence as f32,
created_at,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::migration::run_migrations;
fn setup() -> KnowledgeStore {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
KnowledgeStore::new(Arc::new(Mutex::new(conn)))
}
#[test]
fn test_add_and_query_entity() {
let store = setup();
let id = store
.add_entity(Entity {
id: String::new(),
entity_type: EntityType::Person,
name: "Alice".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
assert!(!id.is_empty());
}
#[test]
fn test_add_relation_and_query() {
let store = setup();
let alice_id = store
.add_entity(Entity {
id: "alice".to_string(),
entity_type: EntityType::Person,
name: "Alice".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
let company_id = store
.add_entity(Entity {
id: "acme".to_string(),
entity_type: EntityType::Organization,
name: "Acme Corp".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
store
.add_relation(Relation {
source: alice_id.clone(),
relation: RelationType::WorksAt,
target: company_id,
properties: HashMap::new(),
confidence: 0.95,
created_at: Utc::now(),
})
.unwrap();
let matches = store
.query_graph(GraphPattern {
source: Some(alice_id),
relation: Some(RelationType::WorksAt),
target: None,
max_depth: 1,
})
.unwrap();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].target.name, "Acme Corp");
}
#[test]
fn test_recursive_traversal_depth_2() {
// Create a chain: Alice -> WorksAt -> Acme -> LocatedIn -> NYC
let store = setup();
let alice_id = store
.add_entity(Entity {
id: "alice".to_string(),
entity_type: EntityType::Person,
name: "Alice".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
let acme_id = store
.add_entity(Entity {
id: "acme".to_string(),
entity_type: EntityType::Organization,
name: "Acme Corp".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
let nyc_id = store
.add_entity(Entity {
id: "nyc".to_string(),
entity_type: EntityType::Location,
name: "New York City".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
// Alice works at Acme
store
.add_relation(Relation {
source: alice_id.clone(),
relation: RelationType::WorksAt,
target: acme_id.clone(),
properties: HashMap::new(),
confidence: 0.95,
created_at: Utc::now(),
})
.unwrap();
// Acme is located in NYC (using RelatedTo as LocatedIn doesn't exist)
store
.add_relation(Relation {
source: acme_id.clone(),
relation: RelationType::RelatedTo,
target: nyc_id.clone(),
properties: HashMap::new(),
confidence: 1.0,
created_at: Utc::now(),
})
.unwrap();
// Query with depth=1 should only return Alice -> Acme
let matches_depth_1 = store
.query_graph(GraphPattern {
source: Some(alice_id.clone()),
relation: None,
target: None,
max_depth: 1,
})
.unwrap();
assert_eq!(matches_depth_1.len(), 1);
assert_eq!(matches_depth_1[0].target.name, "Acme Corp");
// Query with depth=2 should return both relations
let matches_depth_2 = store
.query_graph(GraphPattern {
source: Some(alice_id),
relation: None,
target: None,
max_depth: 2,
})
.unwrap();
assert_eq!(matches_depth_2.len(), 2);
// Verify we get both hops
let target_names: Vec<&str> = matches_depth_2.iter().map(|m| m.target.name.as_str()).collect();
assert!(target_names.contains(&"Acme Corp"));
assert!(target_names.contains(&"New York City"));
}
#[test]
fn test_recursive_traversal_with_cycle() {
// Create a cycle: A -> B -> C -> A
let store = setup();
let a_id = store
.add_entity(Entity {
id: "a".to_string(),
entity_type: EntityType::Person,
name: "Entity A".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
let b_id = store
.add_entity(Entity {
id: "b".to_string(),
entity_type: EntityType::Person,
name: "Entity B".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
let c_id = store
.add_entity(Entity {
id: "c".to_string(),
entity_type: EntityType::Person,
name: "Entity C".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
// A -> B
store
.add_relation(Relation {
source: a_id.clone(),
relation: RelationType::RelatedTo,
target: b_id.clone(),
properties: HashMap::new(),
confidence: 1.0,
created_at: Utc::now(),
})
.unwrap();
// B -> C
store
.add_relation(Relation {
source: b_id.clone(),
relation: RelationType::RelatedTo,
target: c_id.clone(),
properties: HashMap::new(),
confidence: 1.0,
created_at: Utc::now(),
})
.unwrap();
// C -> A (creates cycle)
store
.add_relation(Relation {
source: c_id.clone(),
relation: RelationType::RelatedTo,
target: a_id.clone(),
properties: HashMap::new(),
confidence: 1.0,
created_at: Utc::now(),
})
.unwrap();
// Query should handle cycle without infinite loop
let matches = store
.query_graph(GraphPattern {
source: Some(a_id),
relation: None,
target: None,
max_depth: 5,
})
.unwrap();
// Should only return 3 unique relations (A->B, B->C, C->A)
// Cycle detection prevents revisiting A->B
assert_eq!(matches.len(), 3);
}
#[test]
fn test_recursive_traversal_with_relation_filter() {
// Create chain with different relation types
let store = setup();
let alice_id = store
.add_entity(Entity {
id: "alice".to_string(),
entity_type: EntityType::Person,
name: "Alice".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
let bob_id = store
.add_entity(Entity {
id: "bob".to_string(),
entity_type: EntityType::Person,
name: "Bob".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
let carol_id = store
.add_entity(Entity {
id: "carol".to_string(),
entity_type: EntityType::Person,
name: "Carol".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
// Alice knows Bob
store
.add_relation(Relation {
source: alice_id.clone(),
relation: RelationType::Knows,
target: bob_id.clone(),
properties: HashMap::new(),
confidence: 0.9,
created_at: Utc::now(),
})
.unwrap();
// Bob works with Carol (different relation type)
store
.add_relation(Relation {
source: bob_id.clone(),
relation: RelationType::WorksAt,
target: carol_id.clone(),
properties: HashMap::new(),
confidence: 0.9,
created_at: Utc::now(),
})
.unwrap();
// Query with relation filter should only return matching relations
let matches = store
.query_graph(GraphPattern {
source: Some(alice_id),
relation: Some(RelationType::Knows),
target: None,
max_depth: 3,
})
.unwrap();
// Should only return Alice -> Bob (Knows relation)
// Bob -> Carol is WorksAt, not Knows
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].target.name, "Bob");
}
#[test]
fn test_max_depth_safety_limit() {
// Test that max_depth is capped at MAX_ALLOWED_DEPTH (10)
let store = setup();
let alice_id = store
.add_entity(Entity {
id: "alice".to_string(),
entity_type: EntityType::Person,
name: "Alice".to_string(),
properties: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
})
.unwrap();
// Request depth of 100 (should be capped to 10)
let _matches = store
.query_graph(GraphPattern {
source: Some(alice_id),
relation: None,
target: None,
max_depth: 100, // Should be capped to 10
})
.unwrap();
// If we reach here without hanging, the safety limit worked
// (A real test would verify the actual depth, but this is a basic sanity check)
}
}

View File

@@ -0,0 +1,19 @@
//! Memory substrate for the OpenFang Agent Operating System.
//!
//! Provides a unified memory API over three storage backends:
//! - **Structured store** (SQLite): Key-value pairs, sessions, agent state
//! - **Semantic store**: Text-based search (Phase 1: LIKE matching, Phase 2: Qdrant vectors)
//! - **Knowledge graph** (SQLite): Entities and relations
//!
//! Agents interact with a single `Memory` trait that abstracts over all three stores.
pub mod consolidation;
pub mod knowledge;
pub mod migration;
pub mod semantic;
pub mod session;
pub mod structured;
pub mod usage;
mod substrate;
pub use substrate::MemorySubstrate;

View File

@@ -0,0 +1,436 @@
//! SQLite schema creation and migration.
//!
//! Creates all tables needed by the memory substrate on first boot.
use rusqlite::Connection;
/// Current schema version.
const SCHEMA_VERSION: u32 = 8;
/// Run all migrations to bring the database up to date.
pub fn run_migrations(conn: &Connection) -> Result<(), rusqlite::Error> {
let current_version = get_schema_version(conn);
if current_version < 1 {
migrate_v1(conn)?;
}
if current_version < 2 {
migrate_v2(conn)?;
}
if current_version < 3 {
migrate_v3(conn)?;
}
if current_version < 4 {
migrate_v4(conn)?;
}
if current_version < 5 {
migrate_v5(conn)?;
}
if current_version < 6 {
migrate_v6(conn)?;
}
if current_version < 7 {
migrate_v7(conn)?;
}
if current_version < 8 {
migrate_v8(conn)?;
}
set_schema_version(conn, SCHEMA_VERSION)?;
Ok(())
}
/// Get the current schema version from the database.
fn get_schema_version(conn: &Connection) -> u32 {
conn.pragma_query_value(None, "user_version", |row| row.get(0))
.unwrap_or(0)
}
/// Check if a column exists in a table (SQLite has no ADD COLUMN IF NOT EXISTS).
fn column_exists(conn: &Connection, table: &str, column: &str) -> bool {
let sql = format!("PRAGMA table_info({})", table);
let Ok(mut stmt) = conn.prepare(&sql) else {
return false;
};
let Ok(rows) = stmt.query_map([], |row| row.get::<_, String>(1)) else {
return false;
};
let names: Vec<String> = rows.filter_map(|r| r.ok()).collect();
names.iter().any(|n| n == column)
}
/// Set the schema version in the database.
fn set_schema_version(conn: &Connection, version: u32) -> Result<(), rusqlite::Error> {
conn.pragma_update(None, "user_version", version)
}
/// Version 1: Create all core tables.
fn migrate_v1(conn: &Connection) -> Result<(), rusqlite::Error> {
conn.execute_batch(
"
-- Agent registry
CREATE TABLE IF NOT EXISTS agents (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
manifest BLOB NOT NULL,
state TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
-- Session history
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
messages BLOB NOT NULL,
context_window_tokens INTEGER DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
-- Event log
CREATE TABLE IF NOT EXISTS events (
id TEXT PRIMARY KEY,
source_agent TEXT NOT NULL,
target TEXT NOT NULL,
payload BLOB NOT NULL,
timestamp TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_events_timestamp ON events(timestamp);
CREATE INDEX IF NOT EXISTS idx_events_source ON events(source_agent);
-- Key-value store (per-agent)
CREATE TABLE IF NOT EXISTS kv_store (
agent_id TEXT NOT NULL,
key TEXT NOT NULL,
value BLOB NOT NULL,
version INTEGER NOT NULL DEFAULT 1,
updated_at TEXT NOT NULL,
PRIMARY KEY (agent_id, key)
);
-- Task queue
CREATE TABLE IF NOT EXISTS task_queue (
id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
task_type TEXT NOT NULL,
payload BLOB NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
priority INTEGER NOT NULL DEFAULT 0,
scheduled_at TEXT,
created_at TEXT NOT NULL,
completed_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_task_status_priority ON task_queue(status, priority DESC);
-- Semantic memories
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
content TEXT NOT NULL,
source TEXT NOT NULL,
scope TEXT NOT NULL DEFAULT 'episodic',
confidence REAL NOT NULL DEFAULT 1.0,
metadata TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
accessed_at TEXT NOT NULL,
access_count INTEGER NOT NULL DEFAULT 0,
deleted INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_memories_agent ON memories(agent_id);
CREATE INDEX IF NOT EXISTS idx_memories_scope ON memories(scope);
-- Knowledge graph entities
CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY,
entity_type TEXT NOT NULL,
name TEXT NOT NULL,
properties TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
-- Knowledge graph relations
CREATE TABLE IF NOT EXISTS relations (
id TEXT PRIMARY KEY,
source_entity TEXT NOT NULL,
relation_type TEXT NOT NULL,
target_entity TEXT NOT NULL,
properties TEXT NOT NULL DEFAULT '{}',
confidence REAL NOT NULL DEFAULT 1.0,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_relations_source ON relations(source_entity);
CREATE INDEX IF NOT EXISTS idx_relations_target ON relations(target_entity);
CREATE INDEX IF NOT EXISTS idx_relations_type ON relations(relation_type);
-- Migration tracking
CREATE TABLE IF NOT EXISTS migrations (
version INTEGER PRIMARY KEY,
applied_at TEXT NOT NULL,
description TEXT
);
INSERT OR IGNORE INTO migrations (version, applied_at, description)
VALUES (1, datetime('now'), 'Initial schema');
",
)?;
Ok(())
}
/// Version 2: Add collaboration columns to task_queue for agent task delegation.
fn migrate_v2(conn: &Connection) -> Result<(), rusqlite::Error> {
// SQLite requires one ALTER TABLE per statement; check before adding
let cols = [
("title", "TEXT DEFAULT ''"),
("description", "TEXT DEFAULT ''"),
("assigned_to", "TEXT DEFAULT ''"),
("created_by", "TEXT DEFAULT ''"),
("result", "TEXT DEFAULT ''"),
];
for (name, typedef) in &cols {
if !column_exists(conn, "task_queue", name) {
conn.execute(
&format!("ALTER TABLE task_queue ADD COLUMN {} {}", name, typedef),
[],
)?;
}
}
conn.execute(
"INSERT OR IGNORE INTO migrations (version, applied_at, description) VALUES (2, datetime('now'), 'Add collaboration columns to task_queue')",
[],
)?;
Ok(())
}
/// Version 3: Add embedding column to memories table for vector search.
fn migrate_v3(conn: &Connection) -> Result<(), rusqlite::Error> {
if !column_exists(conn, "memories", "embedding") {
conn.execute(
"ALTER TABLE memories ADD COLUMN embedding BLOB DEFAULT NULL",
[],
)?;
}
conn.execute(
"INSERT OR IGNORE INTO migrations (version, applied_at, description) VALUES (3, datetime('now'), 'Add embedding column to memories')",
[],
)?;
Ok(())
}
/// Version 4: Add usage_events table for cost tracking and metering.
fn migrate_v4(conn: &Connection) -> Result<(), rusqlite::Error> {
conn.execute_batch(
"
CREATE TABLE IF NOT EXISTS usage_events (
id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
timestamp TEXT NOT NULL,
model TEXT NOT NULL,
input_tokens INTEGER NOT NULL DEFAULT 0,
output_tokens INTEGER NOT NULL DEFAULT 0,
cost_usd REAL NOT NULL DEFAULT 0.0,
tool_calls INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_usage_agent_time ON usage_events(agent_id, timestamp);
CREATE INDEX IF NOT EXISTS idx_usage_timestamp ON usage_events(timestamp);
INSERT OR IGNORE INTO migrations (version, applied_at, description)
VALUES (4, datetime('now'), 'Add usage_events table for cost tracking');
",
)?;
Ok(())
}
/// Version 5: Add canonical_sessions table for cross-channel persistent memory.
fn migrate_v5(conn: &Connection) -> Result<(), rusqlite::Error> {
conn.execute_batch(
"
CREATE TABLE IF NOT EXISTS canonical_sessions (
agent_id TEXT PRIMARY KEY,
messages BLOB NOT NULL,
compaction_cursor INTEGER NOT NULL DEFAULT 0,
compacted_summary TEXT,
updated_at TEXT NOT NULL
);
INSERT OR IGNORE INTO migrations (version, applied_at, description)
VALUES (5, datetime('now'), 'Add canonical_sessions for cross-channel memory');
",
)?;
Ok(())
}
/// Version 6: Add label column to sessions table.
fn migrate_v6(conn: &Connection) -> Result<(), rusqlite::Error> {
// Check if column already exists before ALTER (SQLite has no ADD COLUMN IF NOT EXISTS)
if !column_exists(conn, "sessions", "label") {
conn.execute("ALTER TABLE sessions ADD COLUMN label TEXT", [])?;
}
conn.execute(
"INSERT OR IGNORE INTO migrations (version, applied_at, description) VALUES (6, datetime('now'), 'Add label column to sessions for human-readable labels')",
[],
)?;
Ok(())
}
/// Version 7: Add paired_devices table for device pairing persistence.
fn migrate_v7(conn: &Connection) -> Result<(), rusqlite::Error> {
conn.execute_batch(
"
CREATE TABLE IF NOT EXISTS paired_devices (
device_id TEXT PRIMARY KEY,
display_name TEXT NOT NULL,
platform TEXT NOT NULL,
paired_at TEXT NOT NULL,
last_seen TEXT NOT NULL,
push_token TEXT
);
INSERT OR IGNORE INTO migrations (version, applied_at, description)
VALUES (7, datetime('now'), 'Add paired_devices table for device pairing');
",
)?;
Ok(())
}
/// Version 8: Add annotations tables for real-time collaboration (comments, highlights, reactions).
fn migrate_v8(conn: &Connection) -> Result<(), rusqlite::Error> {
conn.execute_batch(
"
-- Annotations table for comments, highlights, and suggestions
CREATE TABLE IF NOT EXISTS annotations (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
connection_id TEXT NOT NULL,
author_name TEXT NOT NULL,
annotation_type TEXT NOT NULL DEFAULT 'comment',
content TEXT NOT NULL,
message_index INTEGER NOT NULL,
char_start INTEGER NOT NULL DEFAULT 0,
char_end INTEGER NOT NULL DEFAULT 0,
line_start INTEGER,
line_end INTEGER,
parent_id TEXT,
status TEXT NOT NULL DEFAULT 'open',
priority TEXT DEFAULT 'normal',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
resolved_at TEXT,
resolved_by TEXT,
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE,
FOREIGN KEY (parent_id) REFERENCES annotations(id) ON DELETE SET NULL
);
CREATE INDEX IF NOT EXISTS idx_annotations_session ON annotations(session_id);
CREATE INDEX IF NOT EXISTS idx_annotations_message ON annotations(session_id, message_index);
CREATE INDEX IF NOT EXISTS idx_annotations_parent ON annotations(parent_id);
CREATE INDEX IF NOT EXISTS idx_annotations_status ON annotations(status);
-- Annotation reactions (emoji reactions, votes)
CREATE TABLE IF NOT EXISTS annotation_reactions (
id TEXT PRIMARY KEY,
annotation_id TEXT NOT NULL,
connection_id TEXT NOT NULL,
reaction_type TEXT NOT NULL,
reaction_value TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY (annotation_id) REFERENCES annotations(id) ON DELETE CASCADE,
UNIQUE(annotation_id, connection_id, reaction_type)
);
CREATE INDEX IF NOT EXISTS idx_reactions_annotation ON annotation_reactions(annotation_id);
-- Collaboration sessions for tracking active collaborative sessions
CREATE TABLE IF NOT EXISTS collab_sessions (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
owner_connection_id TEXT NOT NULL,
share_mode TEXT NOT NULL DEFAULT 'collaborative',
max_participants INTEGER DEFAULT 10,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_collab_session ON collab_sessions(session_id);
-- Presence log for tracking user activity in collaborative sessions
CREATE TABLE IF NOT EXISTS presence_log (
id TEXT PRIMARY KEY,
collab_session_id TEXT NOT NULL,
connection_id TEXT NOT NULL,
display_name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
joined_at TEXT NOT NULL,
left_at TEXT,
last_activity TEXT NOT NULL,
FOREIGN KEY (collab_session_id) REFERENCES collab_sessions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_presence_collab ON presence_log(collab_session_id);
CREATE INDEX IF NOT EXISTS idx_presence_connection ON presence_log(connection_id);
INSERT OR IGNORE INTO migrations (version, applied_at, description)
VALUES (8, datetime('now'), 'Add annotations and collaboration tables for real-time collaboration');
",
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_migration_creates_tables() {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
// Verify tables exist
let tables: Vec<String> = conn
.prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
.unwrap()
.query_map([], |row| row.get(0))
.unwrap()
.filter_map(|r| r.ok())
.collect();
assert!(tables.contains(&"agents".to_string()));
assert!(tables.contains(&"sessions".to_string()));
assert!(tables.contains(&"kv_store".to_string()));
assert!(tables.contains(&"memories".to_string()));
assert!(tables.contains(&"entities".to_string()));
assert!(tables.contains(&"relations".to_string()));
// v8 collaboration tables
assert!(tables.contains(&"annotations".to_string()));
assert!(tables.contains(&"annotation_reactions".to_string()));
assert!(tables.contains(&"collab_sessions".to_string()));
assert!(tables.contains(&"presence_log".to_string()));
}
#[test]
fn test_migration_idempotent() {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
run_migrations(&conn).unwrap(); // Should not error
}
#[test]
fn test_schema_version_is_8() {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
let version = get_schema_version(&conn);
assert_eq!(version, 8);
}
}

View File

@@ -0,0 +1,556 @@
//! Semantic memory store with vector embedding support.
//!
//! Phase 1: SQLite LIKE matching (fallback when no embeddings).
//! Phase 2: Vector cosine similarity search using stored embeddings.
//!
//! Embeddings are stored as BLOBs in the `embedding` column of the memories table.
//! When a query embedding is provided, recall uses cosine similarity ranking.
//! When no embeddings are available, falls back to LIKE matching.
use chrono::Utc;
use openfang_types::agent::AgentId;
use openfang_types::error::{OpenFangError, OpenFangResult};
use openfang_types::memory::{MemoryFilter, MemoryFragment, MemoryId, MemorySource};
use rusqlite::Connection;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tracing::debug;
/// Semantic store backed by SQLite with optional vector search.
#[derive(Clone)]
pub struct SemanticStore {
conn: Arc<Mutex<Connection>>,
}
impl SemanticStore {
/// Create a new semantic store wrapping the given connection.
pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
Self { conn }
}
/// Store a new memory fragment (without embedding).
pub fn remember(
&self,
agent_id: AgentId,
content: &str,
source: MemorySource,
scope: &str,
metadata: HashMap<String, serde_json::Value>,
) -> OpenFangResult<MemoryId> {
self.remember_with_embedding(agent_id, content, source, scope, metadata, None)
}
/// Store a new memory fragment with an optional embedding vector.
pub fn remember_with_embedding(
&self,
agent_id: AgentId,
content: &str,
source: MemorySource,
scope: &str,
metadata: HashMap<String, serde_json::Value>,
embedding: Option<&[f32]>,
) -> OpenFangResult<MemoryId> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let id = MemoryId::new();
let now = Utc::now().to_rfc3339();
let source_str = serde_json::to_string(&source)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let meta_str = serde_json::to_string(&metadata)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let embedding_bytes: Option<Vec<u8>> = embedding.map(embedding_to_bytes);
conn.execute(
"INSERT INTO memories (id, agent_id, content, source, scope, confidence, metadata, created_at, accessed_at, access_count, deleted, embedding)
VALUES (?1, ?2, ?3, ?4, ?5, 1.0, ?6, ?7, ?7, 0, 0, ?8)",
rusqlite::params![
id.0.to_string(),
agent_id.0.to_string(),
content,
source_str,
scope,
meta_str,
now,
embedding_bytes,
],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(id)
}
/// Search for memories using text matching (fallback, no embeddings).
pub fn recall(
&self,
query: &str,
limit: usize,
filter: Option<MemoryFilter>,
) -> OpenFangResult<Vec<MemoryFragment>> {
self.recall_with_embedding(query, limit, filter, None)
}
/// Search for memories using vector similarity when a query embedding is provided,
/// falling back to LIKE matching otherwise.
pub fn recall_with_embedding(
&self,
query: &str,
limit: usize,
filter: Option<MemoryFilter>,
query_embedding: Option<&[f32]>,
) -> OpenFangResult<Vec<MemoryFragment>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
// Build SQL: fetch candidates (broader than limit for vector re-ranking)
let fetch_limit = if query_embedding.is_some() {
// Fetch more candidates for vector search re-ranking
(limit * 10).max(100)
} else {
limit
};
let mut sql = String::from(
"SELECT id, agent_id, content, source, scope, confidence, metadata, created_at, accessed_at, access_count, embedding
FROM memories WHERE deleted = 0",
);
let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut param_idx = 1;
// Text search filter (only when no embeddings — vector search handles relevance)
if query_embedding.is_none() && !query.is_empty() {
sql.push_str(&format!(" AND content LIKE ?{param_idx}"));
params.push(Box::new(format!("%{query}%")));
param_idx += 1;
}
// Apply filters
if let Some(ref f) = filter {
if let Some(agent_id) = f.agent_id {
sql.push_str(&format!(" AND agent_id = ?{param_idx}"));
params.push(Box::new(agent_id.0.to_string()));
param_idx += 1;
}
if let Some(ref scope) = f.scope {
sql.push_str(&format!(" AND scope = ?{param_idx}"));
params.push(Box::new(scope.clone()));
param_idx += 1;
}
if let Some(min_conf) = f.min_confidence {
sql.push_str(&format!(" AND confidence >= ?{param_idx}"));
params.push(Box::new(min_conf as f64));
param_idx += 1;
}
if let Some(ref source) = f.source {
let source_str = serde_json::to_string(source)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
sql.push_str(&format!(" AND source = ?{param_idx}"));
params.push(Box::new(source_str));
let _ = param_idx;
}
}
sql.push_str(" ORDER BY accessed_at DESC, access_count DESC");
sql.push_str(&format!(" LIMIT {fetch_limit}"));
let mut stmt = conn
.prepare(&sql)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
params.iter().map(|p| p.as_ref()).collect();
let rows = stmt
.query_map(param_refs.as_slice(), |row| {
let id_str: String = row.get(0)?;
let agent_str: String = row.get(1)?;
let content: String = row.get(2)?;
let source_str: String = row.get(3)?;
let scope: String = row.get(4)?;
let confidence: f64 = row.get(5)?;
let meta_str: String = row.get(6)?;
let created_str: String = row.get(7)?;
let accessed_str: String = row.get(8)?;
let access_count: i64 = row.get(9)?;
let embedding_bytes: Option<Vec<u8>> = row.get(10)?;
Ok((
id_str,
agent_str,
content,
source_str,
scope,
confidence,
meta_str,
created_str,
accessed_str,
access_count,
embedding_bytes,
))
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut fragments = Vec::new();
for row_result in rows {
let (
id_str,
agent_str,
content,
source_str,
scope,
confidence,
meta_str,
created_str,
accessed_str,
access_count,
embedding_bytes,
) = row_result.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let id = uuid::Uuid::parse_str(&id_str)
.map(MemoryId)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let agent_id = uuid::Uuid::parse_str(&agent_str)
.map(openfang_types::agent::AgentId)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let source: MemorySource =
serde_json::from_str(&source_str).unwrap_or(MemorySource::System);
let metadata: HashMap<String, serde_json::Value> =
serde_json::from_str(&meta_str).unwrap_or_default();
let created_at = chrono::DateTime::parse_from_rfc3339(&created_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
let accessed_at = chrono::DateTime::parse_from_rfc3339(&accessed_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
let embedding = embedding_bytes.as_deref().map(embedding_from_bytes);
fragments.push(MemoryFragment {
id,
agent_id,
content,
embedding,
metadata,
source,
confidence: confidence as f32,
created_at,
accessed_at,
access_count: access_count as u64,
scope,
});
}
// If we have a query embedding, re-rank by cosine similarity
if let Some(qe) = query_embedding {
fragments.sort_by(|a, b| {
let sim_a = a
.embedding
.as_deref()
.map(|e| cosine_similarity(qe, e))
.unwrap_or(-1.0);
let sim_b = b
.embedding
.as_deref()
.map(|e| cosine_similarity(qe, e))
.unwrap_or(-1.0);
sim_b
.partial_cmp(&sim_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
fragments.truncate(limit);
debug!(
"Vector recall: {} results from {} candidates",
fragments.len(),
fetch_limit
);
}
// Update access counts for returned memories
for frag in &fragments {
let _ = conn.execute(
"UPDATE memories SET access_count = access_count + 1, accessed_at = ?1 WHERE id = ?2",
rusqlite::params![Utc::now().to_rfc3339(), frag.id.0.to_string()],
);
}
Ok(fragments)
}
/// Soft-delete a memory fragment.
pub fn forget(&self, id: MemoryId) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
conn.execute(
"UPDATE memories SET deleted = 1 WHERE id = ?1",
rusqlite::params![id.0.to_string()],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Update the embedding for an existing memory.
pub fn update_embedding(&self, id: MemoryId, embedding: &[f32]) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let bytes = embedding_to_bytes(embedding);
conn.execute(
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
rusqlite::params![bytes, id.0.to_string()],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
}
/// Compute cosine similarity between two vectors.
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < f32::EPSILON {
0.0
} else {
dot / denom
}
}
/// Serialize embedding to bytes for SQLite BLOB storage.
fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(embedding.len() * 4);
for &val in embedding {
bytes.extend_from_slice(&val.to_le_bytes());
}
bytes
}
/// Deserialize embedding from bytes.
fn embedding_from_bytes(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::migration::run_migrations;
fn setup() -> SemanticStore {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
SemanticStore::new(Arc::new(Mutex::new(conn)))
}
#[test]
fn test_remember_and_recall() {
let store = setup();
let agent_id = AgentId::new();
store
.remember(
agent_id,
"The user likes Rust programming",
MemorySource::Conversation,
"episodic",
HashMap::new(),
)
.unwrap();
let results = store.recall("Rust", 10, None).unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("Rust"));
}
#[test]
fn test_recall_with_filter() {
let store = setup();
let agent_id = AgentId::new();
store
.remember(
agent_id,
"Memory A",
MemorySource::Conversation,
"episodic",
HashMap::new(),
)
.unwrap();
store
.remember(
AgentId::new(),
"Memory B",
MemorySource::Conversation,
"episodic",
HashMap::new(),
)
.unwrap();
let filter = MemoryFilter::agent(agent_id);
let results = store.recall("Memory", 10, Some(filter)).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].content, "Memory A");
}
#[test]
fn test_forget() {
let store = setup();
let agent_id = AgentId::new();
let id = store
.remember(
agent_id,
"To forget",
MemorySource::Conversation,
"episodic",
HashMap::new(),
)
.unwrap();
store.forget(id).unwrap();
let results = store.recall("To forget", 10, None).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_remember_with_embedding() {
let store = setup();
let agent_id = AgentId::new();
let embedding = vec![0.1, 0.2, 0.3, 0.4];
let id = store
.remember_with_embedding(
agent_id,
"Rust is great",
MemorySource::Conversation,
"episodic",
HashMap::new(),
Some(&embedding),
)
.unwrap();
assert_ne!(id.0.to_string(), "");
}
#[test]
fn test_vector_recall_ranking() {
let store = setup();
let agent_id = AgentId::new();
// Store 3 memories with embeddings pointing in different directions
let emb_rust = vec![0.9, 0.1, 0.0, 0.0]; // "Rust" direction
let emb_python = vec![0.0, 0.0, 0.9, 0.1]; // "Python" direction
let emb_mixed = vec![0.5, 0.5, 0.0, 0.0]; // mixed
store
.remember_with_embedding(
agent_id,
"Rust is a systems language",
MemorySource::Conversation,
"episodic",
HashMap::new(),
Some(&emb_rust),
)
.unwrap();
store
.remember_with_embedding(
agent_id,
"Python is interpreted",
MemorySource::Conversation,
"episodic",
HashMap::new(),
Some(&emb_python),
)
.unwrap();
store
.remember_with_embedding(
agent_id,
"Both are popular",
MemorySource::Conversation,
"episodic",
HashMap::new(),
Some(&emb_mixed),
)
.unwrap();
// Query with a "Rust"-like embedding
let query_emb = vec![0.85, 0.15, 0.0, 0.0];
let results = store
.recall_with_embedding("", 3, None, Some(&query_emb))
.unwrap();
assert_eq!(results.len(), 3);
// Rust memory should be first (highest cosine similarity)
assert!(results[0].content.contains("Rust"));
// Python memory should be last (lowest similarity)
assert!(results[2].content.contains("Python"));
}
#[test]
fn test_update_embedding() {
let store = setup();
let agent_id = AgentId::new();
let id = store
.remember(
agent_id,
"No embedding yet",
MemorySource::Conversation,
"episodic",
HashMap::new(),
)
.unwrap();
// Update with embedding
let emb = vec![1.0, 0.0, 0.0];
store.update_embedding(id, &emb).unwrap();
// Verify the embedding is stored by doing vector recall
let query_emb = vec![1.0, 0.0, 0.0];
let results = store
.recall_with_embedding("", 10, None, Some(&query_emb))
.unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].embedding.is_some());
assert_eq!(results[0].embedding.as_ref().unwrap().len(), 3);
}
#[test]
fn test_mixed_embedded_and_non_embedded() {
let store = setup();
let agent_id = AgentId::new();
// One memory with embedding, one without
store
.remember_with_embedding(
agent_id,
"Has embedding",
MemorySource::Conversation,
"episodic",
HashMap::new(),
Some(&[1.0, 0.0]),
)
.unwrap();
store
.remember(
agent_id,
"No embedding",
MemorySource::Conversation,
"episodic",
HashMap::new(),
)
.unwrap();
// Vector recall should rank embedded memory higher
let results = store
.recall_with_embedding("", 10, None, Some(&[1.0, 0.0]))
.unwrap();
assert_eq!(results.len(), 2);
// Embedded memory should rank first
assert_eq!(results[0].content, "Has embedding");
}
}

View File

@@ -0,0 +1,796 @@
//! Session management — load/save conversation history.
use chrono::Utc;
use openfang_types::agent::{AgentId, SessionId};
use openfang_types::error::{OpenFangError, OpenFangResult};
use openfang_types::message::{ContentBlock, Message, MessageContent, Role};
use rusqlite::Connection;
use std::io::Write;
use std::path::Path;
use std::sync::{Arc, Mutex};
/// A conversation session with message history.
#[derive(Debug, Clone)]
pub struct Session {
/// Session ID.
pub id: SessionId,
/// Owning agent ID.
pub agent_id: AgentId,
/// Conversation messages.
pub messages: Vec<Message>,
/// Estimated token count for the context window.
pub context_window_tokens: u64,
/// Optional human-readable session label.
pub label: Option<String>,
}
/// Session store backed by SQLite.
#[derive(Clone)]
pub struct SessionStore {
conn: Arc<Mutex<Connection>>,
}
impl SessionStore {
/// Create a new session store wrapping the given connection.
pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
Self { conn }
}
/// Load a session from the database.
pub fn get_session(&self, session_id: SessionId) -> OpenFangResult<Option<Session>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare("SELECT agent_id, messages, context_window_tokens, label FROM sessions WHERE id = ?1")
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let result = stmt.query_row(rusqlite::params![session_id.0.to_string()], |row| {
let agent_str: String = row.get(0)?;
let messages_blob: Vec<u8> = row.get(1)?;
let tokens: i64 = row.get(2)?;
let label: Option<String> = row.get(3).unwrap_or(None);
Ok((agent_str, messages_blob, tokens, label))
});
match result {
Ok((agent_str, messages_blob, tokens, label)) => {
let agent_id = uuid::Uuid::parse_str(&agent_str)
.map(AgentId)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let messages: Vec<Message> = rmp_serde::from_slice(&messages_blob)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
Ok(Some(Session {
id: session_id,
agent_id,
messages,
context_window_tokens: tokens as u64,
label,
}))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(OpenFangError::Memory(e.to_string())),
}
}
/// Save a session to the database.
pub fn save_session(&self, session: &Session) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let messages_blob = rmp_serde::to_vec_named(&session.messages)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO sessions (id, agent_id, messages, context_window_tokens, label, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?6)
ON CONFLICT(id) DO UPDATE SET messages = ?3, context_window_tokens = ?4, label = ?5, updated_at = ?6",
rusqlite::params![
session.id.0.to_string(),
session.agent_id.0.to_string(),
messages_blob,
session.context_window_tokens as i64,
session.label.as_deref(),
now,
],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Delete a session from the database.
pub fn delete_session(&self, session_id: SessionId) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
conn.execute(
"DELETE FROM sessions WHERE id = ?1",
rusqlite::params![session_id.0.to_string()],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Delete all sessions belonging to an agent.
pub fn delete_agent_sessions(&self, agent_id: AgentId) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
conn.execute(
"DELETE FROM sessions WHERE agent_id = ?1",
rusqlite::params![agent_id.0.to_string()],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// List all sessions with metadata (session_id, agent_id, message_count, created_at).
pub fn list_sessions(&self) -> OpenFangResult<Vec<serde_json::Value>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare(
"SELECT id, agent_id, messages, created_at, label FROM sessions ORDER BY created_at DESC",
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows = stmt
.query_map([], |row| {
let session_id: String = row.get(0)?;
let agent_id: String = row.get(1)?;
let messages_blob: Vec<u8> = row.get(2)?;
let created_at: String = row.get(3)?;
let label: Option<String> = row.get(4)?;
// Deserialize just to count messages
let msg_count = rmp_serde::from_slice::<Vec<Message>>(&messages_blob)
.map(|m| m.len())
.unwrap_or(0);
Ok(serde_json::json!({
"session_id": session_id,
"agent_id": agent_id,
"message_count": msg_count,
"created_at": created_at,
"label": label,
}))
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut sessions = Vec::new();
for row in rows {
sessions.push(row.map_err(|e| OpenFangError::Memory(e.to_string()))?);
}
Ok(sessions)
}
/// Create a new empty session for an agent.
pub fn create_session(&self, agent_id: AgentId) -> OpenFangResult<Session> {
let session = Session {
id: SessionId::new(),
agent_id,
messages: Vec::new(),
context_window_tokens: 0,
label: None,
};
self.save_session(&session)?;
Ok(session)
}
/// Set the label on an existing session.
pub fn set_session_label(
&self,
session_id: SessionId,
label: Option<&str>,
) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
conn.execute(
"UPDATE sessions SET label = ?1, updated_at = ?2 WHERE id = ?3",
rusqlite::params![label, Utc::now().to_rfc3339(), session_id.0.to_string()],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Find a session by label for a given agent.
pub fn find_session_by_label(
&self,
agent_id: AgentId,
label: &str,
) -> OpenFangResult<Option<Session>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare(
"SELECT id, messages, context_window_tokens, label FROM sessions \
WHERE agent_id = ?1 AND label = ?2 LIMIT 1",
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let result = stmt.query_row(rusqlite::params![agent_id.0.to_string(), label], |row| {
let id_str: String = row.get(0)?;
let messages_blob: Vec<u8> = row.get(1)?;
let tokens: i64 = row.get(2)?;
let lbl: Option<String> = row.get(3).unwrap_or(None);
Ok((id_str, messages_blob, tokens, lbl))
});
match result {
Ok((id_str, messages_blob, tokens, lbl)) => {
let session_id = uuid::Uuid::parse_str(&id_str)
.map(SessionId)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let messages: Vec<Message> = rmp_serde::from_slice(&messages_blob)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
Ok(Some(Session {
id: session_id,
agent_id,
messages,
context_window_tokens: tokens as u64,
label: lbl,
}))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(OpenFangError::Memory(e.to_string())),
}
}
}
impl SessionStore {
/// List all sessions for a specific agent.
pub fn list_agent_sessions(&self, agent_id: AgentId) -> OpenFangResult<Vec<serde_json::Value>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare(
"SELECT id, messages, created_at, label FROM sessions WHERE agent_id = ?1 ORDER BY created_at DESC",
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows = stmt
.query_map(rusqlite::params![agent_id.0.to_string()], |row| {
let session_id: String = row.get(0)?;
let messages_blob: Vec<u8> = row.get(1)?;
let created_at: String = row.get(2)?;
let label: Option<String> = row.get(3)?;
let msg_count = rmp_serde::from_slice::<Vec<Message>>(&messages_blob)
.map(|m| m.len())
.unwrap_or(0);
Ok(serde_json::json!({
"session_id": session_id,
"message_count": msg_count,
"created_at": created_at,
"label": label,
}))
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut sessions = Vec::new();
for row in rows {
sessions.push(row.map_err(|e| OpenFangError::Memory(e.to_string()))?);
}
Ok(sessions)
}
/// Create a new session with an optional label.
pub fn create_session_with_label(
&self,
agent_id: AgentId,
label: Option<&str>,
) -> OpenFangResult<Session> {
let session = Session {
id: SessionId::new(),
agent_id,
messages: Vec::new(),
context_window_tokens: 0,
label: label.map(|s| s.to_string()),
};
self.save_session(&session)?;
Ok(session)
}
/// Store an LLM-generated summary, replacing older messages with the summary
/// and keeping only the specified recent messages.
///
/// This is used by the LLM-based compactor to replace text-truncation compaction
/// with an intelligent, LLM-generated summary of older conversation history.
pub fn store_llm_summary(
&self,
agent_id: AgentId,
summary: &str,
kept_messages: Vec<Message>,
) -> OpenFangResult<()> {
let mut canonical = self.load_canonical(agent_id)?;
canonical.compacted_summary = Some(summary.to_string());
canonical.messages = kept_messages;
canonical.compaction_cursor = 0;
canonical.updated_at = Utc::now().to_rfc3339();
self.save_canonical(&canonical)
}
}
/// Default number of recent messages to include from canonical session.
const DEFAULT_CANONICAL_WINDOW: usize = 50;
/// Default compaction threshold: when message count exceeds this, compact older messages.
const DEFAULT_COMPACTION_THRESHOLD: usize = 100;
/// A canonical session stores persistent cross-channel context for an agent.
///
/// Unlike regular sessions (one per channel interaction), there is one canonical
/// session per agent. All channels contribute to it, so what a user tells an agent
/// on Telegram is remembered on Discord.
#[derive(Debug, Clone)]
pub struct CanonicalSession {
/// The agent this session belongs to.
pub agent_id: AgentId,
/// Full message history (post-compaction window).
pub messages: Vec<Message>,
/// Index marking how far compaction has processed.
pub compaction_cursor: usize,
/// Summary of compacted (older) messages.
pub compacted_summary: Option<String>,
/// Last update time.
pub updated_at: String,
}
impl SessionStore {
/// Load the canonical session for an agent, creating one if it doesn't exist.
pub fn load_canonical(&self, agent_id: AgentId) -> OpenFangResult<CanonicalSession> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare(
"SELECT messages, compaction_cursor, compacted_summary, updated_at \
FROM canonical_sessions WHERE agent_id = ?1",
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let result = stmt.query_row(rusqlite::params![agent_id.0.to_string()], |row| {
let messages_blob: Vec<u8> = row.get(0)?;
let cursor: i64 = row.get(1)?;
let summary: Option<String> = row.get(2)?;
let updated_at: String = row.get(3)?;
Ok((messages_blob, cursor, summary, updated_at))
});
match result {
Ok((messages_blob, cursor, summary, updated_at)) => {
let messages: Vec<Message> = rmp_serde::from_slice(&messages_blob)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
Ok(CanonicalSession {
agent_id,
messages,
compaction_cursor: cursor as usize,
compacted_summary: summary,
updated_at,
})
}
Err(rusqlite::Error::QueryReturnedNoRows) => {
let now = Utc::now().to_rfc3339();
Ok(CanonicalSession {
agent_id,
messages: Vec::new(),
compaction_cursor: 0,
compacted_summary: None,
updated_at: now,
})
}
Err(e) => Err(OpenFangError::Memory(e.to_string())),
}
}
/// Append new messages to the canonical session and compact if over threshold.
///
/// Compaction summarizes old messages into a text summary and trims the
/// message list. The `compaction_threshold` controls when this happens
/// (default: 100 messages).
pub fn append_canonical(
&self,
agent_id: AgentId,
new_messages: &[Message],
compaction_threshold: Option<usize>,
) -> OpenFangResult<CanonicalSession> {
let mut canonical = self.load_canonical(agent_id)?;
canonical.messages.extend(new_messages.iter().cloned());
let threshold = compaction_threshold.unwrap_or(DEFAULT_COMPACTION_THRESHOLD);
// Compact if over threshold
if canonical.messages.len() > threshold {
let keep_count = DEFAULT_CANONICAL_WINDOW;
let to_compact = canonical.messages.len().saturating_sub(keep_count);
if to_compact > canonical.compaction_cursor {
// Build a summary from the messages being compacted
let compacting = &canonical.messages[canonical.compaction_cursor..to_compact];
let mut summary_parts: Vec<String> = Vec::new();
if let Some(ref existing) = canonical.compacted_summary {
summary_parts.push(existing.clone());
}
for msg in compacting {
let role = match msg.role {
openfang_types::message::Role::User => "User",
openfang_types::message::Role::Assistant => "Assistant",
openfang_types::message::Role::System => "System",
};
let text = msg.content.text_content();
if !text.is_empty() {
// Truncate individual messages in summary to keep it compact (UTF-8 safe)
let truncated = if text.len() > 200 {
format!("{}...", openfang_types::truncate_str(&text, 200))
} else {
text
};
summary_parts.push(format!("{role}: {truncated}"));
}
}
// Keep summary under ~4000 chars (UTF-8 safe)
let mut full_summary = summary_parts.join("\n");
if full_summary.len() > 4000 {
let start = full_summary.len() - 4000;
// Find the next char boundary at or after `start`
let safe_start = (start..full_summary.len())
.find(|&i| full_summary.is_char_boundary(i))
.unwrap_or(full_summary.len());
full_summary = full_summary[safe_start..].to_string();
}
canonical.compacted_summary = Some(full_summary);
canonical.compaction_cursor = to_compact;
// Trim messages: keep only the recent window
canonical.messages = canonical.messages.split_off(to_compact);
canonical.compaction_cursor = 0; // reset cursor since we trimmed
}
}
canonical.updated_at = Utc::now().to_rfc3339();
self.save_canonical(&canonical)?;
Ok(canonical)
}
/// Get recent messages from canonical session for context injection.
///
/// Returns up to `window_size` recent messages (default 50), plus
/// the compacted summary if available.
pub fn canonical_context(
&self,
agent_id: AgentId,
window_size: Option<usize>,
) -> OpenFangResult<(Option<String>, Vec<Message>)> {
let canonical = self.load_canonical(agent_id)?;
let window = window_size.unwrap_or(DEFAULT_CANONICAL_WINDOW);
let start = canonical.messages.len().saturating_sub(window);
let recent = canonical.messages[start..].to_vec();
Ok((canonical.compacted_summary.clone(), recent))
}
/// Persist a canonical session to SQLite.
fn save_canonical(&self, canonical: &CanonicalSession) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let messages_blob = rmp_serde::to_vec(&canonical.messages)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
conn.execute(
"INSERT INTO canonical_sessions (agent_id, messages, compaction_cursor, compacted_summary, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT(agent_id) DO UPDATE SET messages = ?2, compaction_cursor = ?3, compacted_summary = ?4, updated_at = ?5",
rusqlite::params![
canonical.agent_id.0.to_string(),
messages_blob,
canonical.compaction_cursor as i64,
canonical.compacted_summary,
canonical.updated_at,
],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
}
/// A single JSONL line in the session mirror file.
#[derive(serde::Serialize)]
struct JsonlLine {
timestamp: String,
role: String,
content: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
tool_use: Option<serde_json::Value>,
}
impl SessionStore {
/// Write a human-readable JSONL mirror of a session to disk.
///
/// Best-effort: errors are returned but should be logged and never
/// affect the primary SQLite store.
pub fn write_jsonl_mirror(
&self,
session: &Session,
sessions_dir: &Path,
) -> Result<(), std::io::Error> {
std::fs::create_dir_all(sessions_dir)?;
let path = sessions_dir.join(format!("{}.jsonl", session.id.0));
let mut file = std::fs::File::create(&path)?;
let now = Utc::now().to_rfc3339();
for msg in &session.messages {
let role_str = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
};
let mut text_parts: Vec<String> = Vec::new();
let mut tool_parts: Vec<serde_json::Value> = Vec::new();
match &msg.content {
MessageContent::Text(t) => {
text_parts.push(t.clone());
}
MessageContent::Blocks(blocks) => {
for block in blocks {
match block {
ContentBlock::Text { text } => {
text_parts.push(text.clone());
}
ContentBlock::ToolUse { id, name, input } => {
tool_parts.push(serde_json::json!({
"type": "tool_use",
"id": id,
"name": name,
"input": input,
}));
}
ContentBlock::ToolResult {
tool_use_id,
content,
is_error,
} => {
tool_parts.push(serde_json::json!({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content,
"is_error": is_error,
}));
}
ContentBlock::Image { media_type, .. } => {
text_parts.push(format!("[image: {media_type}]"));
}
ContentBlock::Thinking { thinking } => {
text_parts.push(format!(
"[thinking: {}]",
&thinking[..thinking.len().min(200)]
));
}
ContentBlock::Unknown => {}
}
}
}
}
let line = JsonlLine {
timestamp: now.clone(),
role: role_str.to_string(),
content: serde_json::Value::String(text_parts.join("\n")),
tool_use: if tool_parts.is_empty() {
None
} else {
Some(serde_json::Value::Array(tool_parts))
},
};
serde_json::to_writer(&mut file, &line).map_err(std::io::Error::other)?;
file.write_all(b"\n")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::migration::run_migrations;
fn setup() -> SessionStore {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
SessionStore::new(Arc::new(Mutex::new(conn)))
}
#[test]
fn test_create_and_load_session() {
let store = setup();
let agent_id = AgentId::new();
let session = store.create_session(agent_id).unwrap();
let loaded = store.get_session(session.id).unwrap().unwrap();
assert_eq!(loaded.agent_id, agent_id);
assert!(loaded.messages.is_empty());
}
#[test]
fn test_save_and_load_with_messages() {
let store = setup();
let agent_id = AgentId::new();
let mut session = store.create_session(agent_id).unwrap();
session.messages.push(Message::user("Hello"));
session.messages.push(Message::assistant("Hi there!"));
store.save_session(&session).unwrap();
let loaded = store.get_session(session.id).unwrap().unwrap();
assert_eq!(loaded.messages.len(), 2);
}
#[test]
fn test_get_missing_session() {
let store = setup();
let result = store.get_session(SessionId::new()).unwrap();
assert!(result.is_none());
}
#[test]
fn test_delete_session() {
let store = setup();
let agent_id = AgentId::new();
let session = store.create_session(agent_id).unwrap();
let sid = session.id;
assert!(store.get_session(sid).unwrap().is_some());
store.delete_session(sid).unwrap();
assert!(store.get_session(sid).unwrap().is_none());
}
#[test]
fn test_delete_agent_sessions() {
let store = setup();
let agent_id = AgentId::new();
let s1 = store.create_session(agent_id).unwrap();
let s2 = store.create_session(agent_id).unwrap();
assert!(store.get_session(s1.id).unwrap().is_some());
assert!(store.get_session(s2.id).unwrap().is_some());
store.delete_agent_sessions(agent_id).unwrap();
assert!(store.get_session(s1.id).unwrap().is_none());
assert!(store.get_session(s2.id).unwrap().is_none());
}
#[test]
fn test_canonical_load_creates_empty() {
let store = setup();
let agent_id = AgentId::new();
let canonical = store.load_canonical(agent_id).unwrap();
assert_eq!(canonical.agent_id, agent_id);
assert!(canonical.messages.is_empty());
assert!(canonical.compacted_summary.is_none());
assert_eq!(canonical.compaction_cursor, 0);
}
#[test]
fn test_canonical_append_and_load() {
let store = setup();
let agent_id = AgentId::new();
// Append from "Telegram"
let msgs1 = vec![
Message::user("Hello from Telegram"),
Message::assistant("Hi! I'm your agent."),
];
store.append_canonical(agent_id, &msgs1, None).unwrap();
// Append from "Discord"
let msgs2 = vec![
Message::user("Now I'm on Discord"),
Message::assistant("I remember you from Telegram!"),
];
let canonical = store.append_canonical(agent_id, &msgs2, None).unwrap();
// Should have all 4 messages
assert_eq!(canonical.messages.len(), 4);
}
#[test]
fn test_canonical_context_window() {
let store = setup();
let agent_id = AgentId::new();
// Add 10 messages
let msgs: Vec<Message> = (0..10)
.map(|i| Message::user(format!("Message {i}")))
.collect();
store.append_canonical(agent_id, &msgs, None).unwrap();
// Request window of 3
let (summary, recent) = store.canonical_context(agent_id, Some(3)).unwrap();
assert_eq!(recent.len(), 3);
assert!(summary.is_none()); // No compaction yet
}
#[test]
fn test_canonical_compaction() {
let store = setup();
let agent_id = AgentId::new();
// Add 120 messages (over the default 100 threshold)
let msgs: Vec<Message> = (0..120)
.map(|i| Message::user(format!("Message number {i} with some content")))
.collect();
let canonical = store.append_canonical(agent_id, &msgs, Some(100)).unwrap();
// After compaction: should keep DEFAULT_CANONICAL_WINDOW (50) messages
assert!(canonical.messages.len() <= 60); // some tolerance
assert!(canonical.compacted_summary.is_some());
}
#[test]
fn test_canonical_cross_channel_roundtrip() {
let store = setup();
let agent_id = AgentId::new();
// Channel 1: user tells agent their name
store
.append_canonical(
agent_id,
&[
Message::user("My name is Jaber"),
Message::assistant("Nice to meet you, Jaber!"),
],
None,
)
.unwrap();
// Channel 2: different channel queries same agent
let (summary, recent) = store.canonical_context(agent_id, None).unwrap();
// The agent should have context about "Jaber" from the previous channel
let all_text: String = recent.iter().map(|m| m.content.text_content()).collect();
assert!(all_text.contains("Jaber"));
assert!(summary.is_none()); // Only 2 messages, no compaction
}
#[test]
fn test_jsonl_mirror_write() {
let store = setup();
let agent_id = AgentId::new();
let mut session = store.create_session(agent_id).unwrap();
session
.messages
.push(openfang_types::message::Message::user("Hello"));
session
.messages
.push(openfang_types::message::Message::assistant("Hi there!"));
store.save_session(&session).unwrap();
let dir = tempfile::TempDir::new().unwrap();
let sessions_dir = dir.path().join("sessions");
store.write_jsonl_mirror(&session, &sessions_dir).unwrap();
let jsonl_path = sessions_dir.join(format!("{}.jsonl", session.id.0));
assert!(jsonl_path.exists());
let content = std::fs::read_to_string(&jsonl_path).unwrap();
let lines: Vec<&str> = content.trim().split('\n').collect();
assert_eq!(lines.len(), 2);
// Verify first line is user message
let line1: serde_json::Value = serde_json::from_str(lines[0]).unwrap();
assert_eq!(line1["role"], "user");
assert_eq!(line1["content"], "Hello");
// Verify second line is assistant message
let line2: serde_json::Value = serde_json::from_str(lines[1]).unwrap();
assert_eq!(line2["role"], "assistant");
assert_eq!(line2["content"], "Hi there!");
assert!(line2.get("tool_use").is_none());
}
}

View File

@@ -0,0 +1,448 @@
//! SQLite structured store for key-value pairs and agent persistence.
use chrono::Utc;
use openfang_types::agent::{AgentEntry, AgentId};
use openfang_types::error::{OpenFangError, OpenFangResult};
use rusqlite::Connection;
use std::sync::{Arc, Mutex};
/// Structured store backed by SQLite for key-value operations and agent storage.
#[derive(Clone)]
pub struct StructuredStore {
conn: Arc<Mutex<Connection>>,
}
impl StructuredStore {
/// Create a new structured store wrapping the given connection.
pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
Self { conn }
}
/// Get a value from the key-value store.
pub fn get(&self, agent_id: AgentId, key: &str) -> OpenFangResult<Option<serde_json::Value>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare("SELECT value FROM kv_store WHERE agent_id = ?1 AND key = ?2")
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let result = stmt.query_row(rusqlite::params![agent_id.0.to_string(), key], |row| {
let blob: Vec<u8> = row.get(0)?;
Ok(blob)
});
match result {
Ok(blob) => {
let value: serde_json::Value = serde_json::from_slice(&blob)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
Ok(Some(value))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(OpenFangError::Memory(e.to_string())),
}
}
/// Set a value in the key-value store.
pub fn set(
&self,
agent_id: AgentId,
key: &str,
value: serde_json::Value,
) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let blob =
serde_json::to_vec(&value).map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO kv_store (agent_id, key, value, version, updated_at) VALUES (?1, ?2, ?3, 1, ?4)
ON CONFLICT(agent_id, key) DO UPDATE SET value = ?3, version = version + 1, updated_at = ?4",
rusqlite::params![agent_id.0.to_string(), key, blob, now],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Delete a value from the key-value store.
pub fn delete(&self, agent_id: AgentId, key: &str) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
conn.execute(
"DELETE FROM kv_store WHERE agent_id = ?1 AND key = ?2",
rusqlite::params![agent_id.0.to_string(), key],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// List all key-value pairs for an agent.
pub fn list_kv(&self, agent_id: AgentId) -> OpenFangResult<Vec<(String, serde_json::Value)>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare("SELECT key, value FROM kv_store WHERE agent_id = ?1 ORDER BY key")
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows = stmt
.query_map(rusqlite::params![agent_id.0.to_string()], |row| {
let key: String = row.get(0)?;
let val_str: String = row.get(1)?;
Ok((key, val_str))
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut pairs = Vec::new();
for row in rows {
let (key, val_str) = row.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let value: serde_json::Value =
serde_json::from_str(&val_str).unwrap_or(serde_json::Value::String(val_str));
pairs.push((key, value));
}
Ok(pairs)
}
/// Save an agent entry to the database.
pub fn save_agent(&self, entry: &AgentEntry) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
// Use named-field encoding so new fields with #[serde(default)] are
// handled gracefully when the struct evolves between versions.
let manifest_blob = rmp_serde::to_vec_named(&entry.manifest)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let state_str = serde_json::to_string(&entry.state)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let now = Utc::now().to_rfc3339();
// Add session_id column if it doesn't exist yet (migration compat)
let _ = conn.execute(
"ALTER TABLE agents ADD COLUMN session_id TEXT DEFAULT ''",
[],
);
conn.execute(
"INSERT INTO agents (id, name, manifest, state, created_at, updated_at, session_id)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
ON CONFLICT(id) DO UPDATE SET name = ?2, manifest = ?3, state = ?4, updated_at = ?6, session_id = ?7",
rusqlite::params![
entry.id.0.to_string(),
entry.name,
manifest_blob,
state_str,
entry.created_at.to_rfc3339(),
now,
entry.session_id.0.to_string(),
],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Load an agent entry from the database.
pub fn load_agent(&self, agent_id: AgentId) -> OpenFangResult<Option<AgentEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare("SELECT id, name, manifest, state, created_at, updated_at, session_id FROM agents WHERE id = ?1")
.or_else(|_| {
// Fallback without session_id column for old DBs
conn.prepare("SELECT id, name, manifest, state, created_at, updated_at FROM agents WHERE id = ?1")
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let col_count = stmt.column_count();
let result = stmt.query_row(rusqlite::params![agent_id.0.to_string()], |row| {
let manifest_blob: Vec<u8> = row.get(2)?;
let state_str: String = row.get(3)?;
let created_str: String = row.get(4)?;
let name: String = row.get(1)?;
let session_id_str: Option<String> = if col_count >= 7 {
row.get(6).ok()
} else {
None
};
Ok((name, manifest_blob, state_str, created_str, session_id_str))
});
match result {
Ok((name, manifest_blob, state_str, created_str, session_id_str)) => {
let manifest = rmp_serde::from_slice(&manifest_blob)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let state = serde_json::from_str(&state_str)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
let created_at = chrono::DateTime::parse_from_rfc3339(&created_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
let session_id = session_id_str
.and_then(|s| uuid::Uuid::parse_str(&s).ok())
.map(openfang_types::agent::SessionId)
.unwrap_or_else(openfang_types::agent::SessionId::new);
Ok(Some(AgentEntry {
id: agent_id,
name,
manifest,
state,
mode: Default::default(),
created_at,
last_active: Utc::now(),
parent: None,
children: vec![],
session_id,
tags: vec![],
identity: Default::default(),
onboarding_completed: false,
onboarding_completed_at: None,
}))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(OpenFangError::Memory(e.to_string())),
}
}
/// Remove an agent from the database.
pub fn remove_agent(&self, agent_id: AgentId) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
conn.execute(
"DELETE FROM agents WHERE id = ?1",
rusqlite::params![agent_id.0.to_string()],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Load all agent entries from the database.
///
/// Uses lenient deserialization (via `serde_compat`) to handle schema-mismatched
/// fields gracefully. When an agent is loaded with lenient defaults, it is
/// automatically re-saved to upgrade the stored blob. Duplicate agent names
/// are deduplicated (first occurrence wins).
pub fn load_all_agents(&self) -> OpenFangResult<Vec<AgentEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
// Try with session_id column first, fall back without
let mut stmt = conn
.prepare(
"SELECT id, name, manifest, state, created_at, updated_at, session_id FROM agents",
)
.or_else(|_| {
conn.prepare("SELECT id, name, manifest, state, created_at, updated_at FROM agents")
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let col_count = stmt.column_count();
let rows = stmt
.query_map([], |row| {
let id_str: String = row.get(0)?;
let name: String = row.get(1)?;
let manifest_blob: Vec<u8> = row.get(2)?;
let state_str: String = row.get(3)?;
let created_str: String = row.get(4)?;
let session_id_str: Option<String> = if col_count >= 7 {
row.get(6).ok()
} else {
None
};
Ok((
id_str,
name,
manifest_blob,
state_str,
created_str,
session_id_str,
))
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut agents = Vec::new();
let mut seen_names = std::collections::HashSet::new();
let mut repair_queue: Vec<(String, Vec<u8>, String)> = Vec::new();
for row in rows {
let (id_str, name, manifest_blob, state_str, created_str, session_id_str) = match row {
Ok(r) => r,
Err(e) => {
tracing::warn!("Skipping agent row with read error: {e}");
continue;
}
};
// Deduplicate: skip agents with names we've already seen
let name_lower = name.to_lowercase();
if !seen_names.insert(name_lower) {
tracing::info!(agent = %name, id = %id_str, "Skipping duplicate agent name");
continue;
}
let agent_id = match uuid::Uuid::parse_str(&id_str).map(openfang_types::agent::AgentId)
{
Ok(id) => id,
Err(e) => {
tracing::warn!(agent = %name, "Skipping agent with bad UUID '{id_str}': {e}");
continue;
}
};
let manifest: openfang_types::agent::AgentManifest = match rmp_serde::from_slice(
&manifest_blob,
) {
Ok(m) => m,
Err(e) => {
tracing::warn!(
agent = %name, id = %id_str,
"Skipping agent with incompatible manifest (schema may have changed): {e}"
);
continue;
}
};
// Auto-repair: re-serialize with current schema and queue for update.
// This upgrades the stored blob so future boots don't hit lenient paths.
let new_blob = rmp_serde::to_vec_named(&manifest)
.map_err(|e| OpenFangError::Serialization(e.to_string()))?;
if new_blob != manifest_blob {
tracing::info!(
agent = %name, id = %id_str,
"Auto-repaired agent manifest (schema upgraded)"
);
repair_queue.push((id_str.clone(), new_blob, name.clone()));
}
let state = match serde_json::from_str(&state_str) {
Ok(s) => s,
Err(e) => {
tracing::warn!(agent = %name, "Skipping agent with bad state: {e}");
continue;
}
};
let created_at = chrono::DateTime::parse_from_rfc3339(&created_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
let session_id = session_id_str
.and_then(|s| uuid::Uuid::parse_str(&s).ok())
.map(openfang_types::agent::SessionId)
.unwrap_or_else(openfang_types::agent::SessionId::new);
agents.push(AgentEntry {
id: agent_id,
name,
manifest,
state,
mode: Default::default(),
created_at,
last_active: Utc::now(),
parent: None,
children: vec![],
session_id,
tags: vec![],
identity: Default::default(),
onboarding_completed: false,
onboarding_completed_at: None,
});
}
// Apply queued repairs (re-save upgraded blobs)
for (id_str, new_blob, name) in repair_queue {
if let Err(e) = conn.execute(
"UPDATE agents SET manifest = ?1 WHERE id = ?2",
rusqlite::params![new_blob, id_str],
) {
tracing::warn!(agent = %name, "Failed to auto-repair agent blob: {e}");
}
}
Ok(agents)
}
/// List all agents in the database.
pub fn list_agents(&self) -> OpenFangResult<Vec<(String, String, String)>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare("SELECT id, name, state FROM agents")
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
))
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut agents = Vec::new();
for row in rows {
agents.push(row.map_err(|e| OpenFangError::Memory(e.to_string()))?);
}
Ok(agents)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::migration::run_migrations;
fn setup() -> StructuredStore {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
StructuredStore::new(Arc::new(Mutex::new(conn)))
}
#[test]
fn test_kv_set_get() {
let store = setup();
let agent_id = AgentId::new();
store
.set(agent_id, "test_key", serde_json::json!("test_value"))
.unwrap();
let value = store.get(agent_id, "test_key").unwrap();
assert_eq!(value, Some(serde_json::json!("test_value")));
}
#[test]
fn test_kv_get_missing() {
let store = setup();
let agent_id = AgentId::new();
let value = store.get(agent_id, "nonexistent").unwrap();
assert!(value.is_none());
}
#[test]
fn test_kv_delete() {
let store = setup();
let agent_id = AgentId::new();
store
.set(agent_id, "to_delete", serde_json::json!(42))
.unwrap();
store.delete(agent_id, "to_delete").unwrap();
let value = store.get(agent_id, "to_delete").unwrap();
assert!(value.is_none());
}
#[test]
fn test_kv_update() {
let store = setup();
let agent_id = AgentId::new();
store.set(agent_id, "key", serde_json::json!("v1")).unwrap();
store.set(agent_id, "key", serde_json::json!("v2")).unwrap();
let value = store.get(agent_id, "key").unwrap();
assert_eq!(value, Some(serde_json::json!("v2")));
}
}

View File

@@ -0,0 +1,756 @@
//! MemorySubstrate: unified implementation of the `Memory` trait.
//!
//! Composes the structured store, semantic store, knowledge store,
//! session store, and consolidation engine behind a single async API.
use crate::consolidation::ConsolidationEngine;
use crate::knowledge::KnowledgeStore;
use crate::migration::run_migrations;
use crate::semantic::SemanticStore;
use crate::session::{Session, SessionStore};
use crate::structured::StructuredStore;
use crate::usage::UsageStore;
use async_trait::async_trait;
use openfang_types::agent::{AgentEntry, AgentId, SessionId};
use openfang_types::error::{OpenFangError, OpenFangResult};
use openfang_types::memory::{
ConsolidationReport, Entity, ExportFormat, GraphMatch, GraphPattern, ImportReport, Memory,
MemoryFilter, MemoryFragment, MemoryId, MemorySource, Relation,
};
use rusqlite::Connection;
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
/// The unified memory substrate. Implements the `Memory` trait by delegating
/// to specialized stores backed by a shared SQLite connection.
pub struct MemorySubstrate {
conn: Arc<Mutex<Connection>>,
structured: StructuredStore,
semantic: SemanticStore,
knowledge: KnowledgeStore,
sessions: SessionStore,
consolidation: ConsolidationEngine,
usage: UsageStore,
}
impl MemorySubstrate {
/// Open or create a memory substrate at the given database path.
pub fn open(db_path: &Path, decay_rate: f32) -> OpenFangResult<Self> {
let conn = Connection::open(db_path).map_err(|e| OpenFangError::Memory(e.to_string()))?;
conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;")
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
run_migrations(&conn).map_err(|e| OpenFangError::Memory(e.to_string()))?;
let shared = Arc::new(Mutex::new(conn));
Ok(Self {
conn: Arc::clone(&shared),
structured: StructuredStore::new(Arc::clone(&shared)),
semantic: SemanticStore::new(Arc::clone(&shared)),
knowledge: KnowledgeStore::new(Arc::clone(&shared)),
sessions: SessionStore::new(Arc::clone(&shared)),
usage: UsageStore::new(Arc::clone(&shared)),
consolidation: ConsolidationEngine::new(shared, decay_rate),
})
}
/// Create an in-memory substrate (for testing).
pub fn open_in_memory(decay_rate: f32) -> OpenFangResult<Self> {
let conn =
Connection::open_in_memory().map_err(|e| OpenFangError::Memory(e.to_string()))?;
run_migrations(&conn).map_err(|e| OpenFangError::Memory(e.to_string()))?;
let shared = Arc::new(Mutex::new(conn));
Ok(Self {
conn: Arc::clone(&shared),
structured: StructuredStore::new(Arc::clone(&shared)),
semantic: SemanticStore::new(Arc::clone(&shared)),
knowledge: KnowledgeStore::new(Arc::clone(&shared)),
sessions: SessionStore::new(Arc::clone(&shared)),
usage: UsageStore::new(Arc::clone(&shared)),
consolidation: ConsolidationEngine::new(shared, decay_rate),
})
}
/// Get a reference to the usage store.
pub fn usage(&self) -> &UsageStore {
&self.usage
}
/// Get the shared database connection (for constructing stores from outside).
pub fn usage_conn(&self) -> Arc<Mutex<Connection>> {
Arc::clone(&self.conn)
}
/// Save an agent entry to persistent storage.
pub fn save_agent(&self, entry: &AgentEntry) -> OpenFangResult<()> {
self.structured.save_agent(entry)
}
/// Load an agent entry from persistent storage.
pub fn load_agent(&self, agent_id: AgentId) -> OpenFangResult<Option<AgentEntry>> {
self.structured.load_agent(agent_id)
}
/// Remove an agent from persistent storage and cascade-delete sessions.
pub fn remove_agent(&self, agent_id: AgentId) -> OpenFangResult<()> {
// Delete associated sessions first
let _ = self.sessions.delete_agent_sessions(agent_id);
self.structured.remove_agent(agent_id)
}
/// Load all agent entries from persistent storage.
pub fn load_all_agents(&self) -> OpenFangResult<Vec<AgentEntry>> {
self.structured.load_all_agents()
}
/// List all saved agents.
pub fn list_agents(&self) -> OpenFangResult<Vec<(String, String, String)>> {
self.structured.list_agents()
}
/// Synchronous get from the structured store (for kernel handle use).
pub fn structured_get(
&self,
agent_id: AgentId,
key: &str,
) -> OpenFangResult<Option<serde_json::Value>> {
self.structured.get(agent_id, key)
}
/// List all KV pairs for an agent.
pub fn list_kv(&self, agent_id: AgentId) -> OpenFangResult<Vec<(String, serde_json::Value)>> {
self.structured.list_kv(agent_id)
}
/// Delete a KV entry for an agent.
pub fn structured_delete(&self, agent_id: AgentId, key: &str) -> OpenFangResult<()> {
self.structured.delete(agent_id, key)
}
/// Synchronous set in the structured store (for kernel handle use).
pub fn structured_set(
&self,
agent_id: AgentId,
key: &str,
value: serde_json::Value,
) -> OpenFangResult<()> {
self.structured.set(agent_id, key, value)
}
/// Get a session by ID.
pub fn get_session(&self, session_id: SessionId) -> OpenFangResult<Option<Session>> {
self.sessions.get_session(session_id)
}
/// Save a session.
pub fn save_session(&self, session: &Session) -> OpenFangResult<()> {
self.sessions.save_session(session)
}
/// Create a new empty session for an agent.
pub fn create_session(&self, agent_id: AgentId) -> OpenFangResult<Session> {
self.sessions.create_session(agent_id)
}
/// List all sessions with metadata.
pub fn list_sessions(&self) -> OpenFangResult<Vec<serde_json::Value>> {
self.sessions.list_sessions()
}
/// Delete a session by ID.
pub fn delete_session(&self, session_id: SessionId) -> OpenFangResult<()> {
self.sessions.delete_session(session_id)
}
/// Set or clear a session label.
pub fn set_session_label(
&self,
session_id: SessionId,
label: Option<&str>,
) -> OpenFangResult<()> {
self.sessions.set_session_label(session_id, label)
}
/// Find a session by label for a given agent.
pub fn find_session_by_label(
&self,
agent_id: AgentId,
label: &str,
) -> OpenFangResult<Option<Session>> {
self.sessions.find_session_by_label(agent_id, label)
}
/// List all sessions for a specific agent.
pub fn list_agent_sessions(&self, agent_id: AgentId) -> OpenFangResult<Vec<serde_json::Value>> {
self.sessions.list_agent_sessions(agent_id)
}
/// Create a new session with an optional label.
pub fn create_session_with_label(
&self,
agent_id: AgentId,
label: Option<&str>,
) -> OpenFangResult<Session> {
self.sessions.create_session_with_label(agent_id, label)
}
/// Load canonical session context for cross-channel memory.
///
/// Returns the compacted summary (if any) and recent messages from the
/// agent's persistent canonical session.
pub fn canonical_context(
&self,
agent_id: AgentId,
window_size: Option<usize>,
) -> OpenFangResult<(Option<String>, Vec<openfang_types::message::Message>)> {
self.sessions.canonical_context(agent_id, window_size)
}
/// Store an LLM-generated summary, replacing older messages with the kept subset.
///
/// Used by the compactor to replace text-truncation compaction with an
/// LLM-generated summary of older conversation history.
pub fn store_llm_summary(
&self,
agent_id: AgentId,
summary: &str,
kept_messages: Vec<openfang_types::message::Message>,
) -> OpenFangResult<()> {
self.sessions
.store_llm_summary(agent_id, summary, kept_messages)
}
/// Write a human-readable JSONL mirror of a session to disk.
///
/// Best-effort — errors are returned but should be logged,
/// never affecting the primary SQLite store.
pub fn write_jsonl_mirror(
&self,
session: &Session,
sessions_dir: &Path,
) -> Result<(), std::io::Error> {
self.sessions.write_jsonl_mirror(session, sessions_dir)
}
/// Append messages to the agent's canonical session for cross-channel persistence.
pub fn append_canonical(
&self,
agent_id: AgentId,
messages: &[openfang_types::message::Message],
compaction_threshold: Option<usize>,
) -> OpenFangResult<()> {
self.sessions
.append_canonical(agent_id, messages, compaction_threshold)?;
Ok(())
}
// -----------------------------------------------------------------
// Paired devices persistence
// -----------------------------------------------------------------
/// Load all paired devices from the database.
pub fn load_paired_devices(&self) -> OpenFangResult<Vec<serde_json::Value>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut stmt = conn.prepare(
"SELECT device_id, display_name, platform, paired_at, last_seen, push_token FROM paired_devices"
).map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows = stmt
.query_map([], |row| {
Ok(serde_json::json!({
"device_id": row.get::<_, String>(0)?,
"display_name": row.get::<_, String>(1)?,
"platform": row.get::<_, String>(2)?,
"paired_at": row.get::<_, String>(3)?,
"last_seen": row.get::<_, String>(4)?,
"push_token": row.get::<_, Option<String>>(5)?,
}))
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut devices = Vec::new();
for row in rows {
devices.push(row.map_err(|e| OpenFangError::Memory(e.to_string()))?);
}
Ok(devices)
}
/// Save a paired device to the database (insert or replace).
pub fn save_paired_device(
&self,
device_id: &str,
display_name: &str,
platform: &str,
paired_at: &str,
last_seen: &str,
push_token: Option<&str>,
) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
conn.execute(
"INSERT OR REPLACE INTO paired_devices (device_id, display_name, platform, paired_at, last_seen, push_token) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
rusqlite::params![device_id, display_name, platform, paired_at, last_seen, push_token],
).map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Remove a paired device from the database.
pub fn remove_paired_device(&self, device_id: &str) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
conn.execute(
"DELETE FROM paired_devices WHERE device_id = ?1",
rusqlite::params![device_id],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
// -----------------------------------------------------------------
// Embedding-aware memory operations
// -----------------------------------------------------------------
/// Store a memory with an embedding vector.
pub fn remember_with_embedding(
&self,
agent_id: AgentId,
content: &str,
source: MemorySource,
scope: &str,
metadata: HashMap<String, serde_json::Value>,
embedding: Option<&[f32]>,
) -> OpenFangResult<MemoryId> {
self.semantic
.remember_with_embedding(agent_id, content, source, scope, metadata, embedding)
}
/// Recall memories using vector similarity when a query embedding is provided.
pub fn recall_with_embedding(
&self,
query: &str,
limit: usize,
filter: Option<MemoryFilter>,
query_embedding: Option<&[f32]>,
) -> OpenFangResult<Vec<MemoryFragment>> {
self.semantic
.recall_with_embedding(query, limit, filter, query_embedding)
}
/// Update the embedding for an existing memory.
pub fn update_embedding(&self, id: MemoryId, embedding: &[f32]) -> OpenFangResult<()> {
self.semantic.update_embedding(id, embedding)
}
/// Async wrapper for `recall_with_embedding` — runs in a blocking thread.
pub async fn recall_with_embedding_async(
&self,
query: &str,
limit: usize,
filter: Option<MemoryFilter>,
query_embedding: Option<&[f32]>,
) -> OpenFangResult<Vec<MemoryFragment>> {
let store = self.semantic.clone();
let query = query.to_string();
let embedding_owned = query_embedding.map(|e| e.to_vec());
tokio::task::spawn_blocking(move || {
store.recall_with_embedding(&query, limit, filter, embedding_owned.as_deref())
})
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
/// Async wrapper for `remember_with_embedding` — runs in a blocking thread.
pub async fn remember_with_embedding_async(
&self,
agent_id: AgentId,
content: &str,
source: MemorySource,
scope: &str,
metadata: HashMap<String, serde_json::Value>,
embedding: Option<&[f32]>,
) -> OpenFangResult<MemoryId> {
let store = self.semantic.clone();
let content = content.to_string();
let scope = scope.to_string();
let embedding_owned = embedding.map(|e| e.to_vec());
tokio::task::spawn_blocking(move || {
store.remember_with_embedding(
agent_id,
&content,
source,
&scope,
metadata,
embedding_owned.as_deref(),
)
})
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
// -----------------------------------------------------------------
// Task queue operations
// -----------------------------------------------------------------
/// Post a new task to the shared queue. Returns the task ID.
pub async fn task_post(
&self,
title: &str,
description: &str,
assigned_to: Option<&str>,
created_by: Option<&str>,
) -> OpenFangResult<String> {
let conn = Arc::clone(&self.conn);
let title = title.to_string();
let description = description.to_string();
let assigned_to = assigned_to.unwrap_or("").to_string();
let created_by = created_by.unwrap_or("").to_string();
tokio::task::spawn_blocking(move || {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
let db = conn.lock().map_err(|e| OpenFangError::Internal(e.to_string()))?;
db.execute(
"INSERT INTO task_queue (id, agent_id, task_type, payload, status, priority, created_at, title, description, assigned_to, created_by)
VALUES (?1, ?2, ?3, ?4, 'pending', 0, ?5, ?6, ?7, ?8, ?9)",
rusqlite::params![id, &created_by, &title, b"", now, title, description, assigned_to, created_by],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(id)
})
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
/// Claim the next pending task (optionally for a specific assignee). Returns task JSON or None.
pub async fn task_claim(&self, agent_id: &str) -> OpenFangResult<Option<serde_json::Value>> {
let conn = Arc::clone(&self.conn);
let agent_id = agent_id.to_string();
tokio::task::spawn_blocking(move || {
let db = conn.lock().map_err(|e| OpenFangError::Internal(e.to_string()))?;
// Find first pending task assigned to this agent, or any unassigned pending task
let mut stmt = db.prepare(
"SELECT id, title, description, assigned_to, created_by, created_at
FROM task_queue
WHERE status = 'pending' AND (assigned_to = ?1 OR assigned_to = '')
ORDER BY priority DESC, created_at ASC
LIMIT 1"
).map_err(|e| OpenFangError::Memory(e.to_string()))?;
let result = stmt.query_row(rusqlite::params![agent_id], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
row.get::<_, String>(4)?,
row.get::<_, String>(5)?,
))
});
match result {
Ok((id, title, description, assigned, created_by, created_at)) => {
// Update status to in_progress
db.execute(
"UPDATE task_queue SET status = 'in_progress', assigned_to = ?2 WHERE id = ?1",
rusqlite::params![id, agent_id],
).map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(Some(serde_json::json!({
"id": id,
"title": title,
"description": description,
"status": "in_progress",
"assigned_to": if assigned.is_empty() { &agent_id } else { &assigned },
"created_by": created_by,
"created_at": created_at,
})))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(OpenFangError::Memory(e.to_string())),
}
})
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
/// Mark a task as completed with a result string.
pub async fn task_complete(&self, task_id: &str, result: &str) -> OpenFangResult<()> {
let conn = Arc::clone(&self.conn);
let task_id = task_id.to_string();
let result = result.to_string();
tokio::task::spawn_blocking(move || {
let now = chrono::Utc::now().to_rfc3339();
let db = conn.lock().map_err(|e| OpenFangError::Internal(e.to_string()))?;
let rows = db.execute(
"UPDATE task_queue SET status = 'completed', result = ?2, completed_at = ?3 WHERE id = ?1",
rusqlite::params![task_id, result, now],
).map_err(|e| OpenFangError::Memory(e.to_string()))?;
if rows == 0 {
return Err(OpenFangError::Internal(format!("Task not found: {task_id}")));
}
Ok(())
})
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
/// List tasks, optionally filtered by status.
pub async fn task_list(&self, status: Option<&str>) -> OpenFangResult<Vec<serde_json::Value>> {
let conn = Arc::clone(&self.conn);
let status = status.map(|s| s.to_string());
tokio::task::spawn_blocking(move || {
let db = conn.lock().map_err(|e| OpenFangError::Internal(e.to_string()))?;
let (sql, params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) = match &status {
Some(s) => (
"SELECT id, title, description, status, assigned_to, created_by, created_at, completed_at, result FROM task_queue WHERE status = ?1 ORDER BY created_at DESC",
vec![Box::new(s.clone())],
),
None => (
"SELECT id, title, description, status, assigned_to, created_by, created_at, completed_at, result FROM task_queue ORDER BY created_at DESC",
vec![],
),
};
let mut stmt = db.prepare(sql).map_err(|e| OpenFangError::Memory(e.to_string()))?;
let params_refs: Vec<&dyn rusqlite::types::ToSql> = params.iter().map(|p| p.as_ref()).collect();
let rows = stmt.query_map(params_refs.as_slice(), |row| {
Ok(serde_json::json!({
"id": row.get::<_, String>(0)?,
"title": row.get::<_, String>(1).unwrap_or_default(),
"description": row.get::<_, String>(2).unwrap_or_default(),
"status": row.get::<_, String>(3)?,
"assigned_to": row.get::<_, String>(4).unwrap_or_default(),
"created_by": row.get::<_, String>(5).unwrap_or_default(),
"created_at": row.get::<_, String>(6).unwrap_or_default(),
"completed_at": row.get::<_, Option<String>>(7).unwrap_or(None),
"result": row.get::<_, Option<String>>(8).unwrap_or(None),
}))
}).map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut tasks = Vec::new();
for row in rows {
tasks.push(row.map_err(|e| OpenFangError::Memory(e.to_string()))?);
}
Ok(tasks)
})
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
}
#[async_trait]
impl Memory for MemorySubstrate {
async fn get(&self, agent_id: AgentId, key: &str) -> OpenFangResult<Option<serde_json::Value>> {
let store = self.structured.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || store.get(agent_id, &key))
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn set(
&self,
agent_id: AgentId,
key: &str,
value: serde_json::Value,
) -> OpenFangResult<()> {
let store = self.structured.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || store.set(agent_id, &key, value))
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn delete(&self, agent_id: AgentId, key: &str) -> OpenFangResult<()> {
let store = self.structured.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || store.delete(agent_id, &key))
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn remember(
&self,
agent_id: AgentId,
content: &str,
source: MemorySource,
scope: &str,
metadata: HashMap<String, serde_json::Value>,
) -> OpenFangResult<MemoryId> {
let store = self.semantic.clone();
let content = content.to_string();
let scope = scope.to_string();
tokio::task::spawn_blocking(move || {
store.remember(agent_id, &content, source, &scope, metadata)
})
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn recall(
&self,
query: &str,
limit: usize,
filter: Option<MemoryFilter>,
) -> OpenFangResult<Vec<MemoryFragment>> {
let store = self.semantic.clone();
let query = query.to_string();
tokio::task::spawn_blocking(move || store.recall(&query, limit, filter))
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn forget(&self, id: MemoryId) -> OpenFangResult<()> {
let store = self.semantic.clone();
tokio::task::spawn_blocking(move || store.forget(id))
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn add_entity(&self, entity: Entity) -> OpenFangResult<String> {
let store = self.knowledge.clone();
tokio::task::spawn_blocking(move || store.add_entity(entity))
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn add_relation(&self, relation: Relation) -> OpenFangResult<String> {
let store = self.knowledge.clone();
tokio::task::spawn_blocking(move || store.add_relation(relation))
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn query_graph(&self, pattern: GraphPattern) -> OpenFangResult<Vec<GraphMatch>> {
let store = self.knowledge.clone();
tokio::task::spawn_blocking(move || store.query_graph(pattern))
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn consolidate(&self) -> OpenFangResult<ConsolidationReport> {
let engine = self.consolidation.clone();
tokio::task::spawn_blocking(move || engine.consolidate())
.await
.map_err(|e| OpenFangError::Internal(e.to_string()))?
}
async fn export(&self, format: ExportFormat) -> OpenFangResult<Vec<u8>> {
let _ = format;
Ok(Vec::new())
}
async fn import(&self, _data: &[u8], _format: ExportFormat) -> OpenFangResult<ImportReport> {
Ok(ImportReport {
entities_imported: 0,
relations_imported: 0,
memories_imported: 0,
errors: vec!["Import not yet implemented in Phase 1".to_string()],
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_substrate_kv() {
let substrate = MemorySubstrate::open_in_memory(0.1).unwrap();
let agent_id = AgentId::new();
substrate
.set(agent_id, "key", serde_json::json!("value"))
.await
.unwrap();
let val = substrate.get(agent_id, "key").await.unwrap();
assert_eq!(val, Some(serde_json::json!("value")));
}
#[tokio::test]
async fn test_substrate_remember_recall() {
let substrate = MemorySubstrate::open_in_memory(0.1).unwrap();
let agent_id = AgentId::new();
substrate
.remember(
agent_id,
"Rust is a great language",
MemorySource::Conversation,
"episodic",
HashMap::new(),
)
.await
.unwrap();
let results = substrate.recall("Rust", 10, None).await.unwrap();
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn test_task_post_and_list() {
let substrate = MemorySubstrate::open_in_memory(0.1).unwrap();
let id = substrate
.task_post(
"Review code",
"Check the auth module for issues",
Some("auditor"),
Some("orchestrator"),
)
.await
.unwrap();
assert!(!id.is_empty());
let tasks = substrate.task_list(Some("pending")).await.unwrap();
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0]["title"], "Review code");
assert_eq!(tasks[0]["assigned_to"], "auditor");
assert_eq!(tasks[0]["status"], "pending");
}
#[tokio::test]
async fn test_task_claim_and_complete() {
let substrate = MemorySubstrate::open_in_memory(0.1).unwrap();
let task_id = substrate
.task_post(
"Audit endpoint",
"Security audit the /api/login endpoint",
Some("auditor"),
None,
)
.await
.unwrap();
// Claim the task
let claimed = substrate.task_claim("auditor").await.unwrap();
assert!(claimed.is_some());
let claimed = claimed.unwrap();
assert_eq!(claimed["id"], task_id);
assert_eq!(claimed["status"], "in_progress");
// Complete the task
substrate
.task_complete(&task_id, "No vulnerabilities found")
.await
.unwrap();
// Verify it shows as completed
let tasks = substrate.task_list(Some("completed")).await.unwrap();
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0]["result"], "No vulnerabilities found");
}
#[tokio::test]
async fn test_task_claim_empty() {
let substrate = MemorySubstrate::open_in_memory(0.1).unwrap();
let claimed = substrate.task_claim("nobody").await.unwrap();
assert!(claimed.is_none());
}
}

View File

@@ -0,0 +1,541 @@
//! Usage tracking store — records LLM usage events for cost monitoring.
use chrono::Utc;
use openfang_types::agent::AgentId;
use openfang_types::error::{OpenFangError, OpenFangResult};
use rusqlite::Connection;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
/// A single usage event recording an LLM call.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageRecord {
/// Which agent made the call.
pub agent_id: AgentId,
/// Model used.
pub model: String,
/// Input tokens consumed.
pub input_tokens: u64,
/// Output tokens consumed.
pub output_tokens: u64,
/// Estimated cost in USD.
pub cost_usd: f64,
/// Number of tool calls in this interaction.
pub tool_calls: u32,
}
/// Summary of usage over a period.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageSummary {
/// Total input tokens.
pub total_input_tokens: u64,
/// Total output tokens.
pub total_output_tokens: u64,
/// Total estimated cost in USD.
pub total_cost_usd: f64,
/// Total number of calls.
pub call_count: u64,
/// Total tool calls.
pub total_tool_calls: u64,
}
/// Usage grouped by model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUsage {
/// Model name.
pub model: String,
/// Total cost for this model.
pub total_cost_usd: f64,
/// Total input tokens.
pub total_input_tokens: u64,
/// Total output tokens.
pub total_output_tokens: u64,
/// Number of calls.
pub call_count: u64,
}
/// Daily usage breakdown.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DailyBreakdown {
/// Date string (YYYY-MM-DD).
pub date: String,
/// Total cost for this day.
pub cost_usd: f64,
/// Total tokens (input + output).
pub tokens: u64,
/// Number of API calls.
pub calls: u64,
}
/// Usage store backed by SQLite.
#[derive(Clone)]
pub struct UsageStore {
conn: Arc<Mutex<Connection>>,
}
impl UsageStore {
/// Create a new usage store wrapping the given connection.
pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
Self { conn }
}
/// Record a usage event.
pub fn record(&self, record: &UsageRecord) -> OpenFangResult<()> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let id = uuid::Uuid::new_v4().to_string();
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO usage_events (id, agent_id, timestamp, model, input_tokens, output_tokens, cost_usd, tool_calls)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
rusqlite::params![
id,
record.agent_id.0.to_string(),
now,
record.model,
record.input_tokens as i64,
record.output_tokens as i64,
record.cost_usd,
record.tool_calls as i64,
],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(())
}
/// Query total cost in the last hour for an agent.
pub fn query_hourly(&self, agent_id: AgentId) -> OpenFangResult<f64> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let cost: f64 = conn
.query_row(
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM usage_events
WHERE agent_id = ?1 AND timestamp > datetime('now', '-1 hour')",
rusqlite::params![agent_id.0.to_string()],
|row| row.get(0),
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(cost)
}
/// Query total cost today for an agent.
pub fn query_daily(&self, agent_id: AgentId) -> OpenFangResult<f64> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let cost: f64 = conn
.query_row(
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM usage_events
WHERE agent_id = ?1 AND timestamp > datetime('now', 'start of day')",
rusqlite::params![agent_id.0.to_string()],
|row| row.get(0),
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(cost)
}
/// Query total cost in the current calendar month for an agent.
pub fn query_monthly(&self, agent_id: AgentId) -> OpenFangResult<f64> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let cost: f64 = conn
.query_row(
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM usage_events
WHERE agent_id = ?1 AND timestamp > datetime('now', 'start of month')",
rusqlite::params![agent_id.0.to_string()],
|row| row.get(0),
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(cost)
}
/// Query total cost across all agents for the current hour.
pub fn query_global_hourly(&self) -> OpenFangResult<f64> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let cost: f64 = conn
.query_row(
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM usage_events
WHERE timestamp > datetime('now', '-1 hour')",
[],
|row| row.get(0),
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(cost)
}
/// Query total cost across all agents for the current calendar month.
pub fn query_global_monthly(&self) -> OpenFangResult<f64> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let cost: f64 = conn
.query_row(
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM usage_events
WHERE timestamp > datetime('now', 'start of month')",
[],
|row| row.get(0),
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(cost)
}
/// Query usage summary, optionally filtered by agent.
pub fn query_summary(&self, agent_id: Option<AgentId>) -> OpenFangResult<UsageSummary> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let (sql, params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) = match agent_id {
Some(aid) => (
"SELECT COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
COALESCE(SUM(cost_usd), 0.0), COUNT(*), COALESCE(SUM(tool_calls), 0)
FROM usage_events WHERE agent_id = ?1",
vec![Box::new(aid.0.to_string())],
),
None => (
"SELECT COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
COALESCE(SUM(cost_usd), 0.0), COUNT(*), COALESCE(SUM(tool_calls), 0)
FROM usage_events",
vec![],
),
};
let params_refs: Vec<&dyn rusqlite::types::ToSql> =
params.iter().map(|p| p.as_ref()).collect();
let summary = conn
.query_row(sql, params_refs.as_slice(), |row| {
Ok(UsageSummary {
total_input_tokens: row.get::<_, i64>(0)? as u64,
total_output_tokens: row.get::<_, i64>(1)? as u64,
total_cost_usd: row.get(2)?,
call_count: row.get::<_, i64>(3)? as u64,
total_tool_calls: row.get::<_, i64>(4)? as u64,
})
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(summary)
}
/// Query usage grouped by model.
pub fn query_by_model(&self) -> OpenFangResult<Vec<ModelUsage>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare(
"SELECT model, COALESCE(SUM(cost_usd), 0.0), COALESCE(SUM(input_tokens), 0),
COALESCE(SUM(output_tokens), 0), COUNT(*)
FROM usage_events GROUP BY model ORDER BY SUM(cost_usd) DESC",
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows = stmt
.query_map([], |row| {
Ok(ModelUsage {
model: row.get(0)?,
total_cost_usd: row.get(1)?,
total_input_tokens: row.get::<_, i64>(2)? as u64,
total_output_tokens: row.get::<_, i64>(3)? as u64,
call_count: row.get::<_, i64>(4)? as u64,
})
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut results = Vec::new();
for row in rows {
results.push(row.map_err(|e| OpenFangError::Memory(e.to_string()))?);
}
Ok(results)
}
/// Query daily usage breakdown for the last N days.
pub fn query_daily_breakdown(&self, days: u32) -> OpenFangResult<Vec<DailyBreakdown>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let mut stmt = conn
.prepare(&format!(
"SELECT date(timestamp) as day,
COALESCE(SUM(cost_usd), 0.0),
COALESCE(SUM(input_tokens) + SUM(output_tokens), 0),
COUNT(*)
FROM usage_events
WHERE timestamp > datetime('now', '-{days} days')
GROUP BY day
ORDER BY day ASC"
))
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let rows = stmt
.query_map([], |row| {
Ok(DailyBreakdown {
date: row.get(0)?,
cost_usd: row.get(1)?,
tokens: row.get::<_, i64>(2)? as u64,
calls: row.get::<_, i64>(3)? as u64,
})
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
let mut results = Vec::new();
for row in rows {
results.push(row.map_err(|e| OpenFangError::Memory(e.to_string()))?);
}
Ok(results)
}
/// Query the timestamp of the earliest usage event.
pub fn query_first_event_date(&self) -> OpenFangResult<Option<String>> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let result: Option<String> = conn
.query_row("SELECT MIN(timestamp) FROM usage_events", [], |row| {
row.get(0)
})
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(result)
}
/// Query today's total cost across all agents.
pub fn query_today_cost(&self) -> OpenFangResult<f64> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let cost: f64 = conn
.query_row(
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM usage_events
WHERE timestamp > datetime('now', 'start of day')",
[],
|row| row.get(0),
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(cost)
}
/// Delete usage events older than the given number of days.
pub fn cleanup_old(&self, days: u32) -> OpenFangResult<usize> {
let conn = self
.conn
.lock()
.map_err(|e| OpenFangError::Internal(e.to_string()))?;
let deleted = conn
.execute(
&format!(
"DELETE FROM usage_events WHERE timestamp < datetime('now', '-{days} days')"
),
[],
)
.map_err(|e| OpenFangError::Memory(e.to_string()))?;
Ok(deleted)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::migration::run_migrations;
fn setup() -> UsageStore {
let conn = Connection::open_in_memory().unwrap();
run_migrations(&conn).unwrap();
UsageStore::new(Arc::new(Mutex::new(conn)))
}
#[test]
fn test_record_and_query_summary() {
let store = setup();
let agent_id = AgentId::new();
store
.record(&UsageRecord {
agent_id,
model: "claude-haiku".to_string(),
input_tokens: 100,
output_tokens: 50,
cost_usd: 0.001,
tool_calls: 2,
})
.unwrap();
store
.record(&UsageRecord {
agent_id,
model: "claude-sonnet".to_string(),
input_tokens: 500,
output_tokens: 200,
cost_usd: 0.01,
tool_calls: 1,
})
.unwrap();
let summary = store.query_summary(Some(agent_id)).unwrap();
assert_eq!(summary.call_count, 2);
assert_eq!(summary.total_input_tokens, 600);
assert_eq!(summary.total_output_tokens, 250);
assert!((summary.total_cost_usd - 0.011).abs() < 0.0001);
assert_eq!(summary.total_tool_calls, 3);
}
#[test]
fn test_query_summary_all_agents() {
let store = setup();
let a1 = AgentId::new();
let a2 = AgentId::new();
store
.record(&UsageRecord {
agent_id: a1,
model: "haiku".to_string(),
input_tokens: 100,
output_tokens: 50,
cost_usd: 0.001,
tool_calls: 0,
})
.unwrap();
store
.record(&UsageRecord {
agent_id: a2,
model: "sonnet".to_string(),
input_tokens: 200,
output_tokens: 100,
cost_usd: 0.005,
tool_calls: 1,
})
.unwrap();
let summary = store.query_summary(None).unwrap();
assert_eq!(summary.call_count, 2);
assert_eq!(summary.total_input_tokens, 300);
}
#[test]
fn test_query_by_model() {
let store = setup();
let agent_id = AgentId::new();
for _ in 0..3 {
store
.record(&UsageRecord {
agent_id,
model: "haiku".to_string(),
input_tokens: 100,
output_tokens: 50,
cost_usd: 0.001,
tool_calls: 0,
})
.unwrap();
}
store
.record(&UsageRecord {
agent_id,
model: "sonnet".to_string(),
input_tokens: 500,
output_tokens: 200,
cost_usd: 0.01,
tool_calls: 1,
})
.unwrap();
let by_model = store.query_by_model().unwrap();
assert_eq!(by_model.len(), 2);
// sonnet should be first (highest cost)
assert_eq!(by_model[0].model, "sonnet");
assert_eq!(by_model[1].model, "haiku");
assert_eq!(by_model[1].call_count, 3);
}
#[test]
fn test_query_hourly() {
let store = setup();
let agent_id = AgentId::new();
store
.record(&UsageRecord {
agent_id,
model: "haiku".to_string(),
input_tokens: 100,
output_tokens: 50,
cost_usd: 0.05,
tool_calls: 0,
})
.unwrap();
let hourly = store.query_hourly(agent_id).unwrap();
assert!((hourly - 0.05).abs() < 0.001);
}
#[test]
fn test_query_daily() {
let store = setup();
let agent_id = AgentId::new();
store
.record(&UsageRecord {
agent_id,
model: "haiku".to_string(),
input_tokens: 100,
output_tokens: 50,
cost_usd: 0.123,
tool_calls: 0,
})
.unwrap();
let daily = store.query_daily(agent_id).unwrap();
assert!((daily - 0.123).abs() < 0.001);
}
#[test]
fn test_cleanup_old() {
let store = setup();
let agent_id = AgentId::new();
store
.record(&UsageRecord {
agent_id,
model: "haiku".to_string(),
input_tokens: 100,
output_tokens: 50,
cost_usd: 0.001,
tool_calls: 0,
})
.unwrap();
// Cleanup events older than 1 day should not remove today's events
let deleted = store.cleanup_old(1).unwrap();
assert_eq!(deleted, 0);
let summary = store.query_summary(None).unwrap();
assert_eq!(summary.call_count, 1);
}
#[test]
fn test_empty_summary() {
let store = setup();
let summary = store.query_summary(None).unwrap();
assert_eq!(summary.call_count, 0);
assert_eq!(summary.total_cost_usd, 0.0);
}
}