初始化提交
Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled
Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled
This commit is contained in:
42
crates/openfang-kernel/Cargo.toml
Normal file
42
crates/openfang-kernel/Cargo.toml
Normal file
@@ -0,0 +1,42 @@
|
||||
[package]
|
||||
name = "openfang-kernel"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Core kernel for the OpenFang Agent OS"
|
||||
|
||||
[dependencies]
|
||||
openfang-types = { path = "../openfang-types" }
|
||||
openfang-memory = { path = "../openfang-memory" }
|
||||
openfang-runtime = { path = "../openfang-runtime" }
|
||||
openfang-skills = { path = "../openfang-skills" }
|
||||
openfang-hands = { path = "../openfang-hands" }
|
||||
openfang-extensions = { path = "../openfang-extensions" }
|
||||
openfang-wire = { path = "../openfang-wire" }
|
||||
openfang-channels = { path = "../openfang-channels" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
crossbeam = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
subtle = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
cron = "0.15"
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
libc = "0.2"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
405
crates/openfang-kernel/src/approval.rs
Normal file
405
crates/openfang-kernel/src/approval.rs
Normal file
@@ -0,0 +1,405 @@
|
||||
//! Execution approval manager — gates dangerous operations behind human approval.
|
||||
|
||||
use chrono::Utc;
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::approval::{
|
||||
ApprovalDecision, ApprovalPolicy, ApprovalRequest, ApprovalResponse, RiskLevel,
|
||||
};
|
||||
use tracing::{debug, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Max pending requests per agent.
|
||||
const MAX_PENDING_PER_AGENT: usize = 5;
|
||||
|
||||
/// Manages approval requests with oneshot channels for blocking resolution.
|
||||
pub struct ApprovalManager {
|
||||
pending: DashMap<Uuid, PendingRequest>,
|
||||
policy: std::sync::RwLock<ApprovalPolicy>,
|
||||
}
|
||||
|
||||
struct PendingRequest {
|
||||
request: ApprovalRequest,
|
||||
sender: tokio::sync::oneshot::Sender<ApprovalDecision>,
|
||||
}
|
||||
|
||||
impl ApprovalManager {
|
||||
pub fn new(policy: ApprovalPolicy) -> Self {
|
||||
Self {
|
||||
pending: DashMap::new(),
|
||||
policy: std::sync::RwLock::new(policy),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool requires approval based on current policy.
|
||||
pub fn requires_approval(&self, tool_name: &str) -> bool {
|
||||
let policy = self.policy.read().unwrap_or_else(|e| e.into_inner());
|
||||
policy.require_approval.iter().any(|t| t == tool_name)
|
||||
}
|
||||
|
||||
/// Submit an approval request. Returns a future that resolves when approved/denied/timed out.
|
||||
pub async fn request_approval(&self, req: ApprovalRequest) -> ApprovalDecision {
|
||||
// Check per-agent pending limit
|
||||
let agent_pending = self
|
||||
.pending
|
||||
.iter()
|
||||
.filter(|r| r.value().request.agent_id == req.agent_id)
|
||||
.count();
|
||||
if agent_pending >= MAX_PENDING_PER_AGENT {
|
||||
warn!(agent_id = %req.agent_id, "Approval request rejected: too many pending");
|
||||
return ApprovalDecision::Denied;
|
||||
}
|
||||
|
||||
let timeout = std::time::Duration::from_secs(req.timeout_secs);
|
||||
let id = req.id;
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
self.pending.insert(
|
||||
id,
|
||||
PendingRequest {
|
||||
request: req,
|
||||
sender: tx,
|
||||
},
|
||||
);
|
||||
|
||||
info!(request_id = %id, "Approval request submitted, waiting for resolution");
|
||||
|
||||
match tokio::time::timeout(timeout, rx).await {
|
||||
Ok(Ok(decision)) => {
|
||||
debug!(request_id = %id, ?decision, "Approval resolved");
|
||||
decision
|
||||
}
|
||||
_ => {
|
||||
self.pending.remove(&id);
|
||||
warn!(request_id = %id, "Approval request timed out");
|
||||
ApprovalDecision::TimedOut
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve a pending request (called by API/UI).
|
||||
pub fn resolve(
|
||||
&self,
|
||||
request_id: Uuid,
|
||||
decision: ApprovalDecision,
|
||||
decided_by: Option<String>,
|
||||
) -> Result<ApprovalResponse, String> {
|
||||
match self.pending.remove(&request_id) {
|
||||
Some((_, pending)) => {
|
||||
let response = ApprovalResponse {
|
||||
request_id,
|
||||
decision,
|
||||
decided_at: Utc::now(),
|
||||
decided_by,
|
||||
};
|
||||
// Send decision to waiting agent (ignore error if receiver dropped)
|
||||
let _ = pending.sender.send(decision);
|
||||
info!(request_id = %request_id, ?decision, "Approval request resolved");
|
||||
Ok(response)
|
||||
}
|
||||
None => Err(format!("No pending approval request with id {request_id}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// List all pending requests (for API/dashboard display).
|
||||
pub fn list_pending(&self) -> Vec<ApprovalRequest> {
|
||||
self.pending
|
||||
.iter()
|
||||
.map(|r| r.value().request.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Number of pending requests.
|
||||
pub fn pending_count(&self) -> usize {
|
||||
self.pending.len()
|
||||
}
|
||||
|
||||
/// Update the approval policy (for hot-reload).
|
||||
pub fn update_policy(&self, policy: ApprovalPolicy) {
|
||||
*self.policy.write().unwrap_or_else(|e| e.into_inner()) = policy;
|
||||
}
|
||||
|
||||
/// Get a copy of the current policy.
|
||||
pub fn policy(&self) -> ApprovalPolicy {
|
||||
self.policy
|
||||
.read()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Classify the risk level of a tool invocation.
|
||||
pub fn classify_risk(tool_name: &str) -> RiskLevel {
|
||||
match tool_name {
|
||||
"shell_exec" => RiskLevel::Critical,
|
||||
"file_write" | "file_delete" => RiskLevel::High,
|
||||
"web_fetch" | "browser_navigate" => RiskLevel::Medium,
|
||||
_ => RiskLevel::Low,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use openfang_types::approval::ApprovalPolicy;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn default_manager() -> ApprovalManager {
|
||||
ApprovalManager::new(ApprovalPolicy::default())
|
||||
}
|
||||
|
||||
fn make_request(agent_id: &str, tool_name: &str, timeout_secs: u64) -> ApprovalRequest {
|
||||
ApprovalRequest {
|
||||
id: Uuid::new_v4(),
|
||||
agent_id: agent_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
description: "test operation".to_string(),
|
||||
action_summary: "test action".to_string(),
|
||||
risk_level: RiskLevel::High,
|
||||
requested_at: Utc::now(),
|
||||
timeout_secs,
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// requires_approval
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_requires_approval_default() {
|
||||
let mgr = default_manager();
|
||||
assert!(mgr.requires_approval("shell_exec"));
|
||||
assert!(!mgr.requires_approval("file_read"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_requires_approval_custom_policy() {
|
||||
let policy = ApprovalPolicy {
|
||||
require_approval: vec!["file_write".to_string(), "file_delete".to_string()],
|
||||
timeout_secs: 30,
|
||||
auto_approve_autonomous: false,
|
||||
auto_approve: false,
|
||||
};
|
||||
let mgr = ApprovalManager::new(policy);
|
||||
assert!(mgr.requires_approval("file_write"));
|
||||
assert!(mgr.requires_approval("file_delete"));
|
||||
assert!(!mgr.requires_approval("shell_exec"));
|
||||
assert!(!mgr.requires_approval("file_read"));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// classify_risk
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_classify_risk() {
|
||||
assert_eq!(
|
||||
ApprovalManager::classify_risk("shell_exec"),
|
||||
RiskLevel::Critical
|
||||
);
|
||||
assert_eq!(
|
||||
ApprovalManager::classify_risk("file_write"),
|
||||
RiskLevel::High
|
||||
);
|
||||
assert_eq!(
|
||||
ApprovalManager::classify_risk("file_delete"),
|
||||
RiskLevel::High
|
||||
);
|
||||
assert_eq!(
|
||||
ApprovalManager::classify_risk("web_fetch"),
|
||||
RiskLevel::Medium
|
||||
);
|
||||
assert_eq!(
|
||||
ApprovalManager::classify_risk("browser_navigate"),
|
||||
RiskLevel::Medium
|
||||
);
|
||||
assert_eq!(ApprovalManager::classify_risk("file_read"), RiskLevel::Low);
|
||||
assert_eq!(
|
||||
ApprovalManager::classify_risk("unknown_tool"),
|
||||
RiskLevel::Low
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// resolve nonexistent
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_resolve_nonexistent() {
|
||||
let mgr = default_manager();
|
||||
let result = mgr.resolve(Uuid::new_v4(), ApprovalDecision::Approved, None);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("No pending approval request"));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// list_pending empty
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_list_pending_empty() {
|
||||
let mgr = default_manager();
|
||||
assert!(mgr.list_pending().is_empty());
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// update_policy
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_update_policy() {
|
||||
let mgr = default_manager();
|
||||
assert!(mgr.requires_approval("shell_exec"));
|
||||
assert!(!mgr.requires_approval("file_write"));
|
||||
|
||||
let new_policy = ApprovalPolicy {
|
||||
require_approval: vec!["file_write".to_string()],
|
||||
timeout_secs: 120,
|
||||
auto_approve_autonomous: true,
|
||||
auto_approve: false,
|
||||
};
|
||||
mgr.update_policy(new_policy);
|
||||
|
||||
assert!(!mgr.requires_approval("shell_exec"));
|
||||
assert!(mgr.requires_approval("file_write"));
|
||||
|
||||
let policy = mgr.policy();
|
||||
assert_eq!(policy.timeout_secs, 120);
|
||||
assert!(policy.auto_approve_autonomous);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// pending_count
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_pending_count() {
|
||||
let mgr = default_manager();
|
||||
assert_eq!(mgr.pending_count(), 0);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// request_approval — timeout
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_approval_timeout() {
|
||||
let mgr = Arc::new(default_manager());
|
||||
let req = make_request("agent-1", "shell_exec", 10);
|
||||
let decision = mgr.request_approval(req).await;
|
||||
assert_eq!(decision, ApprovalDecision::TimedOut);
|
||||
// After timeout, pending map should be cleaned up
|
||||
assert_eq!(mgr.pending_count(), 0);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// request_approval — approve
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_approval_approve() {
|
||||
let mgr = Arc::new(default_manager());
|
||||
let req = make_request("agent-1", "shell_exec", 60);
|
||||
let request_id = req.id;
|
||||
|
||||
let mgr2 = Arc::clone(&mgr);
|
||||
tokio::spawn(async move {
|
||||
// Small delay to let the request register
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
let result = mgr2.resolve(
|
||||
request_id,
|
||||
ApprovalDecision::Approved,
|
||||
Some("admin".to_string()),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
let resp = result.unwrap();
|
||||
assert_eq!(resp.decision, ApprovalDecision::Approved);
|
||||
assert_eq!(resp.decided_by, Some("admin".to_string()));
|
||||
});
|
||||
|
||||
let decision = mgr.request_approval(req).await;
|
||||
assert_eq!(decision, ApprovalDecision::Approved);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// request_approval — deny
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_approval_deny() {
|
||||
let mgr = Arc::new(default_manager());
|
||||
let req = make_request("agent-1", "shell_exec", 60);
|
||||
let request_id = req.id;
|
||||
|
||||
let mgr2 = Arc::clone(&mgr);
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
let result = mgr2.resolve(request_id, ApprovalDecision::Denied, None);
|
||||
assert!(result.is_ok());
|
||||
});
|
||||
|
||||
let decision = mgr.request_approval(req).await;
|
||||
assert_eq!(decision, ApprovalDecision::Denied);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// max pending per agent
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_pending_per_agent() {
|
||||
let mgr = Arc::new(default_manager());
|
||||
|
||||
// Fill up 5 pending requests for agent-1 (they will all be waiting)
|
||||
let mut ids = Vec::new();
|
||||
for _ in 0..MAX_PENDING_PER_AGENT {
|
||||
let req = make_request("agent-1", "shell_exec", 300);
|
||||
ids.push(req.id);
|
||||
let mgr_clone = Arc::clone(&mgr);
|
||||
tokio::spawn(async move {
|
||||
mgr_clone.request_approval(req).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Give spawned tasks time to register
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
assert_eq!(mgr.pending_count(), MAX_PENDING_PER_AGENT);
|
||||
|
||||
// 6th request for the same agent should be immediately denied
|
||||
let req6 = make_request("agent-1", "shell_exec", 300);
|
||||
let decision = mgr.request_approval(req6).await;
|
||||
assert_eq!(decision, ApprovalDecision::Denied);
|
||||
|
||||
// A different agent should still be able to submit
|
||||
let req_other = make_request("agent-2", "shell_exec", 300);
|
||||
let other_id = req_other.id;
|
||||
let mgr2 = Arc::clone(&mgr);
|
||||
tokio::spawn(async move {
|
||||
mgr2.request_approval(req_other).await;
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
assert_eq!(mgr.pending_count(), MAX_PENDING_PER_AGENT + 1);
|
||||
|
||||
// Cleanup: resolve all pending to avoid hanging tasks
|
||||
for id in &ids {
|
||||
let _ = mgr.resolve(*id, ApprovalDecision::Denied, None);
|
||||
}
|
||||
let _ = mgr.resolve(other_id, ApprovalDecision::Denied, None);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// policy defaults
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_policy_defaults() {
|
||||
let mgr = default_manager();
|
||||
let policy = mgr.policy();
|
||||
assert_eq!(policy.require_approval, vec!["shell_exec".to_string()]);
|
||||
assert_eq!(policy.timeout_secs, 60);
|
||||
assert!(!policy.auto_approve_autonomous);
|
||||
}
|
||||
}
|
||||
316
crates/openfang-kernel/src/auth.rs
Normal file
316
crates/openfang-kernel/src/auth.rs
Normal file
@@ -0,0 +1,316 @@
|
||||
//! RBAC authentication and authorization for multi-user access control.
|
||||
//!
|
||||
//! The AuthManager maps platform user identities (Telegram ID, Discord ID, etc.)
|
||||
//! to OpenFang users with roles, then enforces permission checks on actions.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::UserId;
|
||||
use openfang_types::config::UserConfig;
|
||||
use openfang_types::error::{OpenFangError, OpenFangResult};
|
||||
use std::fmt;
|
||||
use tracing::info;
|
||||
|
||||
/// User roles with hierarchical permissions.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum UserRole {
|
||||
/// Read-only access — can view agent output but cannot interact.
|
||||
Viewer = 0,
|
||||
/// Standard user — can chat with agents.
|
||||
User = 1,
|
||||
/// Admin — can spawn/kill agents, install skills, view usage.
|
||||
Admin = 2,
|
||||
/// Owner — full access including user management and config changes.
|
||||
Owner = 3,
|
||||
}
|
||||
|
||||
impl fmt::Display for UserRole {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
UserRole::Viewer => write!(f, "viewer"),
|
||||
UserRole::User => write!(f, "user"),
|
||||
UserRole::Admin => write!(f, "admin"),
|
||||
UserRole::Owner => write!(f, "owner"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserRole {
|
||||
/// Parse a role from a string.
|
||||
pub fn from_str_role(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"owner" => UserRole::Owner,
|
||||
"admin" => UserRole::Admin,
|
||||
"viewer" => UserRole::Viewer,
|
||||
_ => UserRole::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Actions that can be authorized.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Action {
|
||||
/// Chat with an agent.
|
||||
ChatWithAgent,
|
||||
/// Spawn a new agent.
|
||||
SpawnAgent,
|
||||
/// Kill a running agent.
|
||||
KillAgent,
|
||||
/// Install a skill.
|
||||
InstallSkill,
|
||||
/// View kernel configuration.
|
||||
ViewConfig,
|
||||
/// Modify kernel configuration.
|
||||
ModifyConfig,
|
||||
/// View usage/billing data.
|
||||
ViewUsage,
|
||||
/// Manage users (create, delete, change roles).
|
||||
ManageUsers,
|
||||
}
|
||||
|
||||
impl Action {
|
||||
/// Minimum role required for this action.
|
||||
fn required_role(&self) -> UserRole {
|
||||
match self {
|
||||
Action::ChatWithAgent => UserRole::User,
|
||||
Action::ViewConfig => UserRole::User,
|
||||
Action::ViewUsage => UserRole::Admin,
|
||||
Action::SpawnAgent => UserRole::Admin,
|
||||
Action::KillAgent => UserRole::Admin,
|
||||
Action::InstallSkill => UserRole::Admin,
|
||||
Action::ModifyConfig => UserRole::Owner,
|
||||
Action::ManageUsers => UserRole::Owner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A resolved user identity.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UserIdentity {
|
||||
/// OpenFang user ID.
|
||||
pub id: UserId,
|
||||
/// Display name.
|
||||
pub name: String,
|
||||
/// Role.
|
||||
pub role: UserRole,
|
||||
}
|
||||
|
||||
/// RBAC authentication and authorization manager.
|
||||
pub struct AuthManager {
|
||||
/// Known users by their OpenFang user ID.
|
||||
users: DashMap<UserId, UserIdentity>,
|
||||
/// Channel binding index: "channel_type:platform_id" → UserId.
|
||||
channel_index: DashMap<String, UserId>,
|
||||
}
|
||||
|
||||
impl AuthManager {
|
||||
/// Create a new AuthManager from kernel user configuration.
|
||||
pub fn new(user_configs: &[UserConfig]) -> Self {
|
||||
let manager = Self {
|
||||
users: DashMap::new(),
|
||||
channel_index: DashMap::new(),
|
||||
};
|
||||
|
||||
for config in user_configs {
|
||||
let user_id = UserId::new();
|
||||
let role = UserRole::from_str_role(&config.role);
|
||||
let identity = UserIdentity {
|
||||
id: user_id,
|
||||
name: config.name.clone(),
|
||||
role,
|
||||
};
|
||||
|
||||
manager.users.insert(user_id, identity);
|
||||
|
||||
// Index channel bindings
|
||||
for (channel_type, platform_id) in &config.channel_bindings {
|
||||
let key = format!("{channel_type}:{platform_id}");
|
||||
manager.channel_index.insert(key, user_id);
|
||||
}
|
||||
|
||||
info!(
|
||||
user = %config.name,
|
||||
role = %role,
|
||||
bindings = config.channel_bindings.len(),
|
||||
"Registered user"
|
||||
);
|
||||
}
|
||||
|
||||
manager
|
||||
}
|
||||
|
||||
/// Identify a user from a channel identity.
|
||||
///
|
||||
/// Returns the OpenFang UserId if a matching channel binding exists,
|
||||
/// or None for unrecognized users.
|
||||
pub fn identify(&self, channel_type: &str, platform_id: &str) -> Option<UserId> {
|
||||
let key = format!("{channel_type}:{platform_id}");
|
||||
self.channel_index.get(&key).map(|r| *r.value())
|
||||
}
|
||||
|
||||
/// Get a user's identity by their UserId.
|
||||
pub fn get_user(&self, user_id: UserId) -> Option<UserIdentity> {
|
||||
self.users.get(&user_id).map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
/// Authorize a user for an action.
|
||||
///
|
||||
/// Returns Ok(()) if the user has sufficient permissions, or AuthDenied error.
|
||||
pub fn authorize(&self, user_id: UserId, action: &Action) -> OpenFangResult<()> {
|
||||
let identity = self
|
||||
.users
|
||||
.get(&user_id)
|
||||
.ok_or_else(|| OpenFangError::AuthDenied("Unknown user".to_string()))?;
|
||||
|
||||
let required = action.required_role();
|
||||
if identity.role >= required {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(OpenFangError::AuthDenied(format!(
|
||||
"User '{}' (role: {}) lacks permission for {:?} (requires: {})",
|
||||
identity.name, identity.role, action, required
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if RBAC is configured (any users registered).
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
!self.users.is_empty()
|
||||
}
|
||||
|
||||
/// Get the count of registered users.
|
||||
pub fn user_count(&self) -> usize {
|
||||
self.users.len()
|
||||
}
|
||||
|
||||
/// List all registered users.
|
||||
pub fn list_users(&self) -> Vec<UserIdentity> {
|
||||
self.users.iter().map(|r| r.value().clone()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn test_configs() -> Vec<UserConfig> {
|
||||
vec![
|
||||
UserConfig {
|
||||
name: "Alice".to_string(),
|
||||
role: "owner".to_string(),
|
||||
channel_bindings: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("telegram".to_string(), "123456".to_string());
|
||||
m.insert("discord".to_string(), "987654".to_string());
|
||||
m
|
||||
},
|
||||
api_key_hash: None,
|
||||
},
|
||||
UserConfig {
|
||||
name: "Guest".to_string(),
|
||||
role: "user".to_string(),
|
||||
channel_bindings: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("telegram".to_string(), "999999".to_string());
|
||||
m
|
||||
},
|
||||
api_key_hash: None,
|
||||
},
|
||||
UserConfig {
|
||||
name: "ReadOnly".to_string(),
|
||||
role: "viewer".to_string(),
|
||||
channel_bindings: HashMap::new(),
|
||||
api_key_hash: None,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_registration() {
|
||||
let manager = AuthManager::new(&test_configs());
|
||||
assert!(manager.is_enabled());
|
||||
assert_eq!(manager.user_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identify_from_channel() {
|
||||
let manager = AuthManager::new(&test_configs());
|
||||
|
||||
// Alice on Telegram
|
||||
let owner_tg = manager.identify("telegram", "123456");
|
||||
assert!(owner_tg.is_some());
|
||||
|
||||
// Alice on Discord
|
||||
let owner_dc = manager.identify("discord", "987654");
|
||||
assert!(owner_dc.is_some());
|
||||
|
||||
// Same user across channels
|
||||
assert_eq!(owner_tg.unwrap(), owner_dc.unwrap());
|
||||
|
||||
// Unknown user
|
||||
assert!(manager.identify("telegram", "unknown").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_owner_can_do_everything() {
|
||||
let manager = AuthManager::new(&test_configs());
|
||||
let owner_id = manager.identify("telegram", "123456").unwrap();
|
||||
|
||||
assert!(manager.authorize(owner_id, &Action::ChatWithAgent).is_ok());
|
||||
assert!(manager.authorize(owner_id, &Action::SpawnAgent).is_ok());
|
||||
assert!(manager.authorize(owner_id, &Action::KillAgent).is_ok());
|
||||
assert!(manager.authorize(owner_id, &Action::ManageUsers).is_ok());
|
||||
assert!(manager.authorize(owner_id, &Action::ModifyConfig).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_limited_access() {
|
||||
let manager = AuthManager::new(&test_configs());
|
||||
let guest_id = manager.identify("telegram", "999999").unwrap();
|
||||
|
||||
// User can chat and view config
|
||||
assert!(manager.authorize(guest_id, &Action::ChatWithAgent).is_ok());
|
||||
assert!(manager.authorize(guest_id, &Action::ViewConfig).is_ok());
|
||||
|
||||
// User cannot spawn/kill/manage
|
||||
assert!(manager.authorize(guest_id, &Action::SpawnAgent).is_err());
|
||||
assert!(manager.authorize(guest_id, &Action::KillAgent).is_err());
|
||||
assert!(manager.authorize(guest_id, &Action::ManageUsers).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_viewer_read_only() {
|
||||
let manager = AuthManager::new(&test_configs());
|
||||
let users = manager.list_users();
|
||||
let viewer = users.iter().find(|u| u.name == "ReadOnly").unwrap();
|
||||
|
||||
// Viewer cannot even chat
|
||||
assert!(manager
|
||||
.authorize(viewer.id, &Action::ChatWithAgent)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_user_denied() {
|
||||
let manager = AuthManager::new(&test_configs());
|
||||
let fake_id = UserId::new();
|
||||
assert!(manager.authorize(fake_id, &Action::ChatWithAgent).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_users_means_disabled() {
|
||||
let manager = AuthManager::new(&[]);
|
||||
assert!(!manager.is_enabled());
|
||||
assert_eq!(manager.user_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_role_parsing() {
|
||||
assert_eq!(UserRole::from_str_role("owner"), UserRole::Owner);
|
||||
assert_eq!(UserRole::from_str_role("admin"), UserRole::Admin);
|
||||
assert_eq!(UserRole::from_str_role("viewer"), UserRole::Viewer);
|
||||
assert_eq!(UserRole::from_str_role("user"), UserRole::User);
|
||||
assert_eq!(UserRole::from_str_role("OWNER"), UserRole::Owner);
|
||||
assert_eq!(UserRole::from_str_role("unknown"), UserRole::User);
|
||||
}
|
||||
}
|
||||
211
crates/openfang-kernel/src/auto_reply.rs
Normal file
211
crates/openfang-kernel/src/auto_reply.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
//! Auto-reply background engine — trigger-driven background replies with concurrency control.
|
||||
|
||||
use openfang_types::agent::AgentId;
|
||||
use openfang_types::config::AutoReplyConfig;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Semaphore;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Where to deliver the auto-reply result.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoReplyChannel {
|
||||
/// Channel type string (e.g., "telegram", "discord").
|
||||
pub channel_type: String,
|
||||
/// Peer/user ID to send the reply to.
|
||||
pub peer_id: String,
|
||||
/// Optional thread ID for threaded replies.
|
||||
pub thread_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Auto-reply engine with concurrency limits and suppression patterns.
|
||||
pub struct AutoReplyEngine {
|
||||
config: AutoReplyConfig,
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl AutoReplyEngine {
|
||||
/// Create a new auto-reply engine from configuration.
|
||||
pub fn new(config: AutoReplyConfig) -> Self {
|
||||
let permits = config.max_concurrent.max(1);
|
||||
Self {
|
||||
semaphore: Arc::new(Semaphore::new(permits)),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a message should trigger auto-reply.
|
||||
/// Returns `None` if suppressed or disabled, `Some(agent_id)` if should auto-reply.
|
||||
pub fn should_reply(
|
||||
&self,
|
||||
message: &str,
|
||||
_channel_type: &str,
|
||||
agent_id: AgentId,
|
||||
) -> Option<AgentId> {
|
||||
if !self.config.enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check suppression patterns
|
||||
let lower = message.to_lowercase();
|
||||
for pattern in &self.config.suppress_patterns {
|
||||
if lower.contains(&pattern.to_lowercase()) {
|
||||
debug!(pattern = %pattern, "Auto-reply suppressed by pattern");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(agent_id)
|
||||
}
|
||||
|
||||
/// Execute an auto-reply in the background.
|
||||
/// Returns a JoinHandle for the spawned task.
|
||||
///
|
||||
/// The `send_fn` is called with the agent response to deliver it back to the channel.
|
||||
pub async fn execute_reply<F>(
|
||||
&self,
|
||||
kernel_handle: Arc<dyn openfang_runtime::kernel_handle::KernelHandle>,
|
||||
agent_id: AgentId,
|
||||
message: String,
|
||||
reply_channel: AutoReplyChannel,
|
||||
send_fn: F,
|
||||
) -> Result<tokio::task::JoinHandle<()>, String>
|
||||
where
|
||||
F: Fn(String, AutoReplyChannel) -> futures::future::BoxFuture<'static, ()>
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
{
|
||||
// Try to acquire a semaphore permit
|
||||
let permit = match self.semaphore.clone().try_acquire_owned() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Err(format!(
|
||||
"Auto-reply concurrency limit reached ({} max)",
|
||||
self.config.max_concurrent
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let timeout_secs = self.config.timeout_secs;
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = permit; // Hold permit until task completes
|
||||
|
||||
info!(
|
||||
agent = %agent_id,
|
||||
channel = %reply_channel.channel_type,
|
||||
peer = %reply_channel.peer_id,
|
||||
"Starting auto-reply"
|
||||
);
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
kernel_handle.send_to_agent(&agent_id.to_string(), &message),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(response)) => {
|
||||
send_fn(response, reply_channel).await;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
warn!(agent = %agent_id, error = %e, "Auto-reply agent error");
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(agent = %agent_id, timeout = timeout_secs, "Auto-reply timed out");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Check if auto-reply is enabled.
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.config.enabled
|
||||
}
|
||||
|
||||
/// Get the current configuration (read-only).
|
||||
pub fn config(&self) -> &AutoReplyConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get available permits (for monitoring).
|
||||
pub fn available_permits(&self) -> usize {
|
||||
self.semaphore.available_permits()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config(enabled: bool) -> AutoReplyConfig {
|
||||
AutoReplyConfig {
|
||||
enabled,
|
||||
max_concurrent: 3,
|
||||
timeout_secs: 120,
|
||||
suppress_patterns: vec!["/stop".to_string(), "/pause".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disabled_engine() {
|
||||
let engine = AutoReplyEngine::new(test_config(false));
|
||||
let agent_id = AgentId::new();
|
||||
assert!(engine.should_reply("hello", "telegram", agent_id).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enabled_engine_allows() {
|
||||
let engine = AutoReplyEngine::new(test_config(true));
|
||||
let agent_id = AgentId::new();
|
||||
let result = engine.should_reply("hello there", "telegram", agent_id);
|
||||
assert_eq!(result, Some(agent_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_suppression_patterns() {
|
||||
let engine = AutoReplyEngine::new(test_config(true));
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
// Should be suppressed
|
||||
assert!(engine.should_reply("/stop", "telegram", agent_id).is_none());
|
||||
assert!(engine
|
||||
.should_reply("please /pause this", "telegram", agent_id)
|
||||
.is_none());
|
||||
|
||||
// Not suppressed
|
||||
assert!(engine.should_reply("hello", "telegram", agent_id).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrency_limit() {
|
||||
let config = AutoReplyConfig {
|
||||
enabled: true,
|
||||
max_concurrent: 2,
|
||||
timeout_secs: 120,
|
||||
suppress_patterns: Vec::new(),
|
||||
};
|
||||
let engine = AutoReplyEngine::new(config);
|
||||
assert_eq!(engine.available_permits(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_enabled() {
|
||||
let on = AutoReplyEngine::new(test_config(true));
|
||||
assert!(on.is_enabled());
|
||||
|
||||
let off = AutoReplyEngine::new(test_config(false));
|
||||
assert!(!off.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = AutoReplyConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.max_concurrent, 3);
|
||||
assert_eq!(config.timeout_secs, 120);
|
||||
assert!(config.suppress_patterns.contains(&"/stop".to_string()));
|
||||
}
|
||||
}
|
||||
457
crates/openfang-kernel/src/background.rs
Normal file
457
crates/openfang-kernel/src/background.rs
Normal file
@@ -0,0 +1,457 @@
|
||||
//! Background agent executor — runs agents autonomously on schedules, timers, and conditions.
|
||||
//!
|
||||
//! Supports three autonomous modes:
|
||||
//! - **Continuous**: Agent self-prompts on a fixed interval.
|
||||
//! - **Periodic**: Agent wakes on a simplified cron schedule (e.g. "every 5m").
|
||||
//! - **Proactive**: Agent wakes when matching events fire (via the trigger engine).
|
||||
|
||||
use crate::triggers::TriggerPattern;
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::{AgentId, ScheduleMode};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Maximum number of concurrent background LLM calls across all agents.
|
||||
const MAX_CONCURRENT_BG_LLM: usize = 5;
|
||||
|
||||
/// Manages background task loops for autonomous agents.
|
||||
pub struct BackgroundExecutor {
|
||||
/// Running background task handles, keyed by agent ID.
|
||||
tasks: DashMap<AgentId, JoinHandle<()>>,
|
||||
/// Shutdown signal receiver (from Supervisor).
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// SECURITY: Global semaphore to limit concurrent background LLM calls.
|
||||
llm_semaphore: Arc<tokio::sync::Semaphore>,
|
||||
}
|
||||
|
||||
impl BackgroundExecutor {
|
||||
/// Create a new executor bound to the supervisor's shutdown signal.
|
||||
pub fn new(shutdown_rx: watch::Receiver<bool>) -> Self {
|
||||
Self {
|
||||
tasks: DashMap::new(),
|
||||
shutdown_rx,
|
||||
llm_semaphore: Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT_BG_LLM)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a background loop for an agent based on its schedule mode.
|
||||
///
|
||||
/// For `Continuous` and `Periodic` modes, spawns a tokio task that
|
||||
/// periodically sends a self-prompt message to the agent.
|
||||
/// For `Proactive` mode, registers triggers — no dedicated task needed.
|
||||
///
|
||||
/// `send_message` is a closure that sends a message to the given agent
|
||||
/// and returns a result. It captures an `Arc<OpenFangKernel>` from the caller.
|
||||
pub fn start_agent<F>(
|
||||
&self,
|
||||
agent_id: AgentId,
|
||||
agent_name: &str,
|
||||
schedule: &ScheduleMode,
|
||||
send_message: F,
|
||||
) where
|
||||
F: Fn(AgentId, String) -> tokio::task::JoinHandle<()> + Send + Sync + 'static,
|
||||
{
|
||||
match schedule {
|
||||
ScheduleMode::Reactive => {} // nothing to do
|
||||
ScheduleMode::Continuous {
|
||||
check_interval_secs,
|
||||
} => {
|
||||
let interval = std::time::Duration::from_secs(*check_interval_secs);
|
||||
let name = agent_name.to_string();
|
||||
let mut shutdown = self.shutdown_rx.clone();
|
||||
let busy = Arc::new(AtomicBool::new(false));
|
||||
let semaphore = self.llm_semaphore.clone();
|
||||
|
||||
info!(
|
||||
agent = %name, id = %agent_id,
|
||||
interval_secs = check_interval_secs,
|
||||
"Starting continuous background loop"
|
||||
);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(interval) => {}
|
||||
_ = shutdown.changed() => {
|
||||
info!(agent = %name, "Continuous loop: shutdown signal received");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Skip if previous tick is still running
|
||||
if busy
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_err()
|
||||
{
|
||||
debug!(agent = %name, "Continuous loop: skipping tick (busy)");
|
||||
continue;
|
||||
}
|
||||
|
||||
// SECURITY: Acquire global LLM concurrency permit
|
||||
let permit = match semaphore.clone().acquire_owned().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
busy.store(false, Ordering::SeqCst);
|
||||
break; // Semaphore closed
|
||||
}
|
||||
};
|
||||
|
||||
let prompt = format!(
|
||||
"[AUTONOMOUS TICK] You are running in continuous mode. \
|
||||
Check your goals, review shared memory for pending tasks, \
|
||||
and take any necessary actions. Agent: {name}"
|
||||
);
|
||||
debug!(agent = %name, "Continuous loop: sending self-prompt");
|
||||
let busy_clone = busy.clone();
|
||||
let jh = (send_message)(agent_id, prompt);
|
||||
// Spawn a watcher that clears the busy flag and drops permit when done
|
||||
tokio::spawn(async move {
|
||||
let _ = jh.await;
|
||||
drop(permit);
|
||||
busy_clone.store(false, Ordering::SeqCst);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
self.tasks.insert(agent_id, handle);
|
||||
}
|
||||
ScheduleMode::Periodic { cron } => {
|
||||
let interval_secs = parse_cron_to_secs(cron);
|
||||
let interval = std::time::Duration::from_secs(interval_secs);
|
||||
let name = agent_name.to_string();
|
||||
let cron_owned = cron.clone();
|
||||
let mut shutdown = self.shutdown_rx.clone();
|
||||
let busy = Arc::new(AtomicBool::new(false));
|
||||
let semaphore = self.llm_semaphore.clone();
|
||||
|
||||
info!(
|
||||
agent = %name, id = %agent_id,
|
||||
cron = %cron, interval_secs = interval_secs,
|
||||
"Starting periodic background loop"
|
||||
);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(interval) => {}
|
||||
_ = shutdown.changed() => {
|
||||
info!(agent = %name, "Periodic loop: shutdown signal received");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if busy
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_err()
|
||||
{
|
||||
debug!(agent = %name, "Periodic loop: skipping tick (busy)");
|
||||
continue;
|
||||
}
|
||||
|
||||
// SECURITY: Acquire global LLM concurrency permit
|
||||
let permit = match semaphore.clone().acquire_owned().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
busy.store(false, Ordering::SeqCst);
|
||||
break; // Semaphore closed
|
||||
}
|
||||
};
|
||||
|
||||
let prompt = format!(
|
||||
"[SCHEDULED TICK] You are running on a periodic schedule ({cron_owned}). \
|
||||
Perform your routine duties. Agent: {name}"
|
||||
);
|
||||
debug!(agent = %name, "Periodic loop: sending scheduled prompt");
|
||||
let busy_clone = busy.clone();
|
||||
let jh = (send_message)(agent_id, prompt);
|
||||
tokio::spawn(async move {
|
||||
let _ = jh.await;
|
||||
drop(permit);
|
||||
busy_clone.store(false, Ordering::SeqCst);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
self.tasks.insert(agent_id, handle);
|
||||
}
|
||||
ScheduleMode::Proactive { .. } => {
|
||||
// Proactive agents rely on triggers, not a dedicated loop.
|
||||
// Triggers are registered by the kernel during spawn_agent / start_background_agents.
|
||||
debug!(agent = %agent_name, "Proactive agent — triggers handle activation");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stop the background loop for an agent, if one is running.
|
||||
pub fn stop_agent(&self, agent_id: AgentId) {
|
||||
if let Some((_, handle)) = self.tasks.remove(&agent_id) {
|
||||
handle.abort();
|
||||
info!(id = %agent_id, "Background loop stopped");
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of actively running background loops.
|
||||
pub fn active_count(&self) -> usize {
|
||||
self.tasks.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a proactive condition string into a `TriggerPattern`.
|
||||
///
|
||||
/// Supported formats:
|
||||
/// - `"event:agent_spawned"` → `TriggerPattern::AgentSpawned { name_pattern: "*" }`
|
||||
/// - `"event:agent_terminated"` → `TriggerPattern::AgentTerminated`
|
||||
/// - `"event:lifecycle"` → `TriggerPattern::Lifecycle`
|
||||
/// - `"event:system"` → `TriggerPattern::System`
|
||||
/// - `"memory:some_key"` → `TriggerPattern::MemoryKeyPattern { key_pattern: "some_key" }`
|
||||
/// - `"all"` → `TriggerPattern::All`
|
||||
pub fn parse_condition(condition: &str) -> Option<TriggerPattern> {
|
||||
let condition = condition.trim();
|
||||
|
||||
if condition.eq_ignore_ascii_case("all") {
|
||||
return Some(TriggerPattern::All);
|
||||
}
|
||||
|
||||
if let Some(event_kind) = condition.strip_prefix("event:") {
|
||||
let kind = event_kind.trim().to_lowercase();
|
||||
return match kind.as_str() {
|
||||
"agent_spawned" => Some(TriggerPattern::AgentSpawned {
|
||||
name_pattern: "*".to_string(),
|
||||
}),
|
||||
"agent_terminated" => Some(TriggerPattern::AgentTerminated),
|
||||
"lifecycle" => Some(TriggerPattern::Lifecycle),
|
||||
"system" => Some(TriggerPattern::System),
|
||||
"memory_update" => Some(TriggerPattern::MemoryUpdate),
|
||||
other => {
|
||||
warn!(condition = %condition, "Unknown event condition: {other}");
|
||||
None
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(key) = condition.strip_prefix("memory:") {
|
||||
return Some(TriggerPattern::MemoryKeyPattern {
|
||||
key_pattern: key.trim().to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
warn!(condition = %condition, "Unrecognized proactive condition format");
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse a simplified cron expression into seconds.
|
||||
///
|
||||
/// Supported formats:
|
||||
/// - `"every 30s"` → 30
|
||||
/// - `"every 5m"` → 300
|
||||
/// - `"every 1h"` → 3600
|
||||
/// - `"every 2d"` → 172800
|
||||
///
|
||||
/// Falls back to 300 seconds (5 minutes) for unparseable expressions.
|
||||
pub fn parse_cron_to_secs(cron: &str) -> u64 {
|
||||
let cron = cron.trim().to_lowercase();
|
||||
|
||||
// Try "every <N><unit>" format
|
||||
if let Some(rest) = cron.strip_prefix("every ") {
|
||||
let rest = rest.trim();
|
||||
if let Some(num_str) = rest.strip_suffix('s') {
|
||||
if let Ok(n) = num_str.trim().parse::<u64>() {
|
||||
return n;
|
||||
}
|
||||
}
|
||||
if let Some(num_str) = rest.strip_suffix('m') {
|
||||
if let Ok(n) = num_str.trim().parse::<u64>() {
|
||||
return n * 60;
|
||||
}
|
||||
}
|
||||
if let Some(num_str) = rest.strip_suffix('h') {
|
||||
if let Ok(n) = num_str.trim().parse::<u64>() {
|
||||
return n * 3600;
|
||||
}
|
||||
}
|
||||
if let Some(num_str) = rest.strip_suffix('d') {
|
||||
if let Ok(n) = num_str.trim().parse::<u64>() {
|
||||
return n * 86400;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
warn!(cron = %cron, "Unparseable cron expression, defaulting to 300s");
|
||||
300
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_cron_seconds() {
|
||||
assert_eq!(parse_cron_to_secs("every 30s"), 30);
|
||||
assert_eq!(parse_cron_to_secs("every 1s"), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_cron_minutes() {
|
||||
assert_eq!(parse_cron_to_secs("every 5m"), 300);
|
||||
assert_eq!(parse_cron_to_secs("every 1m"), 60);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_cron_hours() {
|
||||
assert_eq!(parse_cron_to_secs("every 1h"), 3600);
|
||||
assert_eq!(parse_cron_to_secs("every 2h"), 7200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_cron_days() {
|
||||
assert_eq!(parse_cron_to_secs("every 1d"), 86400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_cron_fallback() {
|
||||
// Unparseable → 300
|
||||
assert_eq!(parse_cron_to_secs("*/5 * * * *"), 300);
|
||||
assert_eq!(parse_cron_to_secs("gibberish"), 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_condition_events() {
|
||||
assert!(matches!(
|
||||
parse_condition("event:agent_spawned"),
|
||||
Some(TriggerPattern::AgentSpawned { .. })
|
||||
));
|
||||
assert!(matches!(
|
||||
parse_condition("event:agent_terminated"),
|
||||
Some(TriggerPattern::AgentTerminated)
|
||||
));
|
||||
assert!(matches!(
|
||||
parse_condition("event:lifecycle"),
|
||||
Some(TriggerPattern::Lifecycle)
|
||||
));
|
||||
assert!(matches!(
|
||||
parse_condition("event:system"),
|
||||
Some(TriggerPattern::System)
|
||||
));
|
||||
assert!(matches!(
|
||||
parse_condition("event:memory_update"),
|
||||
Some(TriggerPattern::MemoryUpdate)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_condition_memory() {
|
||||
match parse_condition("memory:agent.*.status") {
|
||||
Some(TriggerPattern::MemoryKeyPattern { key_pattern }) => {
|
||||
assert_eq!(key_pattern, "agent.*.status");
|
||||
}
|
||||
other => panic!("Expected MemoryKeyPattern, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_condition_all() {
|
||||
assert!(matches!(parse_condition("all"), Some(TriggerPattern::All)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_condition_unknown() {
|
||||
assert!(parse_condition("event:unknown_thing").is_none());
|
||||
assert!(parse_condition("badprefix:foo").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_continuous_shutdown() {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let executor = BackgroundExecutor::new(shutdown_rx);
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
let tick_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
|
||||
let tick_clone = tick_count.clone();
|
||||
|
||||
let schedule = ScheduleMode::Continuous {
|
||||
check_interval_secs: 1, // 1 second for fast test
|
||||
};
|
||||
|
||||
executor.start_agent(agent_id, "test-agent", &schedule, move |_id, _msg| {
|
||||
let tc = tick_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
tc.fetch_add(1, Ordering::SeqCst);
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(executor.active_count(), 1);
|
||||
|
||||
// Wait for at least 1 tick
|
||||
tokio::time::sleep(std::time::Duration::from_millis(1500)).await;
|
||||
assert!(tick_count.load(Ordering::SeqCst) >= 1);
|
||||
|
||||
// Shutdown
|
||||
let _ = shutdown_tx.send(true);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
|
||||
|
||||
// The loop should have exited (handle finished)
|
||||
// Active count still shows the entry until stop_agent is called
|
||||
executor.stop_agent(agent_id);
|
||||
assert_eq!(executor.active_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_skip_if_busy() {
|
||||
let (_shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let executor = BackgroundExecutor::new(shutdown_rx);
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
let tick_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
|
||||
let tick_clone = tick_count.clone();
|
||||
|
||||
let schedule = ScheduleMode::Continuous {
|
||||
check_interval_secs: 1,
|
||||
};
|
||||
|
||||
// Each tick takes 3 seconds — should cause subsequent ticks to be skipped
|
||||
executor.start_agent(agent_id, "slow-agent", &schedule, move |_id, _msg| {
|
||||
let tc = tick_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
tc.fetch_add(1, Ordering::SeqCst);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
})
|
||||
});
|
||||
|
||||
// Wait 2.5 seconds: 1 tick should fire at t=1s, second at t=2s should be skipped (busy)
|
||||
tokio::time::sleep(std::time::Duration::from_millis(2500)).await;
|
||||
let ticks = tick_count.load(Ordering::SeqCst);
|
||||
// Should be exactly 1 because the first tick is still "busy" when the second arrives
|
||||
assert_eq!(ticks, 1, "Expected 1 tick (skip-if-busy), got {ticks}");
|
||||
|
||||
executor.stop_agent(agent_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_active_count() {
|
||||
let (_tx, rx) = watch::channel(false);
|
||||
let executor = BackgroundExecutor::new(rx);
|
||||
assert_eq!(executor.active_count(), 0);
|
||||
|
||||
// Reactive mode → no background task
|
||||
let id = AgentId::new();
|
||||
executor.start_agent(id, "reactive", &ScheduleMode::Reactive, |_id, _msg| {
|
||||
tokio::spawn(async {})
|
||||
});
|
||||
assert_eq!(executor.active_count(), 0);
|
||||
|
||||
// Proactive mode → no dedicated task
|
||||
let id2 = AgentId::new();
|
||||
executor.start_agent(
|
||||
id2,
|
||||
"proactive",
|
||||
&ScheduleMode::Proactive {
|
||||
conditions: vec!["event:agent_spawned".to_string()],
|
||||
},
|
||||
|_id, _msg| tokio::spawn(async {}),
|
||||
);
|
||||
assert_eq!(executor.active_count(), 0);
|
||||
}
|
||||
}
|
||||
95
crates/openfang-kernel/src/capabilities.rs
Normal file
95
crates/openfang-kernel/src/capabilities.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
//! Capability manager — enforces capability-based security.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::AgentId;
|
||||
use openfang_types::capability::{capability_matches, Capability, CapabilityCheck};
|
||||
use tracing::debug;
|
||||
|
||||
/// Manages capability grants for all agents.
|
||||
pub struct CapabilityManager {
|
||||
/// Granted capabilities per agent.
|
||||
grants: DashMap<AgentId, Vec<Capability>>,
|
||||
}
|
||||
|
||||
impl CapabilityManager {
|
||||
/// Create a new capability manager.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
grants: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Grant capabilities to an agent.
|
||||
pub fn grant(&self, agent_id: AgentId, capabilities: Vec<Capability>) {
|
||||
self.grants.insert(agent_id, capabilities);
|
||||
}
|
||||
|
||||
/// Check whether an agent has a specific capability.
|
||||
pub fn check(&self, agent_id: AgentId, required: &Capability) -> CapabilityCheck {
|
||||
let grants = match self.grants.get(&agent_id) {
|
||||
Some(g) => g,
|
||||
None => {
|
||||
return CapabilityCheck::Denied(format!(
|
||||
"No capabilities registered for agent {agent_id}"
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
for granted in grants.value() {
|
||||
if capability_matches(granted, required) {
|
||||
debug!(agent = %agent_id, ?required, "Capability granted");
|
||||
return CapabilityCheck::Granted;
|
||||
}
|
||||
}
|
||||
|
||||
CapabilityCheck::Denied(format!(
|
||||
"Agent {agent_id} does not have capability: {required:?}"
|
||||
))
|
||||
}
|
||||
|
||||
/// List all capabilities for an agent.
|
||||
pub fn list(&self, agent_id: AgentId) -> Vec<Capability> {
|
||||
self.grants
|
||||
.get(&agent_id)
|
||||
.map(|g| g.value().clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Remove all capabilities for an agent.
|
||||
pub fn revoke_all(&self, agent_id: AgentId) {
|
||||
self.grants.remove(&agent_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CapabilityManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_grant_and_check() {
|
||||
let mgr = CapabilityManager::new();
|
||||
let id = AgentId::new();
|
||||
mgr.grant(id, vec![Capability::ToolInvoke("file_read".to_string())]);
|
||||
assert!(mgr
|
||||
.check(id, &Capability::ToolInvoke("file_read".to_string()))
|
||||
.is_granted());
|
||||
assert!(!mgr
|
||||
.check(id, &Capability::ToolInvoke("shell_exec".to_string()))
|
||||
.is_granted());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_grants() {
|
||||
let mgr = CapabilityManager::new();
|
||||
let id = AgentId::new();
|
||||
assert!(!mgr
|
||||
.check(id, &Capability::ToolInvoke("anything".to_string()))
|
||||
.is_granted());
|
||||
}
|
||||
}
|
||||
434
crates/openfang-kernel/src/config.rs
Normal file
434
crates/openfang-kernel/src/config.rs
Normal file
@@ -0,0 +1,434 @@
|
||||
//! Configuration loading from `~/.openfang/config.toml` with defaults.
|
||||
//!
|
||||
//! Supports config includes: the `include` field specifies additional TOML files
|
||||
//! to load and deep-merge before the root config (root overrides includes).
|
||||
|
||||
use openfang_types::config::KernelConfig;
|
||||
use std::collections::HashSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::info;
|
||||
|
||||
/// Maximum include nesting depth.
|
||||
const MAX_INCLUDE_DEPTH: u32 = 10;
|
||||
|
||||
/// Load kernel configuration from a TOML file, with defaults.
|
||||
///
|
||||
/// If the config contains an `include` field, included files are loaded
|
||||
/// and deep-merged first, then the root config overrides them.
|
||||
pub fn load_config(path: Option<&Path>) -> KernelConfig {
|
||||
let config_path = path
|
||||
.map(|p| p.to_path_buf())
|
||||
.unwrap_or_else(default_config_path);
|
||||
|
||||
if config_path.exists() {
|
||||
match std::fs::read_to_string(&config_path) {
|
||||
Ok(contents) => match toml::from_str::<toml::Value>(&contents) {
|
||||
Ok(mut root_value) => {
|
||||
// Process includes before deserializing
|
||||
let config_dir = config_path
|
||||
.parent()
|
||||
.unwrap_or_else(|| Path::new("."))
|
||||
.to_path_buf();
|
||||
let mut visited = HashSet::new();
|
||||
if let Ok(canonical) = std::fs::canonicalize(&config_path) {
|
||||
visited.insert(canonical);
|
||||
} else {
|
||||
visited.insert(config_path.clone());
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
resolve_config_includes(&mut root_value, &config_dir, &mut visited, 0)
|
||||
{
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
"Config include resolution failed, using root config only"
|
||||
);
|
||||
}
|
||||
|
||||
// Remove the `include` field before deserializing to avoid confusion
|
||||
if let toml::Value::Table(ref mut tbl) = root_value {
|
||||
tbl.remove("include");
|
||||
}
|
||||
|
||||
match root_value.try_into::<KernelConfig>() {
|
||||
Ok(config) => {
|
||||
info!(path = %config_path.display(), "Loaded configuration");
|
||||
return config;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
path = %config_path.display(),
|
||||
"Failed to deserialize merged config, using defaults"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
path = %config_path.display(),
|
||||
"Failed to parse config, using defaults"
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
path = %config_path.display(),
|
||||
"Failed to read config file, using defaults"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!(
|
||||
path = %config_path.display(),
|
||||
"Config file not found, using defaults"
|
||||
);
|
||||
}
|
||||
|
||||
KernelConfig::default()
|
||||
}
|
||||
|
||||
/// Resolve config includes by deep-merging included files into the root value.
|
||||
///
|
||||
/// Included files are loaded first and the root config overrides them.
|
||||
/// Security: rejects absolute paths, `..` components, and circular references.
|
||||
fn resolve_config_includes(
|
||||
root_value: &mut toml::Value,
|
||||
config_dir: &Path,
|
||||
visited: &mut HashSet<PathBuf>,
|
||||
depth: u32,
|
||||
) -> Result<(), String> {
|
||||
if depth > MAX_INCLUDE_DEPTH {
|
||||
return Err(format!(
|
||||
"Config include depth exceeded maximum of {MAX_INCLUDE_DEPTH}"
|
||||
));
|
||||
}
|
||||
|
||||
// Extract include list from the current value
|
||||
let includes = match root_value {
|
||||
toml::Value::Table(tbl) => {
|
||||
if let Some(toml::Value::Array(arr)) = tbl.get("include") {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
if includes.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Merge each include (earlier includes are overridden by later ones,
|
||||
// and the root config overrides everything).
|
||||
let mut merged_base = toml::Value::Table(toml::map::Map::new());
|
||||
|
||||
for include_path_str in &includes {
|
||||
// SECURITY: reject absolute paths
|
||||
let include_path = Path::new(include_path_str);
|
||||
if include_path.is_absolute() {
|
||||
return Err(format!(
|
||||
"Config include rejects absolute path: {include_path_str}"
|
||||
));
|
||||
}
|
||||
// SECURITY: reject `..` components
|
||||
for component in include_path.components() {
|
||||
if let std::path::Component::ParentDir = component {
|
||||
return Err(format!(
|
||||
"Config include rejects path traversal: {include_path_str}"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let resolved = config_dir.join(include_path);
|
||||
// SECURITY: verify resolved path stays within config dir
|
||||
let canonical = std::fs::canonicalize(&resolved).map_err(|e| {
|
||||
format!(
|
||||
"Config include '{}' cannot be resolved: {e}",
|
||||
include_path_str
|
||||
)
|
||||
})?;
|
||||
let canonical_dir = std::fs::canonicalize(config_dir)
|
||||
.map_err(|e| format!("Config dir cannot be canonicalized: {e}"))?;
|
||||
if !canonical.starts_with(&canonical_dir) {
|
||||
return Err(format!(
|
||||
"Config include '{}' escapes config directory",
|
||||
include_path_str
|
||||
));
|
||||
}
|
||||
|
||||
// SECURITY: circular detection
|
||||
if !visited.insert(canonical.clone()) {
|
||||
return Err(format!(
|
||||
"Circular config include detected: {include_path_str}"
|
||||
));
|
||||
}
|
||||
|
||||
info!(include = %include_path_str, "Loading config include");
|
||||
|
||||
let contents = std::fs::read_to_string(&canonical)
|
||||
.map_err(|e| format!("Failed to read config include '{}': {e}", include_path_str))?;
|
||||
let mut include_value: toml::Value = toml::from_str(&contents)
|
||||
.map_err(|e| format!("Failed to parse config include '{}': {e}", include_path_str))?;
|
||||
|
||||
// Recursively resolve includes in the included file
|
||||
let include_dir = canonical.parent().unwrap_or(config_dir).to_path_buf();
|
||||
resolve_config_includes(&mut include_value, &include_dir, visited, depth + 1)?;
|
||||
|
||||
// Remove include field from the included file
|
||||
if let toml::Value::Table(ref mut tbl) = include_value {
|
||||
tbl.remove("include");
|
||||
}
|
||||
|
||||
// Deep merge: include overrides the base built so far
|
||||
deep_merge_toml(&mut merged_base, &include_value);
|
||||
}
|
||||
|
||||
// Now deep merge: root overrides the merged includes
|
||||
// Save root's current values (minus include), then merge root on top
|
||||
let root_without_include = {
|
||||
let mut v = root_value.clone();
|
||||
if let toml::Value::Table(ref mut tbl) = v {
|
||||
tbl.remove("include");
|
||||
}
|
||||
v
|
||||
};
|
||||
deep_merge_toml(&mut merged_base, &root_without_include);
|
||||
*root_value = merged_base;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Deep-merge two TOML values. `overlay` values override `base` values.
|
||||
/// For tables, recursively merge. For everything else, overlay wins.
|
||||
pub fn deep_merge_toml(base: &mut toml::Value, overlay: &toml::Value) {
|
||||
match (base, overlay) {
|
||||
(toml::Value::Table(base_tbl), toml::Value::Table(overlay_tbl)) => {
|
||||
for (key, overlay_val) in overlay_tbl {
|
||||
if let Some(base_val) = base_tbl.get_mut(key) {
|
||||
deep_merge_toml(base_val, overlay_val);
|
||||
} else {
|
||||
base_tbl.insert(key.clone(), overlay_val.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
(base, overlay) => {
|
||||
*base = overlay.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default config file path.
|
||||
pub fn default_config_path() -> PathBuf {
|
||||
dirs::home_dir()
|
||||
.unwrap_or_else(std::env::temp_dir)
|
||||
.join(".openfang")
|
||||
.join("config.toml")
|
||||
}
|
||||
|
||||
/// Get the default OpenFang home directory.
|
||||
pub fn openfang_home() -> PathBuf {
|
||||
dirs::home_dir()
|
||||
.unwrap_or_else(std::env::temp_dir)
|
||||
.join(".openfang")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
|
||||
#[test]
|
||||
fn test_load_config_defaults() {
|
||||
let config = load_config(None);
|
||||
assert_eq!(config.log_level, "info");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_config_missing_file() {
|
||||
let config = load_config(Some(Path::new("/nonexistent/config.toml")));
|
||||
assert_eq!(config.log_level, "info");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deep_merge_simple() {
|
||||
let mut base: toml::Value = toml::from_str(
|
||||
r#"
|
||||
log_level = "debug"
|
||||
api_listen = "0.0.0.0:4200"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let overlay: toml::Value = toml::from_str(
|
||||
r#"
|
||||
log_level = "info"
|
||||
network_enabled = true
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
deep_merge_toml(&mut base, &overlay);
|
||||
assert_eq!(base["log_level"].as_str(), Some("info"));
|
||||
assert_eq!(base["api_listen"].as_str(), Some("0.0.0.0:4200"));
|
||||
assert_eq!(base["network_enabled"].as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deep_merge_nested_tables() {
|
||||
let mut base: toml::Value = toml::from_str(
|
||||
r#"
|
||||
[memory]
|
||||
decay_rate = 0.1
|
||||
consolidation_threshold = 10000
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let overlay: toml::Value = toml::from_str(
|
||||
r#"
|
||||
[memory]
|
||||
decay_rate = 0.5
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
deep_merge_toml(&mut base, &overlay);
|
||||
let mem = base["memory"].as_table().unwrap();
|
||||
assert_eq!(mem["decay_rate"].as_float(), Some(0.5));
|
||||
assert_eq!(mem["consolidation_threshold"].as_integer(), Some(10000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_include() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let base_path = dir.path().join("base.toml");
|
||||
let root_path = dir.path().join("config.toml");
|
||||
|
||||
// Base config
|
||||
let mut f = std::fs::File::create(&base_path).unwrap();
|
||||
writeln!(f, "log_level = \"debug\"").unwrap();
|
||||
writeln!(f, "api_listen = \"0.0.0.0:9999\"").unwrap();
|
||||
drop(f);
|
||||
|
||||
// Root config (includes base, overrides log_level)
|
||||
let mut f = std::fs::File::create(&root_path).unwrap();
|
||||
writeln!(f, "include = [\"base.toml\"]").unwrap();
|
||||
writeln!(f, "log_level = \"warn\"").unwrap();
|
||||
drop(f);
|
||||
|
||||
let config = load_config(Some(&root_path));
|
||||
assert_eq!(config.log_level, "warn"); // root overrides
|
||||
assert_eq!(config.api_listen, "0.0.0.0:9999"); // from base
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_include() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let grandchild = dir.path().join("grandchild.toml");
|
||||
let child = dir.path().join("child.toml");
|
||||
let root = dir.path().join("config.toml");
|
||||
|
||||
let mut f = std::fs::File::create(&grandchild).unwrap();
|
||||
writeln!(f, "log_level = \"trace\"").unwrap();
|
||||
drop(f);
|
||||
|
||||
let mut f = std::fs::File::create(&child).unwrap();
|
||||
writeln!(f, "include = [\"grandchild.toml\"]").unwrap();
|
||||
writeln!(f, "log_level = \"debug\"").unwrap();
|
||||
drop(f);
|
||||
|
||||
let mut f = std::fs::File::create(&root).unwrap();
|
||||
writeln!(f, "include = [\"child.toml\"]").unwrap();
|
||||
writeln!(f, "log_level = \"info\"").unwrap();
|
||||
drop(f);
|
||||
|
||||
let config = load_config(Some(&root));
|
||||
assert_eq!(config.log_level, "info"); // root wins
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circular_include_detected() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let a_path = dir.path().join("a.toml");
|
||||
let b_path = dir.path().join("b.toml");
|
||||
|
||||
let mut f = std::fs::File::create(&a_path).unwrap();
|
||||
writeln!(f, "include = [\"b.toml\"]").unwrap();
|
||||
writeln!(f, "log_level = \"info\"").unwrap();
|
||||
drop(f);
|
||||
|
||||
let mut f = std::fs::File::create(&b_path).unwrap();
|
||||
writeln!(f, "include = [\"a.toml\"]").unwrap();
|
||||
drop(f);
|
||||
|
||||
// Should not panic — circular detection triggers, falls back gracefully
|
||||
let config = load_config(Some(&a_path));
|
||||
// Falls back to defaults due to the circular error
|
||||
assert!(!config.log_level.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_traversal_blocked() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let root = dir.path().join("config.toml");
|
||||
|
||||
let mut f = std::fs::File::create(&root).unwrap();
|
||||
writeln!(f, "include = [\"../etc/passwd\"]").unwrap();
|
||||
drop(f);
|
||||
|
||||
// Should not panic — path traversal triggers error, falls back
|
||||
let config = load_config(Some(&root));
|
||||
assert_eq!(config.log_level, "info"); // defaults
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_depth_exceeded() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
|
||||
// Create a chain of 12 files (exceeds MAX_INCLUDE_DEPTH=10)
|
||||
for i in (0..12).rev() {
|
||||
let name = format!("level{i}.toml");
|
||||
let path = dir.path().join(&name);
|
||||
let mut f = std::fs::File::create(&path).unwrap();
|
||||
if i < 11 {
|
||||
let next = format!("level{}.toml", i + 1);
|
||||
writeln!(f, "include = [\"{next}\"]").unwrap();
|
||||
}
|
||||
writeln!(f, "log_level = \"level{i}\"").unwrap();
|
||||
drop(f);
|
||||
}
|
||||
|
||||
let root = dir.path().join("level0.toml");
|
||||
let config = load_config(Some(&root));
|
||||
// Falls back due to depth limit — but should not panic
|
||||
assert!(!config.log_level.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absolute_path_rejected() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let root = dir.path().join("config.toml");
|
||||
|
||||
let mut f = std::fs::File::create(&root).unwrap();
|
||||
writeln!(f, "include = [\"/etc/shadow\"]").unwrap();
|
||||
drop(f);
|
||||
|
||||
let config = load_config(Some(&root));
|
||||
assert_eq!(config.log_level, "info"); // defaults
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_includes_works() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let root = dir.path().join("config.toml");
|
||||
|
||||
let mut f = std::fs::File::create(&root).unwrap();
|
||||
writeln!(f, "log_level = \"trace\"").unwrap();
|
||||
drop(f);
|
||||
|
||||
let config = load_config(Some(&root));
|
||||
assert_eq!(config.log_level, "trace");
|
||||
}
|
||||
}
|
||||
674
crates/openfang-kernel/src/config_reload.rs
Normal file
674
crates/openfang-kernel/src/config_reload.rs
Normal file
@@ -0,0 +1,674 @@
|
||||
//! Config hot-reload — diffs two `KernelConfig` instances and produces a `ReloadPlan`.
|
||||
//!
|
||||
//! **Hot-reload safe**: channels, skills, usage footer, web config, browser,
|
||||
//! approval policy, cron settings, webhook triggers, extensions.
|
||||
//!
|
||||
//! **No-op** (informational only): log_level, language, mode.
|
||||
//!
|
||||
//! **Restart required**: api_listen, api_key, network, memory, default_model.
|
||||
|
||||
use openfang_types::config::{KernelConfig, ReloadMode};
|
||||
use tracing::{info, warn};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HotAction — what can be changed at runtime without restart
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// An individual action that can be applied at runtime (hot-reload).
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum HotAction {
|
||||
/// Channel configuration changed — reload channel bridges.
|
||||
ReloadChannels,
|
||||
/// Skill configuration changed — reload skill registry.
|
||||
ReloadSkills,
|
||||
/// Usage footer mode changed.
|
||||
UpdateUsageFooter,
|
||||
/// Web config changed — rebuild web tools context.
|
||||
ReloadWebConfig,
|
||||
/// Browser config changed.
|
||||
ReloadBrowserConfig,
|
||||
/// Approval policy changed.
|
||||
UpdateApprovalPolicy,
|
||||
/// Cron max jobs changed.
|
||||
UpdateCronConfig,
|
||||
/// Webhook trigger config changed.
|
||||
UpdateWebhookConfig,
|
||||
/// Extension config changed.
|
||||
ReloadExtensions,
|
||||
/// MCP server list changed — reconnect MCP clients.
|
||||
ReloadMcpServers,
|
||||
/// A2A config changed.
|
||||
ReloadA2aConfig,
|
||||
/// Fallback provider chain changed.
|
||||
ReloadFallbackProviders,
|
||||
/// Provider base URL overrides changed.
|
||||
ReloadProviderUrls,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ReloadPlan — the output of diffing two configs
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A categorized plan for applying config changes.
|
||||
///
|
||||
/// After building a plan via [`build_reload_plan`], callers inspect
|
||||
/// `restart_required` to decide whether a full restart is needed or
|
||||
/// the `hot_actions` can be applied in-place.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReloadPlan {
|
||||
/// Whether a full restart is needed.
|
||||
pub restart_required: bool,
|
||||
/// Human-readable reasons why restart is required.
|
||||
pub restart_reasons: Vec<String>,
|
||||
/// Actions that can be hot-reloaded without restart.
|
||||
pub hot_actions: Vec<HotAction>,
|
||||
/// Fields that changed but are no-ops (informational only).
|
||||
pub noop_changes: Vec<String>,
|
||||
}
|
||||
|
||||
impl ReloadPlan {
|
||||
/// Whether any changes were detected at all.
|
||||
pub fn has_changes(&self) -> bool {
|
||||
self.restart_required || !self.hot_actions.is_empty() || !self.noop_changes.is_empty()
|
||||
}
|
||||
|
||||
/// Whether the plan can be applied without restart.
|
||||
pub fn is_hot_reloadable(&self) -> bool {
|
||||
!self.restart_required
|
||||
}
|
||||
|
||||
/// Log a human-readable summary of the plan.
|
||||
pub fn log_summary(&self) {
|
||||
if !self.has_changes() {
|
||||
info!("config reload: no changes detected");
|
||||
return;
|
||||
}
|
||||
if self.restart_required {
|
||||
warn!(
|
||||
"config reload: restart required — {}",
|
||||
self.restart_reasons.join("; ")
|
||||
);
|
||||
}
|
||||
for action in &self.hot_actions {
|
||||
info!("config reload: hot-reload action queued — {action:?}");
|
||||
}
|
||||
for noop in &self.noop_changes {
|
||||
info!("config reload: no-op change — {noop}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// build_reload_plan
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compare JSON-serialized forms of a field. Returns `true` when the
|
||||
/// serialized representations differ (or if one side fails to serialize).
|
||||
fn field_changed<T: serde::Serialize>(old: &T, new: &T) -> bool {
|
||||
let old_json = serde_json::to_string(old).ok();
|
||||
let new_json = serde_json::to_string(new).ok();
|
||||
old_json != new_json
|
||||
}
|
||||
|
||||
/// Diff two configurations and produce a reload plan.
|
||||
///
|
||||
/// The plan categorizes every detected change into one of three buckets:
|
||||
///
|
||||
/// 1. **restart_required** — the change touches something that cannot be
|
||||
/// patched at runtime (e.g. the listen address or database path).
|
||||
/// 2. **hot_actions** — the change can be applied without restarting.
|
||||
/// 3. **noop_changes** — the change is informational; no action needed.
|
||||
pub fn build_reload_plan(old: &KernelConfig, new: &KernelConfig) -> ReloadPlan {
|
||||
let mut plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: Vec::new(),
|
||||
hot_actions: Vec::new(),
|
||||
noop_changes: Vec::new(),
|
||||
};
|
||||
|
||||
// ----- Restart-required fields -----
|
||||
|
||||
if old.api_listen != new.api_listen {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons.push(format!(
|
||||
"api_listen changed: {} -> {}",
|
||||
old.api_listen, new.api_listen
|
||||
));
|
||||
}
|
||||
|
||||
if old.api_key != new.api_key {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons.push("api_key changed".to_string());
|
||||
}
|
||||
|
||||
if old.network_enabled != new.network_enabled {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons
|
||||
.push("network_enabled changed".to_string());
|
||||
}
|
||||
|
||||
// Network config (shared_secret, listen_addresses, etc.)
|
||||
if field_changed(&old.network, &new.network) {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons
|
||||
.push("network config changed".to_string());
|
||||
}
|
||||
|
||||
// Memory config (requires restarting SQLite connections)
|
||||
if field_changed(&old.memory, &new.memory) {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons
|
||||
.push("memory config changed".to_string());
|
||||
}
|
||||
|
||||
// Default model (driver needs recreation)
|
||||
if field_changed(&old.default_model, &new.default_model) {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons
|
||||
.push("default_model changed".to_string());
|
||||
}
|
||||
|
||||
// Home/data directory changes
|
||||
if old.home_dir != new.home_dir {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons.push(format!(
|
||||
"home_dir changed: {:?} -> {:?}",
|
||||
old.home_dir, new.home_dir
|
||||
));
|
||||
}
|
||||
if old.data_dir != new.data_dir {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons.push(format!(
|
||||
"data_dir changed: {:?} -> {:?}",
|
||||
old.data_dir, new.data_dir
|
||||
));
|
||||
}
|
||||
|
||||
// Vault config (encryption key derivation)
|
||||
if field_changed(&old.vault, &new.vault) {
|
||||
plan.restart_required = true;
|
||||
plan.restart_reasons
|
||||
.push("vault config changed".to_string());
|
||||
}
|
||||
|
||||
// ----- Hot-reloadable fields -----
|
||||
|
||||
if field_changed(&old.channels, &new.channels) {
|
||||
plan.hot_actions.push(HotAction::ReloadChannels);
|
||||
}
|
||||
|
||||
if old.usage_footer != new.usage_footer {
|
||||
plan.hot_actions.push(HotAction::UpdateUsageFooter);
|
||||
}
|
||||
|
||||
if field_changed(&old.web, &new.web) {
|
||||
plan.hot_actions.push(HotAction::ReloadWebConfig);
|
||||
}
|
||||
|
||||
if field_changed(&old.browser, &new.browser) {
|
||||
plan.hot_actions.push(HotAction::ReloadBrowserConfig);
|
||||
}
|
||||
|
||||
if field_changed(&old.approval, &new.approval) {
|
||||
plan.hot_actions.push(HotAction::UpdateApprovalPolicy);
|
||||
}
|
||||
|
||||
if old.max_cron_jobs != new.max_cron_jobs {
|
||||
plan.hot_actions.push(HotAction::UpdateCronConfig);
|
||||
}
|
||||
|
||||
if field_changed(&old.webhook_triggers, &new.webhook_triggers) {
|
||||
plan.hot_actions.push(HotAction::UpdateWebhookConfig);
|
||||
}
|
||||
|
||||
if field_changed(&old.extensions, &new.extensions) {
|
||||
plan.hot_actions.push(HotAction::ReloadExtensions);
|
||||
}
|
||||
|
||||
if field_changed(&old.mcp_servers, &new.mcp_servers) {
|
||||
plan.hot_actions.push(HotAction::ReloadMcpServers);
|
||||
}
|
||||
|
||||
if field_changed(&old.a2a, &new.a2a) {
|
||||
plan.hot_actions.push(HotAction::ReloadA2aConfig);
|
||||
}
|
||||
|
||||
if field_changed(&old.fallback_providers, &new.fallback_providers) {
|
||||
plan.hot_actions.push(HotAction::ReloadFallbackProviders);
|
||||
}
|
||||
|
||||
if field_changed(&old.provider_urls, &new.provider_urls) {
|
||||
plan.hot_actions.push(HotAction::ReloadProviderUrls);
|
||||
}
|
||||
|
||||
// ----- No-op fields -----
|
||||
|
||||
if old.log_level != new.log_level {
|
||||
plan.noop_changes
|
||||
.push(format!("log_level: {} -> {}", old.log_level, new.log_level));
|
||||
}
|
||||
|
||||
if old.language != new.language {
|
||||
plan.noop_changes
|
||||
.push(format!("language: {} -> {}", old.language, new.language));
|
||||
}
|
||||
|
||||
if old.mode != new.mode {
|
||||
plan.noop_changes
|
||||
.push(format!("mode: {:?} -> {:?}", old.mode, new.mode));
|
||||
}
|
||||
|
||||
plan
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// validate_config_for_reload
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate a new config before applying it.
|
||||
///
|
||||
/// Returns `Ok(())` if the config passes basic sanity checks, or `Err` with
|
||||
/// a list of human-readable error messages.
|
||||
pub fn validate_config_for_reload(config: &KernelConfig) -> Result<(), Vec<String>> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
if config.api_listen.is_empty() {
|
||||
errors.push("api_listen cannot be empty".to_string());
|
||||
}
|
||||
|
||||
if config.max_cron_jobs > 10_000 {
|
||||
errors.push("max_cron_jobs exceeds reasonable limit (10000)".to_string());
|
||||
}
|
||||
|
||||
// Validate approval policy
|
||||
if let Err(e) = config.approval.validate() {
|
||||
errors.push(format!("approval policy: {e}"));
|
||||
}
|
||||
|
||||
// Network config: if network is enabled, shared_secret must be set
|
||||
if config.network_enabled && config.network.shared_secret.is_empty() {
|
||||
errors.push("network_enabled is true but network.shared_secret is empty".to_string());
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// should_reload — convenience helper for the reload mode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Given the configured [`ReloadMode`] and a [`ReloadPlan`], decide whether
|
||||
/// the caller should apply hot actions.
|
||||
///
|
||||
/// Returns `true` if hot-reload actions should be applied.
|
||||
pub fn should_apply_hot(mode: ReloadMode, plan: &ReloadPlan) -> bool {
|
||||
match mode {
|
||||
ReloadMode::Off => false,
|
||||
ReloadMode::Restart => false, // caller must do a full restart
|
||||
ReloadMode::Hot => !plan.hot_actions.is_empty(),
|
||||
ReloadMode::Hybrid => !plan.hot_actions.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Tests
|
||||
// ===========================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use openfang_types::config::KernelConfig;
|
||||
|
||||
/// Helper: create a default config for diffing.
|
||||
fn default_cfg() -> KernelConfig {
|
||||
KernelConfig::default()
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Plan detection tests
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_no_changes_detected() {
|
||||
let a = default_cfg();
|
||||
let b = default_cfg();
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(!plan.has_changes());
|
||||
assert!(!plan.restart_required);
|
||||
assert!(plan.hot_actions.is_empty());
|
||||
assert!(plan.noop_changes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_listen_requires_restart() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.api_listen = "0.0.0.0:8080".to_string();
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(plan.restart_required);
|
||||
assert!(plan
|
||||
.restart_reasons
|
||||
.iter()
|
||||
.any(|r| r.contains("api_listen")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_key_requires_restart() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.api_key = "super-secret-key".to_string();
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(plan.restart_required);
|
||||
assert!(plan.restart_reasons.iter().any(|r| r.contains("api_key")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_network_requires_restart() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.network_enabled = true;
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(plan.restart_required);
|
||||
assert!(plan
|
||||
.restart_reasons
|
||||
.iter()
|
||||
.any(|r| r.contains("network_enabled")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_network_config_requires_restart() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.network.shared_secret = "new-secret".to_string();
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(plan.restart_required);
|
||||
assert!(plan
|
||||
.restart_reasons
|
||||
.iter()
|
||||
.any(|r| r.contains("network config")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_config_requires_restart() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.memory.consolidation_threshold = 99_999;
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(plan.restart_required);
|
||||
assert!(plan
|
||||
.restart_reasons
|
||||
.iter()
|
||||
.any(|r| r.contains("memory config")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_model_requires_restart() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.default_model.model = "gpt-4".to_string();
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(plan.restart_required);
|
||||
assert!(plan
|
||||
.restart_reasons
|
||||
.iter()
|
||||
.any(|r| r.contains("default_model")));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Hot-reload tests
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_channels_hot_reload() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
// Change the channels config by adding a Telegram config
|
||||
b.channels.telegram = Some(openfang_types::config::TelegramConfig {
|
||||
bot_token_env: "TG_TOKEN".to_string(),
|
||||
..Default::default()
|
||||
});
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(!plan.restart_required);
|
||||
assert!(plan.hot_actions.contains(&HotAction::ReloadChannels));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_usage_footer_hot_reload() {
|
||||
use openfang_types::config::UsageFooterMode;
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.usage_footer = UsageFooterMode::Off;
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(!plan.restart_required);
|
||||
assert!(plan.hot_actions.contains(&HotAction::UpdateUsageFooter));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_cron_jobs_hot_reload() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.max_cron_jobs = 1000;
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(!plan.restart_required);
|
||||
assert!(plan.hot_actions.contains(&HotAction::UpdateCronConfig));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extensions_hot_reload() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.extensions.reconnect_max_attempts = 20;
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(!plan.restart_required);
|
||||
assert!(plan.hot_actions.contains(&HotAction::ReloadExtensions));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_urls_hot_reload() {
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.provider_urls
|
||||
.insert("ollama".to_string(), "http://10.0.0.5:11434/v1".to_string());
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(!plan.restart_required);
|
||||
assert!(plan.hot_actions.contains(&HotAction::ReloadProviderUrls));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Mixed changes
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_mixed_changes() {
|
||||
use openfang_types::config::UsageFooterMode;
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
// Restart-required
|
||||
b.api_listen = "0.0.0.0:9999".to_string();
|
||||
// Hot-reloadable
|
||||
b.usage_footer = UsageFooterMode::Tokens;
|
||||
b.max_cron_jobs = 100;
|
||||
// No-op
|
||||
b.log_level = "debug".to_string();
|
||||
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(plan.restart_required);
|
||||
assert!(plan.has_changes());
|
||||
// Hot actions are still collected even if restart is required,
|
||||
// so the caller knows what will need re-initialization after restart.
|
||||
assert!(plan.hot_actions.contains(&HotAction::UpdateUsageFooter));
|
||||
assert!(plan.hot_actions.contains(&HotAction::UpdateCronConfig));
|
||||
assert!(plan.noop_changes.iter().any(|c| c.contains("log_level")));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// No-op changes
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_noop_changes() {
|
||||
use openfang_types::config::KernelMode;
|
||||
let a = default_cfg();
|
||||
let mut b = default_cfg();
|
||||
b.log_level = "debug".to_string();
|
||||
b.language = "de".to_string();
|
||||
b.mode = KernelMode::Dev;
|
||||
|
||||
let plan = build_reload_plan(&a, &b);
|
||||
assert!(!plan.restart_required);
|
||||
assert!(plan.hot_actions.is_empty());
|
||||
assert_eq!(plan.noop_changes.len(), 3);
|
||||
assert!(plan.noop_changes.iter().any(|c| c.contains("log_level")));
|
||||
assert!(plan.noop_changes.iter().any(|c| c.contains("language")));
|
||||
assert!(plan.noop_changes.iter().any(|c| c.contains("mode")));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// has_changes / is_hot_reloadable helpers
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_has_changes() {
|
||||
// No changes
|
||||
let plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: vec![],
|
||||
hot_actions: vec![],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(!plan.has_changes());
|
||||
|
||||
// Only noop
|
||||
let plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: vec![],
|
||||
hot_actions: vec![],
|
||||
noop_changes: vec!["log_level: info -> debug".to_string()],
|
||||
};
|
||||
assert!(plan.has_changes());
|
||||
|
||||
// Only hot
|
||||
let plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: vec![],
|
||||
hot_actions: vec![HotAction::UpdateCronConfig],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(plan.has_changes());
|
||||
|
||||
// Only restart
|
||||
let plan = ReloadPlan {
|
||||
restart_required: true,
|
||||
restart_reasons: vec!["api_listen changed".to_string()],
|
||||
hot_actions: vec![],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(plan.has_changes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_hot_reloadable() {
|
||||
let plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: vec![],
|
||||
hot_actions: vec![HotAction::ReloadChannels],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(plan.is_hot_reloadable());
|
||||
|
||||
let plan = ReloadPlan {
|
||||
restart_required: true,
|
||||
restart_reasons: vec!["api_listen changed".to_string()],
|
||||
hot_actions: vec![HotAction::ReloadChannels],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(!plan.is_hot_reloadable());
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Validation tests
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_validate_config_for_reload_valid() {
|
||||
let config = default_cfg();
|
||||
assert!(validate_config_for_reload(&config).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_config_for_reload_invalid() {
|
||||
// Empty api_listen
|
||||
let mut config = default_cfg();
|
||||
config.api_listen = String::new();
|
||||
let err = validate_config_for_reload(&config).unwrap_err();
|
||||
assert!(err.iter().any(|e| e.contains("api_listen")));
|
||||
|
||||
// Excessive max_cron_jobs
|
||||
let mut config = default_cfg();
|
||||
config.max_cron_jobs = 100_000;
|
||||
let err = validate_config_for_reload(&config).unwrap_err();
|
||||
assert!(err.iter().any(|e| e.contains("max_cron_jobs")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_network_enabled_no_secret() {
|
||||
let mut config = default_cfg();
|
||||
config.network_enabled = true;
|
||||
config.network.shared_secret = String::new();
|
||||
let err = validate_config_for_reload(&config).unwrap_err();
|
||||
assert!(err.iter().any(|e| e.contains("shared_secret")));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// should_apply_hot
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_should_apply_hot_off() {
|
||||
let plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: vec![],
|
||||
hot_actions: vec![HotAction::ReloadChannels],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(!should_apply_hot(ReloadMode::Off, &plan));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_apply_hot_restart_mode() {
|
||||
let plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: vec![],
|
||||
hot_actions: vec![HotAction::ReloadChannels],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(!should_apply_hot(ReloadMode::Restart, &plan));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_apply_hot_hybrid() {
|
||||
let plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: vec![],
|
||||
hot_actions: vec![HotAction::ReloadChannels],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(should_apply_hot(ReloadMode::Hybrid, &plan));
|
||||
assert!(should_apply_hot(ReloadMode::Hot, &plan));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_apply_hot_empty() {
|
||||
let plan = ReloadPlan {
|
||||
restart_required: false,
|
||||
restart_reasons: vec![],
|
||||
hot_actions: vec![],
|
||||
noop_changes: vec![],
|
||||
};
|
||||
assert!(!should_apply_hot(ReloadMode::Hybrid, &plan));
|
||||
}
|
||||
}
|
||||
790
crates/openfang-kernel/src/cron.rs
Normal file
790
crates/openfang-kernel/src/cron.rs
Normal file
@@ -0,0 +1,790 @@
|
||||
//! Cron job scheduler engine for the OpenFang kernel.
|
||||
//!
|
||||
//! Manages scheduled jobs (recurring and one-shot) across all agents.
|
||||
//! This is separate from `scheduler.rs` which handles agent resource tracking.
|
||||
//!
|
||||
//! The scheduler stores jobs in a `DashMap` for concurrent access, persists
|
||||
//! them to a JSON file on disk, and exposes methods for the kernel tick loop
|
||||
//! to query due jobs and record outcomes.
|
||||
|
||||
use chrono::{Duration, Utc};
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::AgentId;
|
||||
use openfang_types::error::{OpenFangError, OpenFangResult};
|
||||
use openfang_types::scheduler::{CronJob, CronJobId, CronSchedule};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Maximum consecutive errors before a job is auto-disabled.
|
||||
const MAX_CONSECUTIVE_ERRORS: u32 = 5;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// JobMeta — extra runtime state not stored in CronJob itself
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Runtime metadata for a cron job that extends the base `CronJob` type.
|
||||
///
|
||||
/// The `CronJob` struct in `openfang-types` is intentionally lean (no
|
||||
/// `one_shot`, `last_status`, or error tracking). The scheduler tracks
|
||||
/// these operational details separately.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JobMeta {
|
||||
/// The underlying job definition.
|
||||
pub job: CronJob,
|
||||
/// Whether this job should be removed after a single successful execution.
|
||||
pub one_shot: bool,
|
||||
/// Human-readable status of the last execution (e.g. `"ok"` or `"error: ..."`).
|
||||
pub last_status: Option<String>,
|
||||
/// Number of consecutive failed executions.
|
||||
pub consecutive_errors: u32,
|
||||
}
|
||||
|
||||
impl JobMeta {
|
||||
/// Wrap a `CronJob` with default metadata.
|
||||
pub fn new(job: CronJob, one_shot: bool) -> Self {
|
||||
Self {
|
||||
job,
|
||||
one_shot,
|
||||
last_status: None,
|
||||
consecutive_errors: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CronScheduler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Cron job scheduler — manages scheduled jobs for all agents.
|
||||
///
|
||||
/// Thread-safe via `DashMap`. The kernel should call [`due_jobs`] on a
|
||||
/// regular interval (e.g. every 10-30 seconds) to discover jobs that need
|
||||
/// to fire, then call [`record_success`] or [`record_failure`] after
|
||||
/// execution completes.
|
||||
pub struct CronScheduler {
|
||||
/// All tracked jobs, keyed by their unique ID.
|
||||
jobs: DashMap<CronJobId, JobMeta>,
|
||||
/// Path to the persistence file (`<home>/cron_jobs.json`).
|
||||
persist_path: PathBuf,
|
||||
/// Global cap on total jobs across all agents (atomic for hot-reload).
|
||||
max_total_jobs: AtomicUsize,
|
||||
}
|
||||
|
||||
impl CronScheduler {
|
||||
/// Create a new scheduler.
|
||||
///
|
||||
/// `home_dir` is the OpenFang data directory; jobs are persisted to
|
||||
/// `<home_dir>/cron_jobs.json`. `max_total_jobs` caps the total number
|
||||
/// of jobs across all agents.
|
||||
pub fn new(home_dir: &Path, max_total_jobs: usize) -> Self {
|
||||
Self {
|
||||
jobs: DashMap::new(),
|
||||
persist_path: home_dir.join("cron_jobs.json"),
|
||||
max_total_jobs: AtomicUsize::new(max_total_jobs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the max total jobs limit (for hot-reload).
|
||||
pub fn set_max_total_jobs(&self, new_max: usize) {
|
||||
self.max_total_jobs.store(new_max, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
// -- Persistence --------------------------------------------------------
|
||||
|
||||
/// Load persisted jobs from disk.
|
||||
///
|
||||
/// Returns the number of jobs loaded. If the persistence file does not
|
||||
/// exist, returns `Ok(0)` without error.
|
||||
pub fn load(&self) -> OpenFangResult<usize> {
|
||||
if !self.persist_path.exists() {
|
||||
return Ok(0);
|
||||
}
|
||||
let data = std::fs::read_to_string(&self.persist_path)
|
||||
.map_err(|e| OpenFangError::Internal(format!("Failed to read cron jobs: {e}")))?;
|
||||
let metas: Vec<JobMeta> = serde_json::from_str(&data)
|
||||
.map_err(|e| OpenFangError::Internal(format!("Failed to parse cron jobs: {e}")))?;
|
||||
let count = metas.len();
|
||||
for meta in metas {
|
||||
self.jobs.insert(meta.job.id, meta);
|
||||
}
|
||||
info!(count, "Loaded cron jobs from disk");
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// Persist all jobs to disk via atomic write (write to `.tmp`, then rename).
|
||||
pub fn persist(&self) -> OpenFangResult<()> {
|
||||
let metas: Vec<JobMeta> = self.jobs.iter().map(|r| r.value().clone()).collect();
|
||||
let data = serde_json::to_string_pretty(&metas)
|
||||
.map_err(|e| OpenFangError::Internal(format!("Failed to serialize cron jobs: {e}")))?;
|
||||
let tmp_path = self.persist_path.with_extension("json.tmp");
|
||||
std::fs::write(&tmp_path, data.as_bytes()).map_err(|e| {
|
||||
OpenFangError::Internal(format!("Failed to write cron jobs temp file: {e}"))
|
||||
})?;
|
||||
std::fs::rename(&tmp_path, &self.persist_path).map_err(|e| {
|
||||
OpenFangError::Internal(format!("Failed to rename cron jobs file: {e}"))
|
||||
})?;
|
||||
debug!(count = metas.len(), "Persisted cron jobs");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -- CRUD ---------------------------------------------------------------
|
||||
|
||||
/// Add a new job. Validates fields, computes the initial `next_run`,
|
||||
/// and inserts it into the scheduler.
|
||||
///
|
||||
/// `one_shot` controls whether the job is removed after a single
|
||||
/// successful execution.
|
||||
pub fn add_job(&self, mut job: CronJob, one_shot: bool) -> OpenFangResult<CronJobId> {
|
||||
// Global limit
|
||||
let max_jobs = self.max_total_jobs.load(Ordering::Relaxed);
|
||||
if self.jobs.len() >= max_jobs {
|
||||
return Err(OpenFangError::Internal(format!(
|
||||
"Global cron job limit reached ({})",
|
||||
max_jobs
|
||||
)));
|
||||
}
|
||||
|
||||
// Per-agent count
|
||||
let agent_count = self
|
||||
.jobs
|
||||
.iter()
|
||||
.filter(|r| r.value().job.agent_id == job.agent_id)
|
||||
.count();
|
||||
|
||||
// CronJob.validate returns Result<(), String>
|
||||
job.validate(agent_count)
|
||||
.map_err(OpenFangError::InvalidInput)?;
|
||||
|
||||
// Compute initial next_run
|
||||
job.next_run = Some(compute_next_run(&job.schedule));
|
||||
|
||||
let id = job.id;
|
||||
self.jobs.insert(id, JobMeta::new(job, one_shot));
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Remove a job by ID. Returns the removed `CronJob`.
|
||||
pub fn remove_job(&self, id: CronJobId) -> OpenFangResult<CronJob> {
|
||||
self.jobs
|
||||
.remove(&id)
|
||||
.map(|(_, meta)| meta.job)
|
||||
.ok_or_else(|| OpenFangError::Internal(format!("Cron job {id} not found")))
|
||||
}
|
||||
|
||||
/// Enable or disable a job. Re-enabling resets errors and recomputes
|
||||
/// `next_run`.
|
||||
pub fn set_enabled(&self, id: CronJobId, enabled: bool) -> OpenFangResult<()> {
|
||||
match self.jobs.get_mut(&id) {
|
||||
Some(mut meta) => {
|
||||
meta.job.enabled = enabled;
|
||||
if enabled {
|
||||
meta.consecutive_errors = 0;
|
||||
meta.job.next_run = Some(compute_next_run(&meta.job.schedule));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
None => Err(OpenFangError::Internal(format!("Cron job {id} not found"))),
|
||||
}
|
||||
}
|
||||
|
||||
// -- Queries ------------------------------------------------------------
|
||||
|
||||
/// Get a single job by ID.
|
||||
pub fn get_job(&self, id: CronJobId) -> Option<CronJob> {
|
||||
self.jobs.get(&id).map(|r| r.value().job.clone())
|
||||
}
|
||||
|
||||
/// Get the full metadata for a job (includes `one_shot`, `last_status`,
|
||||
/// `consecutive_errors`).
|
||||
pub fn get_meta(&self, id: CronJobId) -> Option<JobMeta> {
|
||||
self.jobs.get(&id).map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
/// List all jobs for a specific agent.
|
||||
pub fn list_jobs(&self, agent_id: AgentId) -> Vec<CronJob> {
|
||||
self.jobs
|
||||
.iter()
|
||||
.filter(|r| r.value().job.agent_id == agent_id)
|
||||
.map(|r| r.value().job.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// List all jobs across all agents.
|
||||
pub fn list_all_jobs(&self) -> Vec<CronJob> {
|
||||
self.jobs.iter().map(|r| r.value().job.clone()).collect()
|
||||
}
|
||||
|
||||
/// Total number of tracked jobs.
|
||||
pub fn total_jobs(&self) -> usize {
|
||||
self.jobs.len()
|
||||
}
|
||||
|
||||
/// Return jobs whose `next_run` is at or before `now` and are enabled.
|
||||
///
|
||||
/// **Important**: This also pre-advances each due job's `next_run` to the
|
||||
/// next scheduled time. This prevents the same job from being returned as
|
||||
/// "due" on subsequent tick iterations while it's still executing.
|
||||
pub fn due_jobs(&self) -> Vec<CronJob> {
|
||||
let now = Utc::now();
|
||||
let mut due = Vec::new();
|
||||
for mut entry in self.jobs.iter_mut() {
|
||||
let meta = entry.value_mut();
|
||||
if meta.job.enabled && meta.job.next_run.map(|t| t <= now).unwrap_or(false) {
|
||||
due.push(meta.job.clone());
|
||||
// Pre-advance next_run so the job won't fire again on the next
|
||||
// tick while it's still executing. Use `now` as the base so the
|
||||
// next fire time is computed strictly after the current moment.
|
||||
meta.job.next_run = Some(compute_next_run_after(&meta.job.schedule, now));
|
||||
}
|
||||
}
|
||||
due
|
||||
}
|
||||
|
||||
// -- Outcome recording --------------------------------------------------
|
||||
|
||||
/// Record a successful execution for a job.
|
||||
///
|
||||
/// Updates `last_run`, resets errors, and either removes the job (if
|
||||
/// one-shot) or advances `next_run`.
|
||||
pub fn record_success(&self, id: CronJobId) {
|
||||
// We need to check one_shot first, then potentially remove.
|
||||
let should_remove = {
|
||||
if let Some(mut meta) = self.jobs.get_mut(&id) {
|
||||
meta.job.last_run = Some(Utc::now());
|
||||
meta.last_status = Some("ok".to_string());
|
||||
meta.consecutive_errors = 0;
|
||||
// one_shot jobs get removed; recurring jobs keep the next_run
|
||||
// already pre-advanced by due_jobs() — no recompute needed.
|
||||
meta.one_shot
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
};
|
||||
if should_remove {
|
||||
self.jobs.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a failed execution for a job.
|
||||
///
|
||||
/// Increments the consecutive error counter. If it reaches
|
||||
/// [`MAX_CONSECUTIVE_ERRORS`], the job is automatically disabled.
|
||||
pub fn record_failure(&self, id: CronJobId, error_msg: &str) {
|
||||
if let Some(mut meta) = self.jobs.get_mut(&id) {
|
||||
meta.job.last_run = Some(Utc::now());
|
||||
meta.last_status = Some(format!(
|
||||
"error: {}",
|
||||
openfang_types::truncate_str(error_msg, 256)
|
||||
));
|
||||
meta.consecutive_errors += 1;
|
||||
if meta.consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
|
||||
warn!(
|
||||
job_id = %id,
|
||||
errors = meta.consecutive_errors,
|
||||
"Auto-disabling cron job after repeated failures"
|
||||
);
|
||||
meta.job.enabled = false;
|
||||
} else {
|
||||
meta.job.next_run =
|
||||
Some(compute_next_run_after(&meta.job.schedule, Utc::now()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// compute_next_run
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute the next fire time for a schedule, based on `now`.
|
||||
///
|
||||
/// - `At { at }` — returns `at` directly.
|
||||
/// - `Every { every_secs }` — returns `now + every_secs`.
|
||||
/// - `Cron { expr, tz }` — parses the cron expression and computes the next
|
||||
/// matching time. Supports standard 5-field (`min hour dom month dow`) and
|
||||
/// 6-field (`sec min hour dom month dow`) formats by converting to the
|
||||
/// 7-field format required by the `cron` crate.
|
||||
pub fn compute_next_run(schedule: &CronSchedule) -> chrono::DateTime<Utc> {
|
||||
compute_next_run_after(schedule, Utc::now())
|
||||
}
|
||||
|
||||
/// Compute the next fire time for a schedule, strictly after `after`.
|
||||
///
|
||||
/// Uses `after + 1 second` as the base time so the `cron` crate's
|
||||
/// inclusive `.after()` always returns a strictly future time. Without
|
||||
/// this offset, calling `compute_next_run` right after a job fires can
|
||||
/// return the same minute (or even the same second), causing the
|
||||
/// scheduler to re-fire immediately.
|
||||
pub fn compute_next_run_after(
|
||||
schedule: &CronSchedule,
|
||||
after: chrono::DateTime<Utc>,
|
||||
) -> chrono::DateTime<Utc> {
|
||||
match schedule {
|
||||
CronSchedule::At { at } => *at,
|
||||
CronSchedule::Every { every_secs } => after + Duration::seconds(*every_secs as i64),
|
||||
CronSchedule::Cron { expr, tz: _ } => {
|
||||
// Convert standard 5/6-field cron to 7-field for the `cron` crate.
|
||||
// Standard 5-field: min hour dom month dow
|
||||
// 6-field: sec min hour dom month dow
|
||||
// cron crate: sec min hour dom month dow year
|
||||
let trimmed = expr.trim();
|
||||
let fields: Vec<&str> = trimmed.split_whitespace().collect();
|
||||
let seven_field = match fields.len() {
|
||||
5 => format!("0 {trimmed} *"),
|
||||
6 => format!("{trimmed} *"),
|
||||
_ => expr.clone(),
|
||||
};
|
||||
|
||||
// Add 1 second so `.after()` (inclusive) skips the current second.
|
||||
let base = after + Duration::seconds(1);
|
||||
|
||||
match seven_field.parse::<cron::Schedule>() {
|
||||
Ok(sched) => sched
|
||||
.after(&base)
|
||||
.next()
|
||||
.unwrap_or_else(|| after + Duration::hours(1)),
|
||||
Err(e) => {
|
||||
warn!("Failed to parse cron expression '{}': {}", expr, e);
|
||||
after + Duration::hours(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Duration;
|
||||
use openfang_types::scheduler::{CronAction, CronDelivery};
|
||||
|
||||
/// Build a minimal valid `CronJob` with an `Every` schedule.
|
||||
fn make_job(agent_id: AgentId) -> CronJob {
|
||||
CronJob {
|
||||
id: CronJobId::new(),
|
||||
agent_id,
|
||||
name: "test-job".into(),
|
||||
enabled: true,
|
||||
schedule: CronSchedule::Every { every_secs: 3600 },
|
||||
action: CronAction::SystemEvent {
|
||||
text: "ping".into(),
|
||||
},
|
||||
delivery: CronDelivery::None,
|
||||
created_at: Utc::now(),
|
||||
last_run: None,
|
||||
next_run: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a scheduler backed by a temp directory.
|
||||
fn make_scheduler(max_total: usize) -> (CronScheduler, tempfile::TempDir) {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let sched = CronScheduler::new(tmp.path(), max_total);
|
||||
(sched, tmp)
|
||||
}
|
||||
|
||||
// -- test_add_job_and_list ----------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_add_job_and_list() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
let job = make_job(agent);
|
||||
|
||||
let id = sched.add_job(job, false).unwrap();
|
||||
|
||||
// Should appear in agent list
|
||||
let jobs = sched.list_jobs(agent);
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].id, id);
|
||||
assert_eq!(jobs[0].name, "test-job");
|
||||
|
||||
// Should appear in global list
|
||||
let all = sched.list_all_jobs();
|
||||
assert_eq!(all.len(), 1);
|
||||
|
||||
// get_job should return it
|
||||
let fetched = sched.get_job(id).unwrap();
|
||||
assert_eq!(fetched.agent_id, agent);
|
||||
|
||||
// next_run should have been computed
|
||||
assert!(fetched.next_run.is_some());
|
||||
assert_eq!(sched.total_jobs(), 1);
|
||||
}
|
||||
|
||||
// -- test_remove_job ----------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_remove_job() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
let job = make_job(agent);
|
||||
let id = sched.add_job(job, false).unwrap();
|
||||
|
||||
let removed = sched.remove_job(id).unwrap();
|
||||
assert_eq!(removed.name, "test-job");
|
||||
assert_eq!(sched.total_jobs(), 0);
|
||||
|
||||
// Removing again should fail
|
||||
assert!(sched.remove_job(id).is_err());
|
||||
}
|
||||
|
||||
// -- test_add_job_global_limit ------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_add_job_global_limit() {
|
||||
let (sched, _tmp) = make_scheduler(2);
|
||||
let agent = AgentId::new();
|
||||
|
||||
let j1 = make_job(agent);
|
||||
let j2 = make_job(agent);
|
||||
let j3 = make_job(agent);
|
||||
|
||||
sched.add_job(j1, false).unwrap();
|
||||
sched.add_job(j2, false).unwrap();
|
||||
|
||||
// Third should hit global limit
|
||||
let err = sched.add_job(j3, false).unwrap_err();
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("limit"),
|
||||
"Expected global limit error, got: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
// -- test_add_job_per_agent_limit ---------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_add_job_per_agent_limit() {
|
||||
// MAX_JOBS_PER_AGENT = 50 in openfang-types
|
||||
let (sched, _tmp) = make_scheduler(1000);
|
||||
let agent = AgentId::new();
|
||||
|
||||
for i in 0..50 {
|
||||
let mut job = make_job(agent);
|
||||
job.name = format!("job-{i}");
|
||||
sched.add_job(job, false).unwrap();
|
||||
}
|
||||
|
||||
// 51st should be rejected by validate()
|
||||
let mut overflow = make_job(agent);
|
||||
overflow.name = "overflow".into();
|
||||
let err = sched.add_job(overflow, false).unwrap_err();
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("50"),
|
||||
"Expected per-agent limit error, got: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
// -- test_record_success_removes_one_shot --------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_record_success_removes_one_shot() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
let job = make_job(agent);
|
||||
let id = sched.add_job(job, true).unwrap(); // one_shot = true
|
||||
|
||||
assert_eq!(sched.total_jobs(), 1);
|
||||
|
||||
sched.record_success(id);
|
||||
|
||||
// One-shot job should have been removed
|
||||
assert_eq!(sched.total_jobs(), 0);
|
||||
assert!(sched.get_job(id).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_success_keeps_recurring() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
let job = make_job(agent);
|
||||
let id = sched.add_job(job, false).unwrap(); // one_shot = false
|
||||
|
||||
sched.record_success(id);
|
||||
|
||||
// Recurring job should still be there
|
||||
assert_eq!(sched.total_jobs(), 1);
|
||||
let meta = sched.get_meta(id).unwrap();
|
||||
assert_eq!(meta.last_status.as_deref(), Some("ok"));
|
||||
assert_eq!(meta.consecutive_errors, 0);
|
||||
assert!(meta.job.last_run.is_some());
|
||||
}
|
||||
|
||||
// -- test_record_failure_auto_disable -----------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_record_failure_auto_disable() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
let job = make_job(agent);
|
||||
let id = sched.add_job(job, false).unwrap();
|
||||
|
||||
// Fail MAX_CONSECUTIVE_ERRORS - 1 times: should still be enabled
|
||||
for i in 0..(MAX_CONSECUTIVE_ERRORS - 1) {
|
||||
sched.record_failure(id, &format!("error {i}"));
|
||||
let meta = sched.get_meta(id).unwrap();
|
||||
assert!(
|
||||
meta.job.enabled,
|
||||
"Job should still be enabled after {} failures",
|
||||
i + 1
|
||||
);
|
||||
assert_eq!(meta.consecutive_errors, i + 1);
|
||||
}
|
||||
|
||||
// One more failure should auto-disable
|
||||
sched.record_failure(id, "final error");
|
||||
let meta = sched.get_meta(id).unwrap();
|
||||
assert!(
|
||||
!meta.job.enabled,
|
||||
"Job should be auto-disabled after {MAX_CONSECUTIVE_ERRORS} failures"
|
||||
);
|
||||
assert_eq!(meta.consecutive_errors, MAX_CONSECUTIVE_ERRORS);
|
||||
assert!(
|
||||
meta.last_status.as_ref().unwrap().starts_with("error:"),
|
||||
"last_status should record the error"
|
||||
);
|
||||
}
|
||||
|
||||
// -- test_due_jobs_only_enabled -----------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_due_jobs_only_enabled() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
|
||||
// Job 1: enabled, next_run in the past
|
||||
let mut j1 = make_job(agent);
|
||||
j1.name = "enabled-due".into();
|
||||
let id1 = sched.add_job(j1, false).unwrap();
|
||||
|
||||
// Job 2: disabled
|
||||
let mut j2 = make_job(agent);
|
||||
j2.name = "disabled-job".into();
|
||||
let id2 = sched.add_job(j2, false).unwrap();
|
||||
sched.set_enabled(id2, false).unwrap();
|
||||
|
||||
// Force job 1's next_run to the past
|
||||
if let Some(mut meta) = sched.jobs.get_mut(&id1) {
|
||||
meta.job.next_run = Some(Utc::now() - Duration::seconds(10));
|
||||
}
|
||||
|
||||
// Force job 2's next_run to the past too (but it's disabled)
|
||||
if let Some(mut meta) = sched.jobs.get_mut(&id2) {
|
||||
meta.job.next_run = Some(Utc::now() - Duration::seconds(10));
|
||||
}
|
||||
|
||||
let due = sched.due_jobs();
|
||||
assert_eq!(due.len(), 1);
|
||||
assert_eq!(due[0].name, "enabled-due");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_due_jobs_future_not_included() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
|
||||
let job = make_job(agent);
|
||||
sched.add_job(job, false).unwrap();
|
||||
|
||||
// The job was just added with next_run = now + 3600s, so it should
|
||||
// not be due yet.
|
||||
let due = sched.due_jobs();
|
||||
assert!(due.is_empty());
|
||||
}
|
||||
|
||||
// -- test_set_enabled ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_set_enabled() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
|
||||
let job = make_job(agent);
|
||||
let id = sched.add_job(job, false).unwrap();
|
||||
|
||||
// Disable
|
||||
sched.set_enabled(id, false).unwrap();
|
||||
let meta = sched.get_meta(id).unwrap();
|
||||
assert!(!meta.job.enabled);
|
||||
|
||||
// Re-enable resets error count
|
||||
sched.record_failure(id, "ignored because disabled");
|
||||
// Actually the job is disabled so record_failure still updates it.
|
||||
// Let's first re-enable to test reset.
|
||||
sched.set_enabled(id, true).unwrap();
|
||||
let meta = sched.get_meta(id).unwrap();
|
||||
assert!(meta.job.enabled);
|
||||
assert_eq!(meta.consecutive_errors, 0);
|
||||
assert!(meta.job.next_run.is_some());
|
||||
|
||||
// Non-existent ID should fail
|
||||
let fake_id = CronJobId::new();
|
||||
assert!(sched.set_enabled(fake_id, true).is_err());
|
||||
}
|
||||
|
||||
// -- test_persist_and_load ----------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_persist_and_load() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let agent = AgentId::new();
|
||||
|
||||
// Create scheduler, add jobs, persist
|
||||
{
|
||||
let sched = CronScheduler::new(tmp.path(), 100);
|
||||
let mut j1 = make_job(agent);
|
||||
j1.name = "persist-a".into();
|
||||
let mut j2 = make_job(agent);
|
||||
j2.name = "persist-b".into();
|
||||
|
||||
sched.add_job(j1, false).unwrap();
|
||||
sched.add_job(j2, true).unwrap(); // one_shot
|
||||
|
||||
sched.persist().unwrap();
|
||||
}
|
||||
|
||||
// Create a new scheduler and load from disk
|
||||
{
|
||||
let sched = CronScheduler::new(tmp.path(), 100);
|
||||
let count = sched.load().unwrap();
|
||||
assert_eq!(count, 2);
|
||||
assert_eq!(sched.total_jobs(), 2);
|
||||
|
||||
let jobs = sched.list_jobs(agent);
|
||||
assert_eq!(jobs.len(), 2);
|
||||
|
||||
let names: Vec<&str> = jobs.iter().map(|j| j.name.as_str()).collect();
|
||||
assert!(names.contains(&"persist-a"));
|
||||
assert!(names.contains(&"persist-b"));
|
||||
|
||||
// Verify one_shot flag was preserved
|
||||
let b_id = jobs.iter().find(|j| j.name == "persist-b").unwrap().id;
|
||||
let meta = sched.get_meta(b_id).unwrap();
|
||||
assert!(meta.one_shot);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_no_file_returns_zero() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let sched = CronScheduler::new(tmp.path(), 100);
|
||||
assert_eq!(sched.load().unwrap(), 0);
|
||||
}
|
||||
|
||||
// -- compute_next_run ---------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_compute_next_run_at() {
|
||||
let target = Utc::now() + Duration::hours(2);
|
||||
let schedule = CronSchedule::At { at: target };
|
||||
let next = compute_next_run(&schedule);
|
||||
assert_eq!(next, target);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_next_run_every() {
|
||||
let before = Utc::now();
|
||||
let schedule = CronSchedule::Every { every_secs: 300 };
|
||||
let next = compute_next_run(&schedule);
|
||||
let after = Utc::now();
|
||||
|
||||
// Should be roughly now + 300s
|
||||
assert!(next >= before + Duration::seconds(300));
|
||||
assert!(next <= after + Duration::seconds(300));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_next_run_cron_daily() {
|
||||
let now = Utc::now();
|
||||
let schedule = CronSchedule::Cron {
|
||||
expr: "0 9 * * *".into(),
|
||||
tz: None,
|
||||
};
|
||||
let next = compute_next_run(&schedule);
|
||||
|
||||
// Should be within the next 24 hours (next 09:00 UTC)
|
||||
assert!(next > now);
|
||||
assert!(next <= now + Duration::hours(24));
|
||||
assert_eq!(next.format("%M").to_string(), "00");
|
||||
assert_eq!(next.format("%H").to_string(), "09");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_next_run_cron_with_dow() {
|
||||
let now = Utc::now();
|
||||
let schedule = CronSchedule::Cron {
|
||||
expr: "30 14 * * 1-5".into(),
|
||||
tz: None,
|
||||
};
|
||||
let next = compute_next_run(&schedule);
|
||||
|
||||
// Should be within the next 7 days and at 14:30
|
||||
assert!(next > now);
|
||||
assert!(next <= now + Duration::days(7));
|
||||
assert_eq!(next.format("%H:%M").to_string(), "14:30");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_next_run_cron_invalid_expr() {
|
||||
let now = Utc::now();
|
||||
let schedule = CronSchedule::Cron {
|
||||
expr: "not a cron".into(),
|
||||
tz: None,
|
||||
};
|
||||
let next = compute_next_run(&schedule);
|
||||
// Invalid expression falls back to 1 hour from now
|
||||
assert!(next > now + Duration::minutes(59));
|
||||
assert!(next <= now + Duration::minutes(61));
|
||||
}
|
||||
|
||||
// -- error message truncation in record_failure -------------------------
|
||||
|
||||
#[test]
|
||||
fn test_compute_next_run_after_skips_current_second() {
|
||||
// A "every 4 hours" cron: next_run should be >= 4 hours from now,
|
||||
// not in the same minute (the bug from #55).
|
||||
let schedule = CronSchedule::Cron {
|
||||
expr: "0 */4 * * *".into(),
|
||||
tz: None,
|
||||
};
|
||||
let now = Utc::now();
|
||||
let next = compute_next_run_after(&schedule, now);
|
||||
// Must be strictly after `now` and at least ~1 hour away
|
||||
// (the closest 4-hourly boundary is at least minutes away).
|
||||
assert!(next > now, "next_run should be strictly after now");
|
||||
let diff = next - now;
|
||||
assert!(
|
||||
diff.num_minutes() >= 1,
|
||||
"Expected next_run at least 1 min away, got {} seconds",
|
||||
diff.num_seconds()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_failure_truncates_long_error() {
|
||||
let (sched, _tmp) = make_scheduler(100);
|
||||
let agent = AgentId::new();
|
||||
let job = make_job(agent);
|
||||
let id = sched.add_job(job, false).unwrap();
|
||||
|
||||
let long_error = "x".repeat(1000);
|
||||
sched.record_failure(id, &long_error);
|
||||
|
||||
let meta = sched.get_meta(id).unwrap();
|
||||
let status = meta.last_status.unwrap();
|
||||
// "error: " is 7 chars + 256 chars of truncated message = 263 max
|
||||
assert!(
|
||||
status.len() <= 263,
|
||||
"Status should be truncated, got {} chars",
|
||||
status.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
19
crates/openfang-kernel/src/error.rs
Normal file
19
crates/openfang-kernel/src/error.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
//! Kernel-specific error types.
|
||||
|
||||
use openfang_types::error::OpenFangError;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Kernel error type wrapping OpenFangError with kernel-specific context.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum KernelError {
|
||||
/// A wrapped OpenFangError.
|
||||
#[error(transparent)]
|
||||
OpenFang(#[from] OpenFangError),
|
||||
|
||||
/// The kernel failed to boot.
|
||||
#[error("Boot failed: {0}")]
|
||||
BootFailed(String),
|
||||
}
|
||||
|
||||
/// Alias for kernel results.
|
||||
pub type KernelResult<T> = Result<T, KernelError>;
|
||||
149
crates/openfang-kernel/src/event_bus.rs
Normal file
149
crates/openfang-kernel/src/event_bus.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
//! Event bus — pub/sub with pattern matching and history ring buffer.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::AgentId;
|
||||
use openfang_types::event::{Event, EventTarget};
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::debug;
|
||||
|
||||
/// Maximum events retained in the history ring buffer.
|
||||
const HISTORY_SIZE: usize = 1000;
|
||||
|
||||
/// The central event bus for inter-agent and system communication.
|
||||
pub struct EventBus {
|
||||
/// Broadcast channel for all events.
|
||||
sender: broadcast::Sender<Event>,
|
||||
/// Per-agent event channels.
|
||||
agent_channels: DashMap<AgentId, broadcast::Sender<Event>>,
|
||||
/// Event history ring buffer.
|
||||
history: Arc<RwLock<VecDeque<Event>>>,
|
||||
}
|
||||
|
||||
impl EventBus {
|
||||
/// Create a new event bus.
|
||||
pub fn new() -> Self {
|
||||
let (sender, _) = broadcast::channel(1024);
|
||||
Self {
|
||||
sender,
|
||||
agent_channels: DashMap::new(),
|
||||
history: Arc::new(RwLock::new(VecDeque::with_capacity(HISTORY_SIZE))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Publish an event to the bus.
|
||||
pub async fn publish(&self, event: Event) {
|
||||
debug!(
|
||||
event_id = %event.id,
|
||||
source = %event.source,
|
||||
"Publishing event"
|
||||
);
|
||||
|
||||
// Store in history
|
||||
{
|
||||
let mut history = self.history.write().await;
|
||||
if history.len() >= HISTORY_SIZE {
|
||||
history.pop_front();
|
||||
}
|
||||
history.push_back(event.clone());
|
||||
}
|
||||
|
||||
// Route to target
|
||||
match &event.target {
|
||||
EventTarget::Agent(agent_id) => {
|
||||
if let Some(sender) = self.agent_channels.get(agent_id) {
|
||||
let _ = sender.send(event.clone());
|
||||
}
|
||||
}
|
||||
EventTarget::Broadcast => {
|
||||
let _ = self.sender.send(event.clone());
|
||||
for entry in self.agent_channels.iter() {
|
||||
let _ = entry.value().send(event.clone());
|
||||
}
|
||||
}
|
||||
EventTarget::Pattern(_pattern) => {
|
||||
// Phase 1: broadcast to all for pattern matching
|
||||
let _ = self.sender.send(event.clone());
|
||||
}
|
||||
EventTarget::System => {
|
||||
let _ = self.sender.send(event.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to events for a specific agent.
|
||||
pub fn subscribe_agent(&self, agent_id: AgentId) -> broadcast::Receiver<Event> {
|
||||
let entry = self.agent_channels.entry(agent_id).or_insert_with(|| {
|
||||
let (tx, _) = broadcast::channel(256);
|
||||
tx
|
||||
});
|
||||
entry.subscribe()
|
||||
}
|
||||
|
||||
/// Subscribe to all broadcast/system events.
|
||||
pub fn subscribe_all(&self) -> broadcast::Receiver<Event> {
|
||||
self.sender.subscribe()
|
||||
}
|
||||
|
||||
/// Get recent event history.
|
||||
pub async fn history(&self, limit: usize) -> Vec<Event> {
|
||||
let history = self.history.read().await;
|
||||
history.iter().rev().take(limit).cloned().collect()
|
||||
}
|
||||
|
||||
/// Remove an agent's channel when it's terminated.
|
||||
pub fn unsubscribe_agent(&self, agent_id: AgentId) {
|
||||
self.agent_channels.remove(&agent_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EventBus {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use openfang_types::event::{EventPayload, SystemEvent};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_publish_and_history() {
|
||||
let bus = EventBus::new();
|
||||
let agent_id = AgentId::new();
|
||||
let event = Event::new(
|
||||
agent_id,
|
||||
EventTarget::System,
|
||||
EventPayload::System(SystemEvent::KernelStarted),
|
||||
);
|
||||
bus.publish(event).await;
|
||||
let history = bus.history(10).await;
|
||||
assert_eq!(history.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_subscribe() {
|
||||
let bus = EventBus::new();
|
||||
let agent_id = AgentId::new();
|
||||
let mut rx = bus.subscribe_agent(agent_id);
|
||||
|
||||
let event = Event::new(
|
||||
AgentId::new(),
|
||||
EventTarget::Agent(agent_id),
|
||||
EventPayload::System(SystemEvent::HealthCheck {
|
||||
status: "ok".to_string(),
|
||||
}),
|
||||
);
|
||||
bus.publish(event).await;
|
||||
|
||||
let received = rx.recv().await.unwrap();
|
||||
match received.payload {
|
||||
EventPayload::System(SystemEvent::HealthCheck { status }) => {
|
||||
assert_eq!(status, "ok");
|
||||
}
|
||||
_ => panic!("Wrong payload"),
|
||||
}
|
||||
}
|
||||
}
|
||||
245
crates/openfang-kernel/src/heartbeat.rs
Normal file
245
crates/openfang-kernel/src/heartbeat.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
//! Heartbeat monitor — detects unresponsive agents for 24/7 autonomous operation.
|
||||
//!
|
||||
//! The heartbeat monitor runs as a background tokio task, periodically checking
|
||||
//! each running agent's `last_active` timestamp. If an agent hasn't been active
|
||||
//! for longer than 2x its heartbeat interval, a `HealthCheckFailed` event is
|
||||
//! published to the event bus.
|
||||
|
||||
use crate::registry::AgentRegistry;
|
||||
use chrono::Utc;
|
||||
use openfang_types::agent::{AgentId, AgentState};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Default heartbeat check interval (seconds).
|
||||
const DEFAULT_CHECK_INTERVAL_SECS: u64 = 30;
|
||||
|
||||
/// Multiplier: agent is considered unresponsive if inactive for this many
|
||||
/// multiples of its heartbeat interval.
|
||||
const UNRESPONSIVE_MULTIPLIER: u64 = 2;
|
||||
|
||||
/// Result of a heartbeat check.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeartbeatStatus {
|
||||
/// Agent ID.
|
||||
pub agent_id: AgentId,
|
||||
/// Agent name.
|
||||
pub name: String,
|
||||
/// Seconds since last activity.
|
||||
pub inactive_secs: i64,
|
||||
/// Whether the agent is considered unresponsive.
|
||||
pub unresponsive: bool,
|
||||
}
|
||||
|
||||
/// Heartbeat monitor configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeartbeatConfig {
|
||||
/// How often to run the heartbeat check (seconds).
|
||||
pub check_interval_secs: u64,
|
||||
/// Default threshold for unresponsiveness (seconds).
|
||||
/// Overridden per-agent by AutonomousConfig.heartbeat_interval_secs.
|
||||
pub default_timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for HeartbeatConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
check_interval_secs: DEFAULT_CHECK_INTERVAL_SECS,
|
||||
default_timeout_secs: DEFAULT_CHECK_INTERVAL_SECS * UNRESPONSIVE_MULTIPLIER,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check all running agents and return their heartbeat status.
|
||||
///
|
||||
/// This is a pure function — it doesn't start a background task.
|
||||
/// The caller (kernel) can run this periodically or in a background task.
|
||||
pub fn check_agents(registry: &AgentRegistry, config: &HeartbeatConfig) -> Vec<HeartbeatStatus> {
|
||||
let now = Utc::now();
|
||||
let mut statuses = Vec::new();
|
||||
|
||||
for entry_ref in registry.list() {
|
||||
// Only check running agents
|
||||
if entry_ref.state != AgentState::Running {
|
||||
continue;
|
||||
}
|
||||
|
||||
let inactive_secs = (now - entry_ref.last_active).num_seconds();
|
||||
|
||||
// Determine timeout: use agent's autonomous config if set, else default
|
||||
let timeout_secs = entry_ref
|
||||
.manifest
|
||||
.autonomous
|
||||
.as_ref()
|
||||
.map(|a| a.heartbeat_interval_secs * UNRESPONSIVE_MULTIPLIER)
|
||||
.unwrap_or(config.default_timeout_secs) as i64;
|
||||
|
||||
let unresponsive = inactive_secs > timeout_secs;
|
||||
|
||||
if unresponsive {
|
||||
warn!(
|
||||
agent = %entry_ref.name,
|
||||
inactive_secs,
|
||||
timeout_secs,
|
||||
"Agent is unresponsive"
|
||||
);
|
||||
} else {
|
||||
debug!(
|
||||
agent = %entry_ref.name,
|
||||
inactive_secs,
|
||||
"Agent heartbeat OK"
|
||||
);
|
||||
}
|
||||
|
||||
statuses.push(HeartbeatStatus {
|
||||
agent_id: entry_ref.id,
|
||||
name: entry_ref.name.clone(),
|
||||
inactive_secs,
|
||||
unresponsive,
|
||||
});
|
||||
}
|
||||
|
||||
statuses
|
||||
}
|
||||
|
||||
/// Check if an agent is currently within its quiet hours.
|
||||
///
|
||||
/// Quiet hours format: "HH:MM-HH:MM" (24-hour format, UTC).
|
||||
/// Returns true if the current time falls within the quiet period.
|
||||
pub fn is_quiet_hours(quiet_hours: &str) -> bool {
|
||||
let parts: Vec<&str> = quiet_hours.split('-').collect();
|
||||
if parts.len() != 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
let current_minutes = now.format("%H").to_string().parse::<u32>().unwrap_or(0) * 60
|
||||
+ now.format("%M").to_string().parse::<u32>().unwrap_or(0);
|
||||
|
||||
let parse_time = |s: &str| -> Option<u32> {
|
||||
let hm: Vec<&str> = s.trim().split(':').collect();
|
||||
if hm.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
let h = hm[0].parse::<u32>().ok()?;
|
||||
let m = hm[1].parse::<u32>().ok()?;
|
||||
if h > 23 || m > 59 {
|
||||
return None;
|
||||
}
|
||||
Some(h * 60 + m)
|
||||
};
|
||||
|
||||
let start = match parse_time(parts[0]) {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
let end = match parse_time(parts[1]) {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if start <= end {
|
||||
// Same-day range: e.g., 22:00-06:00 would be cross-midnight
|
||||
// This is start <= current < end
|
||||
current_minutes >= start && current_minutes < end
|
||||
} else {
|
||||
// Cross-midnight: e.g., 22:00-06:00
|
||||
current_minutes >= start || current_minutes < end
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregate heartbeat summary.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct HeartbeatSummary {
|
||||
/// Total agents checked.
|
||||
pub total_checked: usize,
|
||||
/// Number of responsive agents.
|
||||
pub responsive: usize,
|
||||
/// Number of unresponsive agents.
|
||||
pub unresponsive: usize,
|
||||
/// Details of unresponsive agents.
|
||||
pub unresponsive_agents: Vec<HeartbeatStatus>,
|
||||
}
|
||||
|
||||
/// Produce a summary from heartbeat statuses.
|
||||
pub fn summarize(statuses: &[HeartbeatStatus]) -> HeartbeatSummary {
|
||||
let unresponsive_agents: Vec<HeartbeatStatus> = statuses
|
||||
.iter()
|
||||
.filter(|s| s.unresponsive)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
HeartbeatSummary {
|
||||
total_checked: statuses.len(),
|
||||
responsive: statuses.len() - unresponsive_agents.len(),
|
||||
unresponsive: unresponsive_agents.len(),
|
||||
unresponsive_agents,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_quiet_hours_parsing() {
|
||||
// We can't easily test time-dependent logic, but we can test format parsing
|
||||
assert!(!is_quiet_hours("invalid"));
|
||||
assert!(!is_quiet_hours(""));
|
||||
assert!(!is_quiet_hours("25:00-06:00")); // Invalid hours handled gracefully
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quiet_hours_format_valid() {
|
||||
// The function returns true/false based on current time
|
||||
// We just verify it doesn't panic on valid input
|
||||
let _ = is_quiet_hours("22:00-06:00");
|
||||
let _ = is_quiet_hours("00:00-23:59");
|
||||
let _ = is_quiet_hours("09:00-17:00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heartbeat_config_default() {
|
||||
let config = HeartbeatConfig::default();
|
||||
assert_eq!(config.check_interval_secs, 30);
|
||||
assert_eq!(config.default_timeout_secs, 60);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summarize_empty() {
|
||||
let summary = summarize(&[]);
|
||||
assert_eq!(summary.total_checked, 0);
|
||||
assert_eq!(summary.responsive, 0);
|
||||
assert_eq!(summary.unresponsive, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summarize_mixed() {
|
||||
let statuses = vec![
|
||||
HeartbeatStatus {
|
||||
agent_id: AgentId::new(),
|
||||
name: "agent-1".to_string(),
|
||||
inactive_secs: 10,
|
||||
unresponsive: false,
|
||||
},
|
||||
HeartbeatStatus {
|
||||
agent_id: AgentId::new(),
|
||||
name: "agent-2".to_string(),
|
||||
inactive_secs: 120,
|
||||
unresponsive: true,
|
||||
},
|
||||
HeartbeatStatus {
|
||||
agent_id: AgentId::new(),
|
||||
name: "agent-3".to_string(),
|
||||
inactive_secs: 5,
|
||||
unresponsive: false,
|
||||
},
|
||||
];
|
||||
|
||||
let summary = summarize(&statuses);
|
||||
assert_eq!(summary.total_checked, 3);
|
||||
assert_eq!(summary.responsive, 2);
|
||||
assert_eq!(summary.unresponsive, 1);
|
||||
assert_eq!(summary.unresponsive_agents.len(), 1);
|
||||
assert_eq!(summary.unresponsive_agents[0].name, "agent-2");
|
||||
}
|
||||
}
|
||||
5226
crates/openfang-kernel/src/kernel.rs
Normal file
5226
crates/openfang-kernel/src/kernel.rs
Normal file
File diff suppressed because it is too large
Load Diff
29
crates/openfang-kernel/src/lib.rs
Normal file
29
crates/openfang-kernel/src/lib.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
//! Core kernel for the OpenFang Agent Operating System.
|
||||
//!
|
||||
//! The kernel manages agent lifecycles, memory, permissions, scheduling,
|
||||
//! and inter-agent communication.
|
||||
|
||||
pub mod approval;
|
||||
pub mod auth;
|
||||
pub mod auto_reply;
|
||||
pub mod background;
|
||||
pub mod capabilities;
|
||||
pub mod config;
|
||||
pub mod config_reload;
|
||||
pub mod cron;
|
||||
pub mod error;
|
||||
pub mod event_bus;
|
||||
pub mod heartbeat;
|
||||
pub mod kernel;
|
||||
pub mod metering;
|
||||
pub mod pairing;
|
||||
pub mod registry;
|
||||
pub mod scheduler;
|
||||
pub mod supervisor;
|
||||
pub mod triggers;
|
||||
pub mod whatsapp_gateway;
|
||||
pub mod wizard;
|
||||
pub mod workflow;
|
||||
|
||||
pub use kernel::DeliveryTracker;
|
||||
pub use kernel::OpenFangKernel;
|
||||
753
crates/openfang-kernel/src/metering.rs
Normal file
753
crates/openfang-kernel/src/metering.rs
Normal file
@@ -0,0 +1,753 @@
|
||||
//! Metering engine — tracks LLM cost and enforces spending quotas.
|
||||
|
||||
use openfang_memory::usage::{ModelUsage, UsageRecord, UsageStore, UsageSummary};
|
||||
use openfang_types::agent::{AgentId, ResourceQuota};
|
||||
use openfang_types::error::{OpenFangError, OpenFangResult};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// The metering engine tracks usage cost and enforces quota limits.
|
||||
pub struct MeteringEngine {
|
||||
/// Persistent usage store (SQLite-backed).
|
||||
store: Arc<UsageStore>,
|
||||
}
|
||||
|
||||
impl MeteringEngine {
|
||||
/// Create a new metering engine with the given usage store.
|
||||
pub fn new(store: Arc<UsageStore>) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
|
||||
/// Record a usage event (persists to SQLite).
|
||||
pub fn record(&self, record: &UsageRecord) -> OpenFangResult<()> {
|
||||
self.store.record(record)
|
||||
}
|
||||
|
||||
/// Check if an agent is within its spending quotas (hourly, daily, monthly).
|
||||
/// Returns Ok(()) if under all quotas, or QuotaExceeded error if over any.
|
||||
pub fn check_quota(&self, agent_id: AgentId, quota: &ResourceQuota) -> OpenFangResult<()> {
|
||||
// Hourly check
|
||||
if quota.max_cost_per_hour_usd > 0.0 {
|
||||
let hourly_cost = self.store.query_hourly(agent_id)?;
|
||||
if hourly_cost >= quota.max_cost_per_hour_usd {
|
||||
return Err(OpenFangError::QuotaExceeded(format!(
|
||||
"Agent {} exceeded hourly cost quota: ${:.4} / ${:.4}",
|
||||
agent_id, hourly_cost, quota.max_cost_per_hour_usd
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Daily check
|
||||
if quota.max_cost_per_day_usd > 0.0 {
|
||||
let daily_cost = self.store.query_daily(agent_id)?;
|
||||
if daily_cost >= quota.max_cost_per_day_usd {
|
||||
return Err(OpenFangError::QuotaExceeded(format!(
|
||||
"Agent {} exceeded daily cost quota: ${:.4} / ${:.4}",
|
||||
agent_id, daily_cost, quota.max_cost_per_day_usd
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Monthly check
|
||||
if quota.max_cost_per_month_usd > 0.0 {
|
||||
let monthly_cost = self.store.query_monthly(agent_id)?;
|
||||
if monthly_cost >= quota.max_cost_per_month_usd {
|
||||
return Err(OpenFangError::QuotaExceeded(format!(
|
||||
"Agent {} exceeded monthly cost quota: ${:.4} / ${:.4}",
|
||||
agent_id, monthly_cost, quota.max_cost_per_month_usd
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check global budget limits (across all agents).
|
||||
pub fn check_global_budget(
|
||||
&self,
|
||||
budget: &openfang_types::config::BudgetConfig,
|
||||
) -> OpenFangResult<()> {
|
||||
if budget.max_hourly_usd > 0.0 {
|
||||
let cost = self.store.query_global_hourly()?;
|
||||
if cost >= budget.max_hourly_usd {
|
||||
return Err(OpenFangError::QuotaExceeded(format!(
|
||||
"Global hourly budget exceeded: ${:.4} / ${:.4}",
|
||||
cost, budget.max_hourly_usd
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if budget.max_daily_usd > 0.0 {
|
||||
let cost = self.store.query_today_cost()?;
|
||||
if cost >= budget.max_daily_usd {
|
||||
return Err(OpenFangError::QuotaExceeded(format!(
|
||||
"Global daily budget exceeded: ${:.4} / ${:.4}",
|
||||
cost, budget.max_daily_usd
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if budget.max_monthly_usd > 0.0 {
|
||||
let cost = self.store.query_global_monthly()?;
|
||||
if cost >= budget.max_monthly_usd {
|
||||
return Err(OpenFangError::QuotaExceeded(format!(
|
||||
"Global monthly budget exceeded: ${:.4} / ${:.4}",
|
||||
cost, budget.max_monthly_usd
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get budget status — current spend vs limits for all time windows.
|
||||
pub fn budget_status(&self, budget: &openfang_types::config::BudgetConfig) -> BudgetStatus {
|
||||
let hourly = self.store.query_global_hourly().unwrap_or(0.0);
|
||||
let daily = self.store.query_today_cost().unwrap_or(0.0);
|
||||
let monthly = self.store.query_global_monthly().unwrap_or(0.0);
|
||||
|
||||
BudgetStatus {
|
||||
hourly_spend: hourly,
|
||||
hourly_limit: budget.max_hourly_usd,
|
||||
hourly_pct: if budget.max_hourly_usd > 0.0 {
|
||||
hourly / budget.max_hourly_usd
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
daily_spend: daily,
|
||||
daily_limit: budget.max_daily_usd,
|
||||
daily_pct: if budget.max_daily_usd > 0.0 {
|
||||
daily / budget.max_daily_usd
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
monthly_spend: monthly,
|
||||
monthly_limit: budget.max_monthly_usd,
|
||||
monthly_pct: if budget.max_monthly_usd > 0.0 {
|
||||
monthly / budget.max_monthly_usd
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
alert_threshold: budget.alert_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a usage summary, optionally filtered by agent.
|
||||
pub fn get_summary(&self, agent_id: Option<AgentId>) -> OpenFangResult<UsageSummary> {
|
||||
self.store.query_summary(agent_id)
|
||||
}
|
||||
|
||||
/// Get usage grouped by model.
|
||||
pub fn get_by_model(&self) -> OpenFangResult<Vec<ModelUsage>> {
|
||||
self.store.query_by_model()
|
||||
}
|
||||
|
||||
/// Estimate the cost of an LLM call based on model and token counts.
|
||||
///
|
||||
/// Pricing table (approximate, per million tokens):
|
||||
///
|
||||
/// | Model Family | Input $/M | Output $/M |
|
||||
/// |-----------------------|-----------|------------|
|
||||
/// | claude-haiku | 0.25 | 1.25 |
|
||||
/// | claude-sonnet-4-6 | 3.00 | 15.00 |
|
||||
/// | claude-opus-4-6 | 5.00 | 25.00 |
|
||||
/// | claude-opus (legacy) | 15.00 | 75.00 |
|
||||
/// | gpt-5.2(-pro) | 1.75 | 14.00 |
|
||||
/// | gpt-5(.1) | 1.25 | 10.00 |
|
||||
/// | gpt-5-mini | 0.25 | 2.00 |
|
||||
/// | gpt-5-nano | 0.05 | 0.40 |
|
||||
/// | gpt-4o | 2.50 | 10.00 |
|
||||
/// | gpt-4o-mini | 0.15 | 0.60 |
|
||||
/// | gpt-4.1 | 2.00 | 8.00 |
|
||||
/// | gpt-4.1-mini | 0.40 | 1.60 |
|
||||
/// | gpt-4.1-nano | 0.10 | 0.40 |
|
||||
/// | o3-mini | 1.10 | 4.40 |
|
||||
/// | gemini-3.1 | 2.50 | 15.00 |
|
||||
/// | gemini-3 | 0.50 | 3.00 |
|
||||
/// | gemini-2.5-flash-lite | 0.04 | 0.15 |
|
||||
/// | gemini-2.5-pro | 1.25 | 10.00 |
|
||||
/// | gemini-2.5-flash | 0.15 | 0.60 |
|
||||
/// | gemini-2.0-flash | 0.10 | 0.40 |
|
||||
/// | deepseek-chat/v3 | 0.27 | 1.10 |
|
||||
/// | deepseek-reasoner/r1 | 0.55 | 2.19 |
|
||||
/// | llama-4-maverick | 0.50 | 0.77 |
|
||||
/// | llama-4-scout | 0.11 | 0.34 |
|
||||
/// | llama/mixtral (groq) | 0.05 | 0.10 |
|
||||
/// | grok-4.1 | 0.20 | 0.50 |
|
||||
/// | grok-4 | 3.00 | 15.00 |
|
||||
/// | grok-3 | 3.00 | 15.00 |
|
||||
/// | qwen | 0.20 | 0.60 |
|
||||
/// | mistral-large | 2.00 | 6.00 |
|
||||
/// | mistral-small | 0.10 | 0.30 |
|
||||
/// | command-r-plus | 2.50 | 10.00 |
|
||||
/// | Default (unknown) | 1.00 | 3.00 |
|
||||
pub fn estimate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
|
||||
let model_lower = model.to_lowercase();
|
||||
let (input_per_m, output_per_m) = estimate_cost_rates(&model_lower);
|
||||
|
||||
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_per_m;
|
||||
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_per_m;
|
||||
input_cost + output_cost
|
||||
}
|
||||
|
||||
/// Estimate cost using the model catalog as the pricing source.
|
||||
///
|
||||
/// Falls back to the default rate ($1/$3 per million) if the model is not
|
||||
/// found in the catalog.
|
||||
pub fn estimate_cost_with_catalog(
|
||||
catalog: &openfang_runtime::model_catalog::ModelCatalog,
|
||||
model: &str,
|
||||
input_tokens: u64,
|
||||
output_tokens: u64,
|
||||
) -> f64 {
|
||||
let (input_per_m, output_per_m) = catalog.pricing(model).unwrap_or((1.0, 3.0));
|
||||
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_per_m;
|
||||
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_per_m;
|
||||
input_cost + output_cost
|
||||
}
|
||||
|
||||
/// Clean up old usage records.
|
||||
pub fn cleanup(&self, days: u32) -> OpenFangResult<usize> {
|
||||
self.store.cleanup_old(days)
|
||||
}
|
||||
}
|
||||
|
||||
/// Budget status snapshot — current spend vs limits for all time windows.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct BudgetStatus {
|
||||
pub hourly_spend: f64,
|
||||
pub hourly_limit: f64,
|
||||
pub hourly_pct: f64,
|
||||
pub daily_spend: f64,
|
||||
pub daily_limit: f64,
|
||||
pub daily_pct: f64,
|
||||
pub monthly_spend: f64,
|
||||
pub monthly_limit: f64,
|
||||
pub monthly_pct: f64,
|
||||
pub alert_threshold: f64,
|
||||
}
|
||||
|
||||
/// Returns (input_per_million, output_per_million) pricing for a model.
|
||||
///
|
||||
/// Order matters: more specific patterns must come before generic ones
|
||||
/// (e.g. "gpt-4o-mini" before "gpt-4o", "gpt-4.1-mini" before "gpt-4.1").
|
||||
fn estimate_cost_rates(model: &str) -> (f64, f64) {
|
||||
// ── Anthropic ──────────────────────────────────────────────
|
||||
if model.contains("haiku") {
|
||||
return (0.25, 1.25);
|
||||
}
|
||||
if model.contains("opus-4-6") || model.contains("claude-opus-4-6") {
|
||||
return (5.0, 25.0);
|
||||
}
|
||||
if model.contains("opus") {
|
||||
return (15.0, 75.0);
|
||||
}
|
||||
if model.contains("sonnet-4-6") || model.contains("claude-sonnet-4-6") {
|
||||
return (3.0, 15.0);
|
||||
}
|
||||
if model.contains("sonnet") {
|
||||
return (3.0, 15.0);
|
||||
}
|
||||
|
||||
// ── OpenAI ─────────────────────────────────────────────────
|
||||
if model.contains("gpt-5.2-pro") {
|
||||
return (1.75, 14.0);
|
||||
}
|
||||
if model.contains("gpt-5.2") {
|
||||
return (1.75, 14.0);
|
||||
}
|
||||
if model.contains("gpt-5.1") {
|
||||
return (1.25, 10.0);
|
||||
}
|
||||
if model.contains("gpt-5-nano") {
|
||||
return (0.05, 0.40);
|
||||
}
|
||||
if model.contains("gpt-5-mini") {
|
||||
return (0.25, 2.0);
|
||||
}
|
||||
if model.contains("gpt-5") {
|
||||
return (1.25, 10.0);
|
||||
}
|
||||
if model.contains("gpt-4o-mini") {
|
||||
return (0.15, 0.60);
|
||||
}
|
||||
if model.contains("gpt-4o") {
|
||||
return (2.50, 10.0);
|
||||
}
|
||||
if model.contains("gpt-4.1-nano") {
|
||||
return (0.10, 0.40);
|
||||
}
|
||||
if model.contains("gpt-4.1-mini") {
|
||||
return (0.40, 1.60);
|
||||
}
|
||||
if model.contains("gpt-4.1") {
|
||||
return (2.00, 8.00);
|
||||
}
|
||||
if model.contains("o4-mini") {
|
||||
return (1.10, 4.40);
|
||||
}
|
||||
if model.contains("o3-mini") {
|
||||
return (1.10, 4.40);
|
||||
}
|
||||
if model.contains("o3") {
|
||||
return (2.00, 8.00);
|
||||
}
|
||||
// Generic gpt-4 fallback
|
||||
if model.contains("gpt-4") {
|
||||
return (2.50, 10.0);
|
||||
}
|
||||
|
||||
// ── Google Gemini ──────────────────────────────────────────
|
||||
if model.contains("gemini-3.1") {
|
||||
return (2.50, 15.0);
|
||||
}
|
||||
if model.contains("gemini-3") {
|
||||
return (0.50, 3.0);
|
||||
}
|
||||
if model.contains("gemini-2.5-flash-lite") {
|
||||
return (0.04, 0.15);
|
||||
}
|
||||
if model.contains("gemini-2.5-pro") {
|
||||
return (1.25, 10.0);
|
||||
}
|
||||
if model.contains("gemini-2.5-flash") {
|
||||
return (0.15, 0.60);
|
||||
}
|
||||
if model.contains("gemini-2.0-flash") || model.contains("gemini-flash") {
|
||||
return (0.10, 0.40);
|
||||
}
|
||||
// Generic gemini fallback
|
||||
if model.contains("gemini") {
|
||||
return (0.15, 0.60);
|
||||
}
|
||||
|
||||
// ── DeepSeek ───────────────────────────────────────────────
|
||||
if model.contains("deepseek-reasoner") || model.contains("deepseek-r1") {
|
||||
return (0.55, 2.19);
|
||||
}
|
||||
if model.contains("deepseek") {
|
||||
return (0.27, 1.10);
|
||||
}
|
||||
|
||||
// ── Cerebras (ultra-fast, cheap) ── must come before llama ─
|
||||
if model.contains("cerebras") {
|
||||
return (0.06, 0.06);
|
||||
}
|
||||
|
||||
// ── SambaNova ── must come before llama ──────────────────────
|
||||
if model.contains("sambanova") {
|
||||
return (0.06, 0.06);
|
||||
}
|
||||
|
||||
// ── Replicate ── must come before llama ─────────────────────
|
||||
if model.contains("replicate") {
|
||||
return (0.40, 0.40);
|
||||
}
|
||||
|
||||
// ── Open-source (Groq, Together, etc.) ─────────────────────
|
||||
if model.contains("llama-4-maverick") {
|
||||
return (0.50, 0.77);
|
||||
}
|
||||
if model.contains("llama-4-scout") {
|
||||
return (0.11, 0.34);
|
||||
}
|
||||
if model.contains("llama") || model.contains("mixtral") {
|
||||
return (0.05, 0.10);
|
||||
}
|
||||
// ── Qwen (Alibaba) ──────────────────────────────────────────
|
||||
if model.contains("qwen-max") {
|
||||
return (4.00, 12.00);
|
||||
}
|
||||
if model.contains("qwen-vl") {
|
||||
return (1.50, 4.50);
|
||||
}
|
||||
if model.contains("qwen-plus") {
|
||||
return (0.80, 2.00);
|
||||
}
|
||||
if model.contains("qwen-turbo") {
|
||||
return (0.30, 0.60);
|
||||
}
|
||||
if model.contains("qwen") {
|
||||
return (0.20, 0.60);
|
||||
}
|
||||
|
||||
// ── MiniMax ──────────────────────────────────────────────────
|
||||
if model.contains("minimax") {
|
||||
return (1.00, 3.00);
|
||||
}
|
||||
|
||||
// ── Zhipu / GLM ─────────────────────────────────────────────
|
||||
if model.contains("glm-4-flash") {
|
||||
return (0.10, 0.10);
|
||||
}
|
||||
if model.contains("glm") {
|
||||
return (1.50, 5.00);
|
||||
}
|
||||
if model.contains("codegeex") {
|
||||
return (0.10, 0.10);
|
||||
}
|
||||
|
||||
// ── Moonshot / Kimi ─────────────────────────────────────────
|
||||
if model.contains("moonshot") || model.contains("kimi") {
|
||||
return (0.80, 0.80);
|
||||
}
|
||||
|
||||
// ── Baidu ERNIE ─────────────────────────────────────────────
|
||||
if model.contains("ernie") {
|
||||
return (2.00, 6.00);
|
||||
}
|
||||
|
||||
// ── AWS Bedrock ─────────────────────────────────────────────
|
||||
if model.contains("nova-pro") {
|
||||
return (0.80, 3.20);
|
||||
}
|
||||
if model.contains("nova-lite") {
|
||||
return (0.06, 0.24);
|
||||
}
|
||||
|
||||
// ── Mistral ────────────────────────────────────────────────
|
||||
if model.contains("mistral-large") {
|
||||
return (2.00, 6.00);
|
||||
}
|
||||
if model.contains("mistral-small") || model.contains("mistral") {
|
||||
return (0.10, 0.30);
|
||||
}
|
||||
|
||||
// ── Cohere ─────────────────────────────────────────────────
|
||||
if model.contains("command-r-plus") {
|
||||
return (2.50, 10.0);
|
||||
}
|
||||
if model.contains("command-r") {
|
||||
return (0.15, 0.60);
|
||||
}
|
||||
|
||||
// ── Perplexity ──────────────────────────────────────────────
|
||||
if model.contains("sonar-pro") {
|
||||
return (3.0, 15.0);
|
||||
}
|
||||
if model.contains("sonar") {
|
||||
return (1.0, 5.0);
|
||||
}
|
||||
|
||||
// ── xAI / Grok ──────────────────────────────────────────────
|
||||
if model.contains("grok-4.1") {
|
||||
return (0.20, 0.50);
|
||||
}
|
||||
if model.contains("grok-4") {
|
||||
return (3.0, 15.0);
|
||||
}
|
||||
if model.contains("grok-3-mini") || model.contains("grok-2-mini") || model.contains("grok-mini")
|
||||
{
|
||||
return (0.30, 0.50);
|
||||
}
|
||||
if model.contains("grok-3") {
|
||||
return (3.0, 15.0);
|
||||
}
|
||||
if model.contains("grok") {
|
||||
return (2.0, 10.0);
|
||||
}
|
||||
|
||||
// ── AI21 / Jamba ────────────────────────────────────────────
|
||||
if model.contains("jamba") {
|
||||
return (2.0, 8.0);
|
||||
}
|
||||
|
||||
// ── Default (conservative) ─────────────────────────────────
|
||||
(1.0, 3.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use openfang_memory::MemorySubstrate;
|
||||
|
||||
fn setup() -> MeteringEngine {
|
||||
let substrate = MemorySubstrate::open_in_memory(0.1).unwrap();
|
||||
let store = Arc::new(UsageStore::new(substrate.usage_conn()));
|
||||
MeteringEngine::new(store)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_and_check_quota_under() {
|
||||
let engine = setup();
|
||||
let agent_id = AgentId::new();
|
||||
let quota = ResourceQuota {
|
||||
max_cost_per_hour_usd: 1.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
engine
|
||||
.record(&UsageRecord {
|
||||
agent_id,
|
||||
model: "claude-haiku".to_string(),
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
cost_usd: 0.001,
|
||||
tool_calls: 0,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(engine.check_quota(agent_id, "a).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_quota_exceeded() {
|
||||
let engine = setup();
|
||||
let agent_id = AgentId::new();
|
||||
let quota = ResourceQuota {
|
||||
max_cost_per_hour_usd: 0.01,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
engine
|
||||
.record(&UsageRecord {
|
||||
agent_id,
|
||||
model: "claude-sonnet".to_string(),
|
||||
input_tokens: 10000,
|
||||
output_tokens: 5000,
|
||||
cost_usd: 0.05,
|
||||
tool_calls: 0,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = engine.check_quota(agent_id, "a);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("exceeded hourly cost quota"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_quota_zero_limit_skipped() {
|
||||
let engine = setup();
|
||||
let agent_id = AgentId::new();
|
||||
let quota = ResourceQuota {
|
||||
max_cost_per_hour_usd: 0.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Even with high usage, a zero limit means no enforcement
|
||||
engine
|
||||
.record(&UsageRecord {
|
||||
agent_id,
|
||||
model: "claude-opus".to_string(),
|
||||
input_tokens: 100000,
|
||||
output_tokens: 50000,
|
||||
cost_usd: 100.0,
|
||||
tool_calls: 0,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(engine.check_quota(agent_id, "a).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_haiku() {
|
||||
let cost = MeteringEngine::estimate_cost("claude-haiku-4-5-20251001", 1_000_000, 1_000_000);
|
||||
assert!((cost - 1.50).abs() < 0.01); // $0.25 + $1.25
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_sonnet() {
|
||||
let cost = MeteringEngine::estimate_cost("claude-sonnet-4-20250514", 1_000_000, 1_000_000);
|
||||
assert!((cost - 18.0).abs() < 0.01); // $3.00 + $15.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_opus() {
|
||||
let cost = MeteringEngine::estimate_cost("claude-opus-4-20250514", 1_000_000, 1_000_000);
|
||||
assert!((cost - 90.0).abs() < 0.01); // $15.00 + $75.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_gpt4o() {
|
||||
let cost = MeteringEngine::estimate_cost("gpt-4o-2024-11-20", 1_000_000, 1_000_000);
|
||||
assert!((cost - 12.50).abs() < 0.01); // $2.50 + $10.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_gpt4o_mini() {
|
||||
let cost = MeteringEngine::estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.75).abs() < 0.01); // $0.15 + $0.60
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_gpt41() {
|
||||
let cost = MeteringEngine::estimate_cost("gpt-4.1", 1_000_000, 1_000_000);
|
||||
assert!((cost - 10.0).abs() < 0.01); // $2.00 + $8.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_gpt41_mini() {
|
||||
let cost = MeteringEngine::estimate_cost("gpt-4.1-mini", 1_000_000, 1_000_000);
|
||||
assert!((cost - 2.0).abs() < 0.01); // $0.40 + $1.60
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_gpt41_nano() {
|
||||
let cost = MeteringEngine::estimate_cost("gpt-4.1-nano", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.50).abs() < 0.01); // $0.10 + $0.40
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_o3_mini() {
|
||||
let cost = MeteringEngine::estimate_cost("o3-mini", 1_000_000, 1_000_000);
|
||||
assert!((cost - 5.50).abs() < 0.01); // $1.10 + $4.40
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_gemini_20_flash() {
|
||||
let cost = MeteringEngine::estimate_cost("gemini-2.0-flash", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.50).abs() < 0.01); // $0.10 + $0.40
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_gemini_25_pro() {
|
||||
let cost = MeteringEngine::estimate_cost("gemini-2.5-pro", 1_000_000, 1_000_000);
|
||||
assert!((cost - 11.25).abs() < 0.01); // $1.25 + $10.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_gemini_25_flash() {
|
||||
let cost = MeteringEngine::estimate_cost("gemini-2.5-flash", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.75).abs() < 0.01); // $0.15 + $0.60
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_deepseek_chat() {
|
||||
let cost = MeteringEngine::estimate_cost("deepseek-chat", 1_000_000, 1_000_000);
|
||||
assert!((cost - 1.37).abs() < 0.01); // $0.27 + $1.10
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_deepseek_reasoner() {
|
||||
let cost = MeteringEngine::estimate_cost("deepseek-reasoner", 1_000_000, 1_000_000);
|
||||
assert!((cost - 2.74).abs() < 0.01); // $0.55 + $2.19
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_llama() {
|
||||
let cost = MeteringEngine::estimate_cost("llama-3.3-70b-versatile", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.15).abs() < 0.01); // $0.05 + $0.10
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_mixtral() {
|
||||
let cost = MeteringEngine::estimate_cost("mixtral-8x7b", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.15).abs() < 0.01); // $0.05 + $0.10
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_qwen() {
|
||||
let cost = MeteringEngine::estimate_cost("qwen-2.5-72b", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.80).abs() < 0.01); // $0.20 + $0.60
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_mistral_large() {
|
||||
let cost = MeteringEngine::estimate_cost("mistral-large-latest", 1_000_000, 1_000_000);
|
||||
assert!((cost - 8.0).abs() < 0.01); // $2.00 + $6.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_mistral_small() {
|
||||
let cost = MeteringEngine::estimate_cost("mistral-small-latest", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.40).abs() < 0.01); // $0.10 + $0.30
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_command_r_plus() {
|
||||
let cost = MeteringEngine::estimate_cost("command-r-plus", 1_000_000, 1_000_000);
|
||||
assert!((cost - 12.50).abs() < 0.01); // $2.50 + $10.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_unknown() {
|
||||
let cost = MeteringEngine::estimate_cost("my-custom-model", 1_000_000, 1_000_000);
|
||||
assert!((cost - 4.0).abs() < 0.01); // $1.00 + $3.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_grok() {
|
||||
let cost = MeteringEngine::estimate_cost("grok-2", 1_000_000, 1_000_000);
|
||||
assert!((cost - 12.0).abs() < 0.01); // $2.00 + $10.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_grok_mini() {
|
||||
let cost = MeteringEngine::estimate_cost("grok-2-mini", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.80).abs() < 0.01); // $0.30 + $0.50
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_sonar_pro() {
|
||||
let cost = MeteringEngine::estimate_cost("sonar-pro", 1_000_000, 1_000_000);
|
||||
assert!((cost - 18.0).abs() < 0.01); // $3.00 + $15.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_jamba() {
|
||||
let cost = MeteringEngine::estimate_cost("jamba-1.5-large", 1_000_000, 1_000_000);
|
||||
assert!((cost - 10.0).abs() < 0.01); // $2.00 + $8.00
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_cerebras() {
|
||||
let cost = MeteringEngine::estimate_cost("cerebras/llama3.3-70b", 1_000_000, 1_000_000);
|
||||
assert!((cost - 0.12).abs() < 0.01); // $0.06 + $0.06
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_with_catalog() {
|
||||
let catalog = openfang_runtime::model_catalog::ModelCatalog::new();
|
||||
// Sonnet: $3/M input, $15/M output
|
||||
let cost = MeteringEngine::estimate_cost_with_catalog(
|
||||
&catalog,
|
||||
"claude-sonnet-4-20250514",
|
||||
1_000_000,
|
||||
1_000_000,
|
||||
);
|
||||
assert!((cost - 18.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_with_catalog_alias() {
|
||||
let catalog = openfang_runtime::model_catalog::ModelCatalog::new();
|
||||
// "sonnet" alias should resolve to same pricing
|
||||
let cost =
|
||||
MeteringEngine::estimate_cost_with_catalog(&catalog, "sonnet", 1_000_000, 1_000_000);
|
||||
assert!((cost - 18.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_cost_with_catalog_unknown_uses_default() {
|
||||
let catalog = openfang_runtime::model_catalog::ModelCatalog::new();
|
||||
// Unknown model falls back to $1/$3
|
||||
let cost = MeteringEngine::estimate_cost_with_catalog(
|
||||
&catalog,
|
||||
"totally-unknown-model",
|
||||
1_000_000,
|
||||
1_000_000,
|
||||
);
|
||||
assert!((cost - 4.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_summary() {
|
||||
let engine = setup();
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
engine
|
||||
.record(&UsageRecord {
|
||||
agent_id,
|
||||
model: "haiku".to_string(),
|
||||
input_tokens: 500,
|
||||
output_tokens: 200,
|
||||
cost_usd: 0.005,
|
||||
tool_calls: 3,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let summary = engine.get_summary(Some(agent_id)).unwrap();
|
||||
assert_eq!(summary.call_count, 1);
|
||||
assert_eq!(summary.total_input_tokens, 500);
|
||||
}
|
||||
}
|
||||
510
crates/openfang-kernel/src/pairing.rs
Normal file
510
crates/openfang-kernel/src/pairing.rs
Normal file
@@ -0,0 +1,510 @@
|
||||
//! Device pairing — QR-code flow for mobile/desktop clients.
|
||||
//!
|
||||
//! Supports pairing via short-lived tokens, device management, and
|
||||
//! push notifications via ntfy.sh or gotify.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::config::PairingConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Maximum concurrent pairing requests (prevent token flooding).
|
||||
const MAX_PENDING_REQUESTS: usize = 5;
|
||||
|
||||
/// A paired device record.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PairedDevice {
|
||||
pub device_id: String,
|
||||
pub display_name: String,
|
||||
pub platform: String,
|
||||
pub paired_at: chrono::DateTime<chrono::Utc>,
|
||||
pub last_seen: chrono::DateTime<chrono::Utc>,
|
||||
#[serde(skip_serializing)]
|
||||
pub push_token: Option<String>,
|
||||
}
|
||||
|
||||
/// Pairing request (short-lived, for QR code flow).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PairingRequest {
|
||||
pub token: String,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub expires_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
/// Persistence callback — kernel injects this so PairingManager can save without
|
||||
/// taking a direct dependency on openfang-memory.
|
||||
pub type PersistFn = Box<dyn Fn(&PairedDevice, PersistOp) + Send + Sync>;
|
||||
|
||||
/// Persistence operation kind.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum PersistOp {
|
||||
Save,
|
||||
Remove,
|
||||
}
|
||||
|
||||
/// Device pairing manager.
|
||||
pub struct PairingManager {
|
||||
config: PairingConfig,
|
||||
pending: DashMap<String, PairingRequest>,
|
||||
devices: DashMap<String, PairedDevice>,
|
||||
persist: Option<PersistFn>,
|
||||
}
|
||||
|
||||
impl PairingManager {
|
||||
pub fn new(config: PairingConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
pending: DashMap::new(),
|
||||
devices: DashMap::new(),
|
||||
persist: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach a persistence callback (called after pair/unpair operations).
|
||||
pub fn set_persist(&mut self, f: PersistFn) {
|
||||
self.persist = Some(f);
|
||||
}
|
||||
|
||||
/// Bulk-load devices from persistence (call once at boot).
|
||||
pub fn load_devices(&self, devices: Vec<PairedDevice>) {
|
||||
for d in devices {
|
||||
self.devices.insert(d.device_id.clone(), d);
|
||||
}
|
||||
debug!(
|
||||
count = self.devices.len(),
|
||||
"Loaded paired devices from database"
|
||||
);
|
||||
}
|
||||
|
||||
/// Generate a new pairing request. Returns token for QR encoding.
|
||||
pub fn create_pairing_request(&self) -> Result<PairingRequest, String> {
|
||||
if !self.config.enabled {
|
||||
return Err("Device pairing is disabled".into());
|
||||
}
|
||||
|
||||
// Enforce max pending limit
|
||||
if self.pending.len() >= MAX_PENDING_REQUESTS {
|
||||
// Clean expired first
|
||||
self.clean_expired();
|
||||
if self.pending.len() >= MAX_PENDING_REQUESTS {
|
||||
return Err("Too many pending pairing requests. Try again later.".into());
|
||||
}
|
||||
}
|
||||
|
||||
// Generate secure random token (32 bytes = 64 hex chars)
|
||||
let mut token_bytes = [0u8; 32];
|
||||
use rand::RngCore;
|
||||
rand::thread_rng().fill_bytes(&mut token_bytes);
|
||||
let token = hex::encode(token_bytes);
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let expires_at = now + chrono::Duration::seconds(self.config.token_expiry_secs as i64);
|
||||
|
||||
let request = PairingRequest {
|
||||
token: token.clone(),
|
||||
created_at: now,
|
||||
expires_at,
|
||||
};
|
||||
|
||||
self.pending.insert(token, request.clone());
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
/// Complete pairing — device submits token + device info.
|
||||
pub fn complete_pairing(
|
||||
&self,
|
||||
token: &str,
|
||||
device_info: PairedDevice,
|
||||
) -> Result<PairedDevice, String> {
|
||||
// SECURITY: Constant-time token comparison
|
||||
let found = self.pending.iter().find(|entry| {
|
||||
use subtle::ConstantTimeEq;
|
||||
let stored = entry.value().token.as_bytes();
|
||||
let provided = token.as_bytes();
|
||||
if stored.len() != provided.len() {
|
||||
return false;
|
||||
}
|
||||
stored.ct_eq(provided).into()
|
||||
});
|
||||
|
||||
let entry = found.ok_or("Invalid or expired pairing token")?;
|
||||
let request = entry.value().clone();
|
||||
let key = entry.key().clone();
|
||||
drop(entry);
|
||||
|
||||
// Check expiry
|
||||
if chrono::Utc::now() > request.expires_at {
|
||||
self.pending.remove(&key);
|
||||
return Err("Pairing token has expired".into());
|
||||
}
|
||||
|
||||
// Check max devices
|
||||
if self.devices.len() >= self.config.max_devices {
|
||||
return Err(format!(
|
||||
"Maximum paired devices ({}) reached. Remove a device first.",
|
||||
self.config.max_devices
|
||||
));
|
||||
}
|
||||
|
||||
// Remove the used token
|
||||
self.pending.remove(&key);
|
||||
|
||||
// Store the device
|
||||
let device_id = device_info.device_id.clone();
|
||||
self.devices.insert(device_id.clone(), device_info.clone());
|
||||
|
||||
// Persist to database
|
||||
if let Some(ref persist) = self.persist {
|
||||
persist(&device_info, PersistOp::Save);
|
||||
}
|
||||
|
||||
debug!(device_id = %device_id, "Device paired successfully");
|
||||
|
||||
Ok(device_info)
|
||||
}
|
||||
|
||||
/// List paired devices.
|
||||
pub fn list_devices(&self) -> Vec<PairedDevice> {
|
||||
self.devices.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
|
||||
/// Remove a paired device.
|
||||
pub fn remove_device(&self, device_id: &str) -> Result<(), String> {
|
||||
let removed = self
|
||||
.devices
|
||||
.remove(device_id)
|
||||
.ok_or_else(|| format!("Device '{device_id}' not found"))?;
|
||||
|
||||
// Persist removal to database
|
||||
if let Some(ref persist) = self.persist {
|
||||
persist(&removed.1, PersistOp::Remove);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send push notification to all paired devices.
|
||||
pub async fn notify_devices(
|
||||
&self,
|
||||
title: &str,
|
||||
body: &str,
|
||||
) -> Vec<(String, Result<(), String>)> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
match self.config.push_provider.as_str() {
|
||||
"ntfy" => {
|
||||
let url = self.config.ntfy_url.as_deref().unwrap_or("https://ntfy.sh");
|
||||
let topic = match &self.config.ntfy_topic {
|
||||
Some(t) => t.clone(),
|
||||
None => {
|
||||
results.push(("ntfy".to_string(), Err("ntfy_topic not configured".into())));
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
||||
let full_url = format!("{}/{}", url.trim_end_matches('/'), topic);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
match client
|
||||
.post(&full_url)
|
||||
.header("Title", title)
|
||||
.body(body.to_string())
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
for device in self.devices.iter() {
|
||||
results.push((device.device_id.clone(), Ok(())));
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
results.push((
|
||||
"ntfy".to_string(),
|
||||
Err(format!("ntfy returned HTTP {status}")),
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
results
|
||||
.push(("ntfy".to_string(), Err(format!("ntfy request failed: {e}"))));
|
||||
}
|
||||
}
|
||||
}
|
||||
"gotify" => {
|
||||
// Gotify requires an app token
|
||||
let app_token = match std::env::var("GOTIFY_APP_TOKEN") {
|
||||
Ok(t) => t,
|
||||
Err(_) => {
|
||||
results
|
||||
.push(("gotify".to_string(), Err("GOTIFY_APP_TOKEN not set".into())));
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
||||
let server_url = match std::env::var("GOTIFY_SERVER_URL") {
|
||||
Ok(u) => u,
|
||||
Err(_) => {
|
||||
results.push((
|
||||
"gotify".to_string(),
|
||||
Err("GOTIFY_SERVER_URL not set".into()),
|
||||
));
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
||||
let url = format!("{}/message", server_url.trim_end_matches('/'));
|
||||
let body_json = serde_json::json!({
|
||||
"title": title,
|
||||
"message": body,
|
||||
"priority": 5,
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
match client
|
||||
.post(&url)
|
||||
.header("X-Gotify-Key", &app_token)
|
||||
.json(&body_json)
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
for device in self.devices.iter() {
|
||||
results.push((device.device_id.clone(), Ok(())));
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
results.push((
|
||||
"gotify".to_string(),
|
||||
Err(format!("gotify returned HTTP {status}")),
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
results.push((
|
||||
"gotify".to_string(),
|
||||
Err(format!("gotify request failed: {e}")),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
"none" | "" => {
|
||||
// No push provider configured — silent
|
||||
}
|
||||
other => {
|
||||
warn!(provider = other, "Unknown push notification provider");
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Clean expired pairing requests.
|
||||
pub fn clean_expired(&self) {
|
||||
let now = chrono::Utc::now();
|
||||
self.pending.retain(|_, req| req.expires_at > now);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_config() -> PairingConfig {
|
||||
PairingConfig::default()
|
||||
}
|
||||
|
||||
fn enabled_config() -> PairingConfig {
|
||||
PairingConfig {
|
||||
enabled: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manager_creation() {
|
||||
let mgr = PairingManager::new(default_config());
|
||||
assert!(mgr.devices.is_empty());
|
||||
assert!(mgr.pending.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_request_disabled() {
|
||||
let mgr = PairingManager::new(default_config());
|
||||
let result = mgr.create_pairing_request();
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("disabled"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_request_success() {
|
||||
let mgr = PairingManager::new(enabled_config());
|
||||
let req = mgr.create_pairing_request().unwrap();
|
||||
assert_eq!(req.token.len(), 64); // 32 bytes = 64 hex chars
|
||||
assert!(req.expires_at > req.created_at);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pending_requests() {
|
||||
let mgr = PairingManager::new(enabled_config());
|
||||
for _ in 0..MAX_PENDING_REQUESTS {
|
||||
mgr.create_pairing_request().unwrap();
|
||||
}
|
||||
let result = mgr.create_pairing_request();
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Too many"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complete_pairing_invalid_token() {
|
||||
let mgr = PairingManager::new(enabled_config());
|
||||
let device = PairedDevice {
|
||||
device_id: "dev-1".to_string(),
|
||||
display_name: "My Phone".to_string(),
|
||||
platform: "android".to_string(),
|
||||
paired_at: chrono::Utc::now(),
|
||||
last_seen: chrono::Utc::now(),
|
||||
push_token: None,
|
||||
};
|
||||
let result = mgr.complete_pairing("invalid-token", device);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complete_pairing_success() {
|
||||
let mgr = PairingManager::new(enabled_config());
|
||||
let req = mgr.create_pairing_request().unwrap();
|
||||
|
||||
let device = PairedDevice {
|
||||
device_id: "dev-1".to_string(),
|
||||
display_name: "My Phone".to_string(),
|
||||
platform: "android".to_string(),
|
||||
paired_at: chrono::Utc::now(),
|
||||
last_seen: chrono::Utc::now(),
|
||||
push_token: None,
|
||||
};
|
||||
|
||||
let result = mgr.complete_pairing(&req.token, device);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(mgr.devices.len(), 1);
|
||||
assert!(mgr.pending.is_empty()); // Token consumed
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_devices_enforced() {
|
||||
let config = PairingConfig {
|
||||
enabled: true,
|
||||
max_devices: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let mgr = PairingManager::new(config);
|
||||
|
||||
// Pair first device
|
||||
let req1 = mgr.create_pairing_request().unwrap();
|
||||
let d1 = PairedDevice {
|
||||
device_id: "dev-1".to_string(),
|
||||
display_name: "Phone 1".to_string(),
|
||||
platform: "ios".to_string(),
|
||||
paired_at: chrono::Utc::now(),
|
||||
last_seen: chrono::Utc::now(),
|
||||
push_token: None,
|
||||
};
|
||||
mgr.complete_pairing(&req1.token, d1).unwrap();
|
||||
|
||||
// Try second device
|
||||
let req2 = mgr.create_pairing_request().unwrap();
|
||||
let d2 = PairedDevice {
|
||||
device_id: "dev-2".to_string(),
|
||||
display_name: "Phone 2".to_string(),
|
||||
platform: "android".to_string(),
|
||||
paired_at: chrono::Utc::now(),
|
||||
last_seen: chrono::Utc::now(),
|
||||
push_token: None,
|
||||
};
|
||||
let result = mgr.complete_pairing(&req2.token, d2);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Maximum"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_devices() {
|
||||
let mgr = PairingManager::new(enabled_config());
|
||||
let req = mgr.create_pairing_request().unwrap();
|
||||
let device = PairedDevice {
|
||||
device_id: "dev-1".to_string(),
|
||||
display_name: "My Phone".to_string(),
|
||||
platform: "android".to_string(),
|
||||
paired_at: chrono::Utc::now(),
|
||||
last_seen: chrono::Utc::now(),
|
||||
push_token: None,
|
||||
};
|
||||
mgr.complete_pairing(&req.token, device).unwrap();
|
||||
|
||||
let devices = mgr.list_devices();
|
||||
assert_eq!(devices.len(), 1);
|
||||
assert_eq!(devices[0].display_name, "My Phone");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_device() {
|
||||
let mgr = PairingManager::new(enabled_config());
|
||||
let req = mgr.create_pairing_request().unwrap();
|
||||
let device = PairedDevice {
|
||||
device_id: "dev-1".to_string(),
|
||||
display_name: "My Phone".to_string(),
|
||||
platform: "android".to_string(),
|
||||
paired_at: chrono::Utc::now(),
|
||||
last_seen: chrono::Utc::now(),
|
||||
push_token: None,
|
||||
};
|
||||
mgr.complete_pairing(&req.token, device).unwrap();
|
||||
|
||||
assert!(mgr.remove_device("dev-1").is_ok());
|
||||
assert!(mgr.devices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_nonexistent_device() {
|
||||
let mgr = PairingManager::new(enabled_config());
|
||||
assert!(mgr.remove_device("nonexistent").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clean_expired() {
|
||||
let config = PairingConfig {
|
||||
enabled: true,
|
||||
token_expiry_secs: 0, // Expire immediately
|
||||
..Default::default()
|
||||
};
|
||||
let mgr = PairingManager::new(config);
|
||||
mgr.create_pairing_request().unwrap();
|
||||
assert_eq!(mgr.pending.len(), 1);
|
||||
|
||||
// Wait a tiny bit for expiry
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
mgr.clean_expired();
|
||||
assert!(mgr.pending.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_length() {
|
||||
let mgr = PairingManager::new(enabled_config());
|
||||
let req = mgr.create_pairing_request().unwrap();
|
||||
// 32 random bytes = 64 hex chars
|
||||
assert_eq!(req.token.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_defaults() {
|
||||
let config = PairingConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.max_devices, 10);
|
||||
assert_eq!(config.token_expiry_secs, 300);
|
||||
assert_eq!(config.push_provider, "none");
|
||||
assert!(config.ntfy_url.is_none());
|
||||
assert!(config.ntfy_topic.is_none());
|
||||
}
|
||||
}
|
||||
464
crates/openfang-kernel/src/registry.rs
Normal file
464
crates/openfang-kernel/src/registry.rs
Normal file
@@ -0,0 +1,464 @@
|
||||
//! Agent registry — tracks all agents, their state, and indexes.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::{AgentEntry, AgentId, AgentMode, AgentState};
|
||||
use openfang_types::error::{OpenFangError, OpenFangResult};
|
||||
|
||||
/// Registry of all agents in the kernel.
|
||||
pub struct AgentRegistry {
|
||||
/// Primary index: agent ID → entry.
|
||||
agents: DashMap<AgentId, AgentEntry>,
|
||||
/// Name index: human-readable name → agent ID.
|
||||
name_index: DashMap<String, AgentId>,
|
||||
/// Tag index: tag → list of agent IDs.
|
||||
tag_index: DashMap<String, Vec<AgentId>>,
|
||||
}
|
||||
|
||||
impl AgentRegistry {
|
||||
/// Create a new empty registry.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
agents: DashMap::new(),
|
||||
name_index: DashMap::new(),
|
||||
tag_index: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new agent.
|
||||
pub fn register(&self, entry: AgentEntry) -> OpenFangResult<()> {
|
||||
if self.name_index.contains_key(&entry.name) {
|
||||
return Err(OpenFangError::AgentAlreadyExists(entry.name.clone()));
|
||||
}
|
||||
let id = entry.id;
|
||||
self.name_index.insert(entry.name.clone(), id);
|
||||
for tag in &entry.tags {
|
||||
self.tag_index.entry(tag.clone()).or_default().push(id);
|
||||
}
|
||||
self.agents.insert(id, entry);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get an agent entry by ID.
|
||||
pub fn get(&self, id: AgentId) -> Option<AgentEntry> {
|
||||
self.agents.get(&id).map(|e| e.value().clone())
|
||||
}
|
||||
|
||||
/// Find an agent by name.
|
||||
pub fn find_by_name(&self, name: &str) -> Option<AgentEntry> {
|
||||
self.name_index
|
||||
.get(name)
|
||||
.and_then(|id| self.agents.get(id.value()).map(|e| e.value().clone()))
|
||||
}
|
||||
|
||||
/// Update agent state.
|
||||
pub fn set_state(&self, id: AgentId, state: AgentState) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.state = state;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update agent operational mode.
|
||||
pub fn set_mode(&self, id: AgentId, mode: AgentMode) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.mode = mode;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove an agent from the registry.
|
||||
pub fn remove(&self, id: AgentId) -> OpenFangResult<AgentEntry> {
|
||||
let (_, entry) = self
|
||||
.agents
|
||||
.remove(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
self.name_index.remove(&entry.name);
|
||||
for tag in &entry.tags {
|
||||
if let Some(mut ids) = self.tag_index.get_mut(tag) {
|
||||
ids.retain(|&agent_id| agent_id != id);
|
||||
}
|
||||
}
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// List all agents.
|
||||
pub fn list(&self) -> Vec<AgentEntry> {
|
||||
self.agents.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
|
||||
/// Add a child agent ID to a parent's children list.
|
||||
pub fn add_child(&self, parent_id: AgentId, child_id: AgentId) {
|
||||
if let Some(mut entry) = self.agents.get_mut(&parent_id) {
|
||||
entry.children.push(child_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Count of registered agents.
|
||||
pub fn count(&self) -> usize {
|
||||
self.agents.len()
|
||||
}
|
||||
|
||||
/// Update an agent's session ID (for session reset).
|
||||
pub fn update_session_id(
|
||||
&self,
|
||||
id: AgentId,
|
||||
new_session_id: openfang_types::agent::SessionId,
|
||||
) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.session_id = new_session_id;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's workspace path.
|
||||
pub fn update_workspace(
|
||||
&self,
|
||||
id: AgentId,
|
||||
workspace: Option<std::path::PathBuf>,
|
||||
) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.manifest.workspace = workspace;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's visual identity (emoji, avatar, color).
|
||||
pub fn update_identity(
|
||||
&self,
|
||||
id: AgentId,
|
||||
identity: openfang_types::agent::AgentIdentity,
|
||||
) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.identity = identity;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's model configuration.
|
||||
pub fn update_model(&self, id: AgentId, new_model: String) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.manifest.model.model = new_model;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's model AND provider together.
|
||||
pub fn update_model_and_provider(
|
||||
&self,
|
||||
id: AgentId,
|
||||
new_model: String,
|
||||
new_provider: String,
|
||||
) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.manifest.model.model = new_model;
|
||||
entry.manifest.model.provider = new_provider;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's skill allowlist.
|
||||
pub fn update_skills(&self, id: AgentId, skills: Vec<String>) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.manifest.skills = skills;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's MCP server allowlist.
|
||||
pub fn update_mcp_servers(&self, id: AgentId, servers: Vec<String>) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.manifest.mcp_servers = servers;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's system prompt (hot-swap, takes effect on next message).
|
||||
pub fn update_system_prompt(&self, id: AgentId, new_prompt: String) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.manifest.model.system_prompt = new_prompt;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's name (also updates the name index).
|
||||
pub fn update_name(&self, id: AgentId, new_name: String) -> OpenFangResult<()> {
|
||||
if self.name_index.contains_key(&new_name) {
|
||||
return Err(OpenFangError::AgentAlreadyExists(new_name));
|
||||
}
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
let old_name = entry.name.clone();
|
||||
entry.name = new_name.clone();
|
||||
entry.manifest.name = new_name.clone();
|
||||
entry.last_active = chrono::Utc::now();
|
||||
// Update name index
|
||||
drop(entry);
|
||||
self.name_index.remove(&old_name);
|
||||
self.name_index.insert(new_name, id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an agent's description.
|
||||
pub fn update_description(&self, id: AgentId, new_desc: String) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.manifest.description = new_desc;
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark an agent's onboarding as complete.
|
||||
pub fn mark_onboarding_complete(&self, id: AgentId) -> OpenFangResult<()> {
|
||||
let mut entry = self
|
||||
.agents
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
|
||||
entry.onboarding_completed = true;
|
||||
entry.onboarding_completed_at = Some(chrono::Utc::now());
|
||||
entry.last_active = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use openfang_types::agent::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn test_entry(name: &str) -> AgentEntry {
|
||||
AgentEntry {
|
||||
id: AgentId::new(),
|
||||
name: name.to_string(),
|
||||
manifest: AgentManifest {
|
||||
name: name.to_string(),
|
||||
version: "0.1.0".to_string(),
|
||||
description: "test".to_string(),
|
||||
author: "test".to_string(),
|
||||
module: "test".to_string(),
|
||||
schedule: ScheduleMode::default(),
|
||||
model: ModelConfig::default(),
|
||||
fallback_models: vec![],
|
||||
resources: ResourceQuota::default(),
|
||||
priority: Priority::default(),
|
||||
capabilities: ManifestCapabilities::default(),
|
||||
profile: None,
|
||||
tools: HashMap::new(),
|
||||
skills: vec![],
|
||||
mcp_servers: vec![],
|
||||
metadata: HashMap::new(),
|
||||
tags: vec![],
|
||||
routing: None,
|
||||
autonomous: None,
|
||||
pinned_model: None,
|
||||
workspace: None,
|
||||
generate_identity_files: true,
|
||||
exec_policy: None,
|
||||
},
|
||||
state: AgentState::Created,
|
||||
mode: AgentMode::default(),
|
||||
created_at: Utc::now(),
|
||||
last_active: Utc::now(),
|
||||
parent: None,
|
||||
children: vec![],
|
||||
session_id: SessionId::new(),
|
||||
tags: vec![],
|
||||
identity: Default::default(),
|
||||
onboarding_completed: false,
|
||||
onboarding_completed_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_and_get() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("test-agent");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
assert!(registry.get(id).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_by_name() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("my-agent");
|
||||
registry.register(entry).unwrap();
|
||||
assert!(registry.find_by_name("my-agent").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_name() {
|
||||
let registry = AgentRegistry::new();
|
||||
registry.register(test_entry("dup")).unwrap();
|
||||
assert!(registry.register(test_entry("dup")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("removable");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
registry.remove(id).unwrap();
|
||||
assert!(registry.get(id).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_state() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("state-test");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
|
||||
registry.set_state(id, AgentState::Running).unwrap();
|
||||
let updated = registry.get(id).unwrap();
|
||||
assert_eq!(updated.state, AgentState::Running);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_mode() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("mode-test");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
|
||||
registry.set_mode(id, AgentMode::Autonomous).unwrap();
|
||||
let updated = registry.get(id).unwrap();
|
||||
assert_eq!(updated.mode, AgentMode::Autonomous);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_identity() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("identity-test");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
|
||||
let identity = AgentIdentity {
|
||||
emoji: Some("🤖".to_string()),
|
||||
avatar_url: None,
|
||||
color: Some("#ff0000".to_string()),
|
||||
archetype: None,
|
||||
vibe: None,
|
||||
greeting_style: None,
|
||||
};
|
||||
registry.update_identity(id, identity).unwrap();
|
||||
|
||||
let updated = registry.get(id).unwrap();
|
||||
assert_eq!(updated.identity.emoji, Some("🤖".to_string()));
|
||||
assert_eq!(updated.identity.color, Some("#ff0000".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_name() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("old-name");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
|
||||
registry.update_name(id, "new-name".to_string()).unwrap();
|
||||
|
||||
// Should be findable by new name
|
||||
assert!(registry.find_by_name("new-name").is_some());
|
||||
// Should not be findable by old name
|
||||
assert!(registry.find_by_name("old-name").is_none());
|
||||
|
||||
let updated = registry.get(id).unwrap();
|
||||
assert_eq!(updated.name, "new-name");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_description() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("desc-test");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
|
||||
registry.update_description(id, "New description".to_string()).unwrap();
|
||||
|
||||
let updated = registry.get(id).unwrap();
|
||||
assert_eq!(updated.manifest.description, "New description");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_model() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("model-test");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
|
||||
registry.update_model(id, "gpt-4".to_string()).unwrap();
|
||||
|
||||
let updated = registry.get(id).unwrap();
|
||||
assert_eq!(updated.manifest.model.model, "gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_skills() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("skills-test");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
|
||||
let skills = vec!["web-search".to_string(), "code-execution".to_string()];
|
||||
registry.update_skills(id, skills.clone()).unwrap();
|
||||
|
||||
let updated = registry.get(id).unwrap();
|
||||
assert_eq!(updated.manifest.skills, skills);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_mcp_servers() {
|
||||
let registry = AgentRegistry::new();
|
||||
let entry = test_entry("mcp-test");
|
||||
let id = entry.id;
|
||||
registry.register(entry).unwrap();
|
||||
|
||||
let servers = vec!["filesystem".to_string(), "database".to_string()];
|
||||
registry.update_mcp_servers(id, servers.clone()).unwrap();
|
||||
|
||||
let updated = registry.get(id).unwrap();
|
||||
assert_eq!(updated.manifest.mcp_servers, servers);
|
||||
}
|
||||
}
|
||||
170
crates/openfang-kernel/src/scheduler.rs
Normal file
170
crates/openfang-kernel/src/scheduler.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
//! Agent scheduler — manages agent execution and resource tracking.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::{AgentId, ResourceQuota};
|
||||
use openfang_types::error::{OpenFangError, OpenFangResult};
|
||||
use openfang_types::message::TokenUsage;
|
||||
use std::time::Instant;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::debug;
|
||||
|
||||
/// Tracks resource usage for an agent with a rolling hourly window.
|
||||
#[derive(Debug)]
|
||||
pub struct UsageTracker {
|
||||
/// Total tokens consumed within the current window.
|
||||
pub total_tokens: u64,
|
||||
/// Total tool calls made within the current window.
|
||||
pub tool_calls: u64,
|
||||
/// Start of the current usage window.
|
||||
pub window_start: Instant,
|
||||
}
|
||||
|
||||
impl Default for UsageTracker {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
total_tokens: 0,
|
||||
tool_calls: 0,
|
||||
window_start: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UsageTracker {
|
||||
/// Reset counters if the current window has expired (1 hour).
|
||||
fn reset_if_expired(&mut self) {
|
||||
if self.window_start.elapsed() >= std::time::Duration::from_secs(3600) {
|
||||
self.total_tokens = 0;
|
||||
self.tool_calls = 0;
|
||||
self.window_start = Instant::now();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The agent scheduler manages execution ordering and resource quotas.
|
||||
pub struct AgentScheduler {
|
||||
/// Resource quotas per agent.
|
||||
quotas: DashMap<AgentId, ResourceQuota>,
|
||||
/// Usage tracking per agent.
|
||||
usage: DashMap<AgentId, UsageTracker>,
|
||||
/// Active task handles per agent.
|
||||
tasks: DashMap<AgentId, JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl AgentScheduler {
|
||||
/// Create a new scheduler.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
quotas: DashMap::new(),
|
||||
usage: DashMap::new(),
|
||||
tasks: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register an agent with its resource quota.
|
||||
pub fn register(&self, agent_id: AgentId, quota: ResourceQuota) {
|
||||
self.quotas.insert(agent_id, quota);
|
||||
self.usage.insert(agent_id, UsageTracker::default());
|
||||
}
|
||||
|
||||
/// Record token usage for an agent.
|
||||
pub fn record_usage(&self, agent_id: AgentId, usage: &TokenUsage) {
|
||||
if let Some(mut tracker) = self.usage.get_mut(&agent_id) {
|
||||
tracker.reset_if_expired();
|
||||
tracker.total_tokens += usage.total();
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an agent has exceeded its quota.
|
||||
pub fn check_quota(&self, agent_id: AgentId) -> OpenFangResult<()> {
|
||||
let quota = match self.quotas.get(&agent_id) {
|
||||
Some(q) => q.clone(),
|
||||
None => return Ok(()), // No quota = no limit
|
||||
};
|
||||
let mut tracker = match self.usage.get_mut(&agent_id) {
|
||||
Some(t) => t,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
// Reset the window if an hour has passed
|
||||
tracker.reset_if_expired();
|
||||
|
||||
if quota.max_llm_tokens_per_hour > 0
|
||||
&& tracker.total_tokens > quota.max_llm_tokens_per_hour
|
||||
{
|
||||
return Err(OpenFangError::QuotaExceeded(format!(
|
||||
"Token limit exceeded: {} / {}",
|
||||
tracker.total_tokens, quota.max_llm_tokens_per_hour
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Abort an agent's active task.
|
||||
pub fn abort_task(&self, agent_id: AgentId) {
|
||||
if let Some((_, handle)) = self.tasks.remove(&agent_id) {
|
||||
handle.abort();
|
||||
debug!(agent = %agent_id, "Aborted agent task");
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove an agent from the scheduler.
|
||||
pub fn unregister(&self, agent_id: AgentId) {
|
||||
self.abort_task(agent_id);
|
||||
self.quotas.remove(&agent_id);
|
||||
self.usage.remove(&agent_id);
|
||||
}
|
||||
|
||||
/// Get usage stats for an agent.
|
||||
pub fn get_usage(&self, agent_id: AgentId) -> Option<(u64, u64)> {
|
||||
self.usage
|
||||
.get(&agent_id)
|
||||
.map(|t| (t.total_tokens, t.tool_calls))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentScheduler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_record_usage() {
|
||||
let scheduler = AgentScheduler::new();
|
||||
let id = AgentId::new();
|
||||
scheduler.register(id, ResourceQuota::default());
|
||||
scheduler.record_usage(
|
||||
id,
|
||||
&TokenUsage {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
},
|
||||
);
|
||||
let (tokens, _) = scheduler.get_usage(id).unwrap();
|
||||
assert_eq!(tokens, 150);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quota_check() {
|
||||
let scheduler = AgentScheduler::new();
|
||||
let id = AgentId::new();
|
||||
let quota = ResourceQuota {
|
||||
max_llm_tokens_per_hour: 100,
|
||||
..Default::default()
|
||||
};
|
||||
scheduler.register(id, quota);
|
||||
scheduler.record_usage(
|
||||
id,
|
||||
&TokenUsage {
|
||||
input_tokens: 60,
|
||||
output_tokens: 50,
|
||||
},
|
||||
);
|
||||
assert!(scheduler.check_quota(id).is_err());
|
||||
}
|
||||
}
|
||||
227
crates/openfang-kernel/src/supervisor.rs
Normal file
227
crates/openfang-kernel/src/supervisor.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
//! Process supervision — graceful shutdown, signal handling, and health monitoring.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::AgentId;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Shutdown signal manager with health monitoring.
|
||||
pub struct Supervisor {
|
||||
/// Send side of the shutdown signal.
|
||||
shutdown_tx: watch::Sender<bool>,
|
||||
/// Receive side of the shutdown signal (clonable).
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Restart count (how many times agents have been restarted).
|
||||
restart_count: AtomicU64,
|
||||
/// Total panics caught across all agents.
|
||||
panic_count: AtomicU64,
|
||||
/// Per-agent restart counts for enforcing max_restarts.
|
||||
agent_restarts: DashMap<AgentId, u32>,
|
||||
}
|
||||
|
||||
impl Supervisor {
|
||||
/// Create a new supervisor.
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = watch::channel(false);
|
||||
Self {
|
||||
shutdown_tx: tx,
|
||||
shutdown_rx: rx,
|
||||
restart_count: AtomicU64::new(0),
|
||||
panic_count: AtomicU64::new(0),
|
||||
agent_restarts: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a receiver that will be notified on shutdown.
|
||||
pub fn subscribe(&self) -> watch::Receiver<bool> {
|
||||
self.shutdown_rx.clone()
|
||||
}
|
||||
|
||||
/// Trigger a graceful shutdown.
|
||||
pub fn shutdown(&self) {
|
||||
info!("Supervisor: initiating graceful shutdown");
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
}
|
||||
|
||||
/// Check if shutdown has been requested.
|
||||
pub fn is_shutting_down(&self) -> bool {
|
||||
*self.shutdown_rx.borrow()
|
||||
}
|
||||
|
||||
/// Record that a panic was caught during agent execution.
|
||||
pub fn record_panic(&self) {
|
||||
self.panic_count.fetch_add(1, Ordering::Relaxed);
|
||||
warn!(
|
||||
total_panics = self.panic_count.load(Ordering::Relaxed),
|
||||
"Agent panic recorded"
|
||||
);
|
||||
}
|
||||
|
||||
/// Record that an agent was restarted.
|
||||
pub fn record_restart(&self) {
|
||||
self.restart_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Get the total number of panics caught.
|
||||
pub fn panic_count(&self) -> u64 {
|
||||
self.panic_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get the total number of restarts.
|
||||
pub fn restart_count(&self) -> u64 {
|
||||
self.restart_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Record a restart for a specific agent and check if limit is exceeded.
|
||||
///
|
||||
/// Returns Ok(restart_count) if within limit, or Err(count) if limit exceeded.
|
||||
pub fn record_agent_restart(&self, agent_id: AgentId, max_restarts: u32) -> Result<u32, u32> {
|
||||
let mut count = self.agent_restarts.entry(agent_id).or_insert(0);
|
||||
*count += 1;
|
||||
self.record_restart();
|
||||
|
||||
if max_restarts > 0 && *count > max_restarts {
|
||||
warn!(
|
||||
agent = %agent_id,
|
||||
restarts = *count,
|
||||
max = max_restarts,
|
||||
"Agent exceeded max restart limit"
|
||||
);
|
||||
Err(*count)
|
||||
} else {
|
||||
Ok(*count)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the restart count for a specific agent.
|
||||
pub fn agent_restart_count(&self, agent_id: AgentId) -> u32 {
|
||||
self.agent_restarts.get(&agent_id).map(|r| *r).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Reset restart counter for an agent (e.g., on manual intervention).
|
||||
pub fn reset_agent_restarts(&self, agent_id: AgentId) {
|
||||
self.agent_restarts.remove(&agent_id);
|
||||
}
|
||||
|
||||
/// Get a health summary.
|
||||
pub fn health(&self) -> SupervisorHealth {
|
||||
SupervisorHealth {
|
||||
is_shutting_down: self.is_shutting_down(),
|
||||
panic_count: self.panic_count(),
|
||||
restart_count: self.restart_count(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Supervisor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Health report from the supervisor.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SupervisorHealth {
|
||||
pub is_shutting_down: bool,
|
||||
pub panic_count: u64,
|
||||
pub restart_count: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_shutdown() {
|
||||
let supervisor = Supervisor::new();
|
||||
assert!(!supervisor.is_shutting_down());
|
||||
supervisor.shutdown();
|
||||
assert!(supervisor.is_shutting_down());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscribe() {
|
||||
let supervisor = Supervisor::new();
|
||||
let rx = supervisor.subscribe();
|
||||
assert!(!*rx.borrow());
|
||||
supervisor.shutdown();
|
||||
assert!(rx.has_changed().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_panic_tracking() {
|
||||
let supervisor = Supervisor::new();
|
||||
assert_eq!(supervisor.panic_count(), 0);
|
||||
supervisor.record_panic();
|
||||
supervisor.record_panic();
|
||||
assert_eq!(supervisor.panic_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_restart_tracking() {
|
||||
let supervisor = Supervisor::new();
|
||||
assert_eq!(supervisor.restart_count(), 0);
|
||||
supervisor.record_restart();
|
||||
assert_eq!(supervisor.restart_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_health() {
|
||||
let supervisor = Supervisor::new();
|
||||
let health = supervisor.health();
|
||||
assert!(!health.is_shutting_down);
|
||||
assert_eq!(health.panic_count, 0);
|
||||
assert_eq!(health.restart_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_restart_within_limit() {
|
||||
let supervisor = Supervisor::new();
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
// Allow up to 3 restarts
|
||||
assert!(supervisor.record_agent_restart(agent_id, 3).is_ok());
|
||||
assert_eq!(supervisor.agent_restart_count(agent_id), 1);
|
||||
assert!(supervisor.record_agent_restart(agent_id, 3).is_ok());
|
||||
assert!(supervisor.record_agent_restart(agent_id, 3).is_ok());
|
||||
assert_eq!(supervisor.agent_restart_count(agent_id), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_restart_exceeds_limit() {
|
||||
let supervisor = Supervisor::new();
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
assert!(supervisor.record_agent_restart(agent_id, 2).is_ok());
|
||||
assert!(supervisor.record_agent_restart(agent_id, 2).is_ok());
|
||||
// 3rd restart exceeds max_restarts=2
|
||||
let result = supervisor.record_agent_restart(agent_id, 2);
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_restart_zero_limit_unlimited() {
|
||||
let supervisor = Supervisor::new();
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
// max_restarts=0 means unlimited
|
||||
for _ in 0..100 {
|
||||
assert!(supervisor.record_agent_restart(agent_id, 0).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_agent_restarts() {
|
||||
let supervisor = Supervisor::new();
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
supervisor.record_agent_restart(agent_id, 10).unwrap();
|
||||
supervisor.record_agent_restart(agent_id, 10).unwrap();
|
||||
assert_eq!(supervisor.agent_restart_count(agent_id), 2);
|
||||
|
||||
supervisor.reset_agent_restarts(agent_id);
|
||||
assert_eq!(supervisor.agent_restart_count(agent_id), 0);
|
||||
}
|
||||
}
|
||||
511
crates/openfang-kernel/src/triggers.rs
Normal file
511
crates/openfang-kernel/src/triggers.rs
Normal file
@@ -0,0 +1,511 @@
|
||||
//! Event-driven agent triggers — agents auto-activate when events match patterns.
|
||||
//!
|
||||
//! Agents register triggers that describe which events should wake them.
|
||||
//! When a matching event arrives on the EventBus, the trigger system
|
||||
//! sends the event content as a message to the subscribing agent.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::AgentId;
|
||||
use openfang_types::event::{Event, EventPayload, LifecycleEvent, SystemEvent};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for a trigger.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct TriggerId(pub Uuid);
|
||||
|
||||
impl TriggerId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TriggerId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TriggerId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// What kind of events a trigger matches on.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TriggerPattern {
|
||||
/// Match any lifecycle event (agent spawned, started, terminated, etc.).
|
||||
Lifecycle,
|
||||
/// Match when a specific agent is spawned.
|
||||
AgentSpawned { name_pattern: String },
|
||||
/// Match when any agent is terminated.
|
||||
AgentTerminated,
|
||||
/// Match any system event.
|
||||
System,
|
||||
/// Match a specific system event by keyword.
|
||||
SystemKeyword { keyword: String },
|
||||
/// Match any memory update event.
|
||||
MemoryUpdate,
|
||||
/// Match memory updates for a specific key pattern.
|
||||
MemoryKeyPattern { key_pattern: String },
|
||||
/// Match all events (wildcard).
|
||||
All,
|
||||
/// Match custom events by content substring.
|
||||
ContentMatch { substring: String },
|
||||
}
|
||||
|
||||
/// A registered trigger definition.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Trigger {
|
||||
/// Unique trigger ID.
|
||||
pub id: TriggerId,
|
||||
/// Which agent owns this trigger.
|
||||
pub agent_id: AgentId,
|
||||
/// The event pattern to match.
|
||||
pub pattern: TriggerPattern,
|
||||
/// Prompt template to send when triggered. Use `{{event}}` for event description.
|
||||
pub prompt_template: String,
|
||||
/// Whether this trigger is currently active.
|
||||
pub enabled: bool,
|
||||
/// When this trigger was created.
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// How many times this trigger has fired.
|
||||
pub fire_count: u64,
|
||||
/// Maximum number of times this trigger can fire (0 = unlimited).
|
||||
pub max_fires: u64,
|
||||
}
|
||||
|
||||
/// The trigger engine manages event-to-agent routing.
|
||||
pub struct TriggerEngine {
|
||||
/// All registered triggers.
|
||||
triggers: DashMap<TriggerId, Trigger>,
|
||||
/// Index: agent_id → list of trigger IDs belonging to that agent.
|
||||
agent_triggers: DashMap<AgentId, Vec<TriggerId>>,
|
||||
}
|
||||
|
||||
impl TriggerEngine {
|
||||
/// Create a new trigger engine.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
triggers: DashMap::new(),
|
||||
agent_triggers: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new trigger.
|
||||
pub fn register(
|
||||
&self,
|
||||
agent_id: AgentId,
|
||||
pattern: TriggerPattern,
|
||||
prompt_template: String,
|
||||
max_fires: u64,
|
||||
) -> TriggerId {
|
||||
let trigger = Trigger {
|
||||
id: TriggerId::new(),
|
||||
agent_id,
|
||||
pattern,
|
||||
prompt_template,
|
||||
enabled: true,
|
||||
created_at: Utc::now(),
|
||||
fire_count: 0,
|
||||
max_fires,
|
||||
};
|
||||
let id = trigger.id;
|
||||
self.triggers.insert(id, trigger);
|
||||
self.agent_triggers.entry(agent_id).or_default().push(id);
|
||||
|
||||
info!(trigger_id = %id, agent_id = %agent_id, "Trigger registered");
|
||||
id
|
||||
}
|
||||
|
||||
/// Remove a trigger.
|
||||
pub fn remove(&self, trigger_id: TriggerId) -> bool {
|
||||
if let Some((_, trigger)) = self.triggers.remove(&trigger_id) {
|
||||
if let Some(mut list) = self.agent_triggers.get_mut(&trigger.agent_id) {
|
||||
list.retain(|id| *id != trigger_id);
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all triggers for an agent.
|
||||
pub fn remove_agent_triggers(&self, agent_id: AgentId) {
|
||||
if let Some((_, trigger_ids)) = self.agent_triggers.remove(&agent_id) {
|
||||
for id in trigger_ids {
|
||||
self.triggers.remove(&id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable or disable a trigger. Returns true if the trigger was found.
|
||||
pub fn set_enabled(&self, trigger_id: TriggerId, enabled: bool) -> bool {
|
||||
if let Some(mut t) = self.triggers.get_mut(&trigger_id) {
|
||||
t.enabled = enabled;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// List all triggers for an agent.
|
||||
pub fn list_agent_triggers(&self, agent_id: AgentId) -> Vec<Trigger> {
|
||||
self.agent_triggers
|
||||
.get(&agent_id)
|
||||
.map(|ids| {
|
||||
ids.iter()
|
||||
.filter_map(|id| self.triggers.get(id).map(|t| t.clone()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// List all registered triggers.
|
||||
pub fn list_all(&self) -> Vec<Trigger> {
|
||||
self.triggers.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
|
||||
/// Evaluate an event against all triggers. Returns a list of
|
||||
/// (agent_id, message_to_send) pairs for matching triggers.
|
||||
pub fn evaluate(&self, event: &Event) -> Vec<(AgentId, String)> {
|
||||
let event_description = describe_event(event);
|
||||
let mut matches = Vec::new();
|
||||
|
||||
for mut entry in self.triggers.iter_mut() {
|
||||
let trigger = entry.value_mut();
|
||||
|
||||
if !trigger.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check max fires
|
||||
if trigger.max_fires > 0 && trigger.fire_count >= trigger.max_fires {
|
||||
trigger.enabled = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches_pattern(&trigger.pattern, event, &event_description) {
|
||||
let message = trigger
|
||||
.prompt_template
|
||||
.replace("{{event}}", &event_description);
|
||||
matches.push((trigger.agent_id, message));
|
||||
trigger.fire_count += 1;
|
||||
|
||||
debug!(
|
||||
trigger_id = %trigger.id,
|
||||
agent_id = %trigger.agent_id,
|
||||
fire_count = trigger.fire_count,
|
||||
"Trigger fired"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
matches
|
||||
}
|
||||
|
||||
/// Get a trigger by ID.
|
||||
pub fn get(&self, trigger_id: TriggerId) -> Option<Trigger> {
|
||||
self.triggers.get(&trigger_id).map(|t| t.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TriggerEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an event matches a trigger pattern.
|
||||
fn matches_pattern(pattern: &TriggerPattern, event: &Event, description: &str) -> bool {
|
||||
match pattern {
|
||||
TriggerPattern::All => true,
|
||||
TriggerPattern::Lifecycle => {
|
||||
matches!(event.payload, EventPayload::Lifecycle(_))
|
||||
}
|
||||
TriggerPattern::AgentSpawned { name_pattern } => {
|
||||
if let EventPayload::Lifecycle(LifecycleEvent::Spawned { name, .. }) = &event.payload {
|
||||
name.contains(name_pattern.as_str()) || name_pattern == "*"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
TriggerPattern::AgentTerminated => matches!(
|
||||
event.payload,
|
||||
EventPayload::Lifecycle(LifecycleEvent::Terminated { .. })
|
||||
| EventPayload::Lifecycle(LifecycleEvent::Crashed { .. })
|
||||
),
|
||||
TriggerPattern::System => {
|
||||
matches!(event.payload, EventPayload::System(_))
|
||||
}
|
||||
TriggerPattern::SystemKeyword { keyword } => {
|
||||
if let EventPayload::System(se) = &event.payload {
|
||||
let se_str = format!("{:?}", se).to_lowercase();
|
||||
se_str.contains(&keyword.to_lowercase())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
TriggerPattern::MemoryUpdate => {
|
||||
matches!(event.payload, EventPayload::MemoryUpdate(_))
|
||||
}
|
||||
TriggerPattern::MemoryKeyPattern { key_pattern } => {
|
||||
if let EventPayload::MemoryUpdate(delta) = &event.payload {
|
||||
delta.key.contains(key_pattern.as_str()) || key_pattern == "*"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
TriggerPattern::ContentMatch { substring } => description
|
||||
.to_lowercase()
|
||||
.contains(&substring.to_lowercase()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a human-readable description of an event for use in prompts.
|
||||
fn describe_event(event: &Event) -> String {
|
||||
match &event.payload {
|
||||
EventPayload::Message(msg) => {
|
||||
format!("Message from {:?}: {}", msg.role, msg.content)
|
||||
}
|
||||
EventPayload::ToolResult(tr) => {
|
||||
format!(
|
||||
"Tool '{}' {} ({}ms): {}",
|
||||
tr.tool_id,
|
||||
if tr.success { "succeeded" } else { "failed" },
|
||||
tr.execution_time_ms,
|
||||
&tr.content[..tr.content.len().min(200)]
|
||||
)
|
||||
}
|
||||
EventPayload::MemoryUpdate(delta) => {
|
||||
format!(
|
||||
"Memory {:?} on key '{}' for agent {}",
|
||||
delta.operation, delta.key, delta.agent_id
|
||||
)
|
||||
}
|
||||
EventPayload::Lifecycle(le) => match le {
|
||||
LifecycleEvent::Spawned { agent_id, name } => {
|
||||
format!("Agent '{name}' (id: {agent_id}) was spawned")
|
||||
}
|
||||
LifecycleEvent::Started { agent_id } => {
|
||||
format!("Agent {agent_id} started")
|
||||
}
|
||||
LifecycleEvent::Suspended { agent_id } => {
|
||||
format!("Agent {agent_id} suspended")
|
||||
}
|
||||
LifecycleEvent::Resumed { agent_id } => {
|
||||
format!("Agent {agent_id} resumed")
|
||||
}
|
||||
LifecycleEvent::Terminated { agent_id, reason } => {
|
||||
format!("Agent {agent_id} terminated: {reason}")
|
||||
}
|
||||
LifecycleEvent::Crashed { agent_id, error } => {
|
||||
format!("Agent {agent_id} crashed: {error}")
|
||||
}
|
||||
},
|
||||
EventPayload::Network(ne) => {
|
||||
format!("Network event: {:?}", ne)
|
||||
}
|
||||
EventPayload::System(se) => match se {
|
||||
SystemEvent::KernelStarted => "Kernel started".to_string(),
|
||||
SystemEvent::KernelStopping => "Kernel stopping".to_string(),
|
||||
SystemEvent::QuotaWarning {
|
||||
agent_id,
|
||||
resource,
|
||||
usage_percent,
|
||||
} => format!("Quota warning: agent {agent_id}, {resource} at {usage_percent:.1}%"),
|
||||
SystemEvent::HealthCheck { status } => {
|
||||
format!("Health check: {status}")
|
||||
}
|
||||
SystemEvent::QuotaEnforced {
|
||||
agent_id,
|
||||
spent,
|
||||
limit,
|
||||
} => {
|
||||
format!("Quota enforced: agent {agent_id}, spent ${spent:.4} / ${limit:.4}")
|
||||
}
|
||||
SystemEvent::ModelRouted {
|
||||
agent_id,
|
||||
complexity,
|
||||
model,
|
||||
} => {
|
||||
format!("Model routed: agent {agent_id}, complexity={complexity}, model={model}")
|
||||
}
|
||||
SystemEvent::UserAction {
|
||||
user_id,
|
||||
action,
|
||||
result,
|
||||
} => {
|
||||
format!("User action: {user_id} {action} -> {result}")
|
||||
}
|
||||
SystemEvent::HealthCheckFailed {
|
||||
agent_id,
|
||||
unresponsive_secs,
|
||||
} => {
|
||||
format!(
|
||||
"Health check failed: agent {agent_id}, unresponsive for {unresponsive_secs}s"
|
||||
)
|
||||
}
|
||||
},
|
||||
EventPayload::Custom(data) => {
|
||||
format!("Custom event ({} bytes)", data.len())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use openfang_types::event::*;
|
||||
|
||||
#[test]
|
||||
fn test_register_trigger() {
|
||||
let engine = TriggerEngine::new();
|
||||
let agent_id = AgentId::new();
|
||||
let id = engine.register(
|
||||
agent_id,
|
||||
TriggerPattern::All,
|
||||
"Event occurred: {{event}}".to_string(),
|
||||
0,
|
||||
);
|
||||
assert!(engine.get(id).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_lifecycle() {
|
||||
let engine = TriggerEngine::new();
|
||||
let watcher = AgentId::new();
|
||||
engine.register(
|
||||
watcher,
|
||||
TriggerPattern::Lifecycle,
|
||||
"Lifecycle: {{event}}".to_string(),
|
||||
0,
|
||||
);
|
||||
|
||||
let event = Event::new(
|
||||
AgentId::new(),
|
||||
EventTarget::Broadcast,
|
||||
EventPayload::Lifecycle(LifecycleEvent::Spawned {
|
||||
agent_id: AgentId::new(),
|
||||
name: "new-agent".to_string(),
|
||||
}),
|
||||
);
|
||||
|
||||
let matches = engine.evaluate(&event);
|
||||
assert_eq!(matches.len(), 1);
|
||||
assert_eq!(matches[0].0, watcher);
|
||||
assert!(matches[0].1.contains("new-agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_agent_spawned_pattern() {
|
||||
let engine = TriggerEngine::new();
|
||||
let watcher = AgentId::new();
|
||||
engine.register(
|
||||
watcher,
|
||||
TriggerPattern::AgentSpawned {
|
||||
name_pattern: "coder".to_string(),
|
||||
},
|
||||
"Coder spawned: {{event}}".to_string(),
|
||||
0,
|
||||
);
|
||||
|
||||
// This should match
|
||||
let event = Event::new(
|
||||
AgentId::new(),
|
||||
EventTarget::Broadcast,
|
||||
EventPayload::Lifecycle(LifecycleEvent::Spawned {
|
||||
agent_id: AgentId::new(),
|
||||
name: "coder".to_string(),
|
||||
}),
|
||||
);
|
||||
assert_eq!(engine.evaluate(&event).len(), 1);
|
||||
|
||||
// This should NOT match
|
||||
let event2 = Event::new(
|
||||
AgentId::new(),
|
||||
EventTarget::Broadcast,
|
||||
EventPayload::Lifecycle(LifecycleEvent::Spawned {
|
||||
agent_id: AgentId::new(),
|
||||
name: "researcher".to_string(),
|
||||
}),
|
||||
);
|
||||
assert_eq!(engine.evaluate(&event2).len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_fires() {
|
||||
let engine = TriggerEngine::new();
|
||||
let agent_id = AgentId::new();
|
||||
engine.register(
|
||||
agent_id,
|
||||
TriggerPattern::All,
|
||||
"Event: {{event}}".to_string(),
|
||||
2, // max 2 fires
|
||||
);
|
||||
|
||||
let event = Event::new(
|
||||
AgentId::new(),
|
||||
EventTarget::Broadcast,
|
||||
EventPayload::System(SystemEvent::HealthCheck {
|
||||
status: "ok".to_string(),
|
||||
}),
|
||||
);
|
||||
|
||||
// First two should match
|
||||
assert_eq!(engine.evaluate(&event).len(), 1);
|
||||
assert_eq!(engine.evaluate(&event).len(), 1);
|
||||
// Third should not
|
||||
assert_eq!(engine.evaluate(&event).len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_trigger() {
|
||||
let engine = TriggerEngine::new();
|
||||
let agent_id = AgentId::new();
|
||||
let id = engine.register(agent_id, TriggerPattern::All, "msg".to_string(), 0);
|
||||
assert!(engine.remove(id));
|
||||
assert!(engine.get(id).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_agent_triggers() {
|
||||
let engine = TriggerEngine::new();
|
||||
let agent_id = AgentId::new();
|
||||
engine.register(agent_id, TriggerPattern::All, "a".to_string(), 0);
|
||||
engine.register(agent_id, TriggerPattern::System, "b".to_string(), 0);
|
||||
assert_eq!(engine.list_agent_triggers(agent_id).len(), 2);
|
||||
|
||||
engine.remove_agent_triggers(agent_id);
|
||||
assert_eq!(engine.list_agent_triggers(agent_id).len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_match() {
|
||||
let engine = TriggerEngine::new();
|
||||
let agent_id = AgentId::new();
|
||||
engine.register(
|
||||
agent_id,
|
||||
TriggerPattern::ContentMatch {
|
||||
substring: "quota".to_string(),
|
||||
},
|
||||
"Alert: {{event}}".to_string(),
|
||||
0,
|
||||
);
|
||||
|
||||
let event = Event::new(
|
||||
AgentId::new(),
|
||||
EventTarget::System,
|
||||
EventPayload::System(SystemEvent::QuotaWarning {
|
||||
agent_id: AgentId::new(),
|
||||
resource: "tokens".to_string(),
|
||||
usage_percent: 85.0,
|
||||
}),
|
||||
);
|
||||
assert_eq!(engine.evaluate(&event).len(), 1);
|
||||
}
|
||||
}
|
||||
345
crates/openfang-kernel/src/whatsapp_gateway.rs
Normal file
345
crates/openfang-kernel/src/whatsapp_gateway.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
//! WhatsApp Web gateway — embedded Node.js process management.
|
||||
//!
|
||||
//! Embeds the gateway JS at compile time, extracts it to `~/.openfang/whatsapp-gateway/`,
|
||||
//! runs `npm install` if needed, and spawns `node index.js` as a managed child process
|
||||
//! that auto-restarts on crash.
|
||||
|
||||
use crate::config::openfang_home;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Gateway source files embedded at compile time.
|
||||
const GATEWAY_INDEX_JS: &str =
|
||||
include_str!("../../../packages/whatsapp-gateway/index.js");
|
||||
const GATEWAY_PACKAGE_JSON: &str =
|
||||
include_str!("../../../packages/whatsapp-gateway/package.json");
|
||||
|
||||
/// Default port for the WhatsApp Web gateway.
|
||||
const DEFAULT_GATEWAY_PORT: u16 = 3009;
|
||||
|
||||
/// Maximum restart attempts before giving up.
|
||||
const MAX_RESTARTS: u32 = 3;
|
||||
|
||||
/// Restart backoff delays in seconds: 5s, 10s, 20s.
|
||||
const RESTART_DELAYS: [u64; 3] = [5, 10, 20];
|
||||
|
||||
/// Get the gateway installation directory.
|
||||
fn gateway_dir() -> PathBuf {
|
||||
openfang_home().join("whatsapp-gateway")
|
||||
}
|
||||
|
||||
/// Compute a simple hash of content for change detection.
|
||||
fn content_hash(content: &str) -> String {
|
||||
// Use a simple FNV-style hash — no crypto needed, just change detection.
|
||||
let mut hash: u64 = 0xcbf29ce484222325;
|
||||
for byte in content.as_bytes() {
|
||||
hash ^= *byte as u64;
|
||||
hash = hash.wrapping_mul(0x100000001b3);
|
||||
}
|
||||
format!("{hash:016x}")
|
||||
}
|
||||
|
||||
/// Write a file only if its content hash differs from the existing file.
|
||||
/// Returns `true` if the file was written (content changed).
|
||||
fn write_if_changed(path: &std::path::Path, content: &str) -> std::io::Result<bool> {
|
||||
let hash_path = path.with_extension("hash");
|
||||
let new_hash = content_hash(content);
|
||||
|
||||
// Check existing hash
|
||||
if let Ok(existing_hash) = std::fs::read_to_string(&hash_path) {
|
||||
if existing_hash.trim() == new_hash {
|
||||
return Ok(false); // No change
|
||||
}
|
||||
}
|
||||
|
||||
std::fs::write(path, content)?;
|
||||
std::fs::write(&hash_path, &new_hash)?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Ensure the gateway files are extracted and npm dependencies installed.
|
||||
///
|
||||
/// Returns the gateway directory path on success, or an error message.
|
||||
async fn ensure_gateway_installed() -> Result<PathBuf, String> {
|
||||
let dir = gateway_dir();
|
||||
std::fs::create_dir_all(&dir).map_err(|e| format!("Failed to create gateway dir: {e}"))?;
|
||||
|
||||
let index_path = dir.join("index.js");
|
||||
let package_path = dir.join("package.json");
|
||||
|
||||
// Write files only if content changed (avoids unnecessary npm install)
|
||||
let index_changed =
|
||||
write_if_changed(&index_path, GATEWAY_INDEX_JS).map_err(|e| format!("Write index.js: {e}"))?;
|
||||
let package_changed = write_if_changed(&package_path, GATEWAY_PACKAGE_JSON)
|
||||
.map_err(|e| format!("Write package.json: {e}"))?;
|
||||
|
||||
let node_modules = dir.join("node_modules");
|
||||
let needs_install = !node_modules.exists() || package_changed;
|
||||
|
||||
if needs_install {
|
||||
info!("Installing WhatsApp gateway npm dependencies...");
|
||||
|
||||
// Determine npm command (npm.cmd on Windows, npm elsewhere)
|
||||
let npm_cmd = if cfg!(windows) { "npm.cmd" } else { "npm" };
|
||||
|
||||
let output = tokio::process::Command::new(npm_cmd)
|
||||
.arg("install")
|
||||
.arg("--production")
|
||||
.current_dir(&dir)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| format!("npm install failed to start: {e}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(format!("npm install failed: {stderr}"));
|
||||
}
|
||||
|
||||
info!("WhatsApp gateway npm dependencies installed");
|
||||
} else if index_changed {
|
||||
info!("WhatsApp gateway index.js updated (binary upgrade)");
|
||||
}
|
||||
|
||||
Ok(dir)
|
||||
}
|
||||
|
||||
/// Check if Node.js is available on the system.
|
||||
async fn node_available() -> bool {
|
||||
let node_cmd = if cfg!(windows) { "node.exe" } else { "node" };
|
||||
tokio::process::Command::new(node_cmd)
|
||||
.arg("--version")
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.status()
|
||||
.await
|
||||
.map(|s| s.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Start the WhatsApp Web gateway as a managed child process.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Checks if Node.js is available
|
||||
/// 2. Extracts and installs the gateway files
|
||||
/// 3. Spawns `node index.js` with appropriate env vars
|
||||
/// 4. Sets `WHATSAPP_WEB_GATEWAY_URL` so the daemon finds it
|
||||
/// 5. Monitors the process and restarts on crash (up to 3 times)
|
||||
///
|
||||
/// The PID is stored in the kernel's `whatsapp_gateway_pid` for shutdown cleanup.
|
||||
pub async fn start_whatsapp_gateway(kernel: &Arc<super::kernel::OpenFangKernel>) {
|
||||
// Only start if WhatsApp is configured
|
||||
let wa_config = match &kernel.config.channels.whatsapp {
|
||||
Some(cfg) => cfg.clone(),
|
||||
None => return,
|
||||
};
|
||||
|
||||
// Check for Node.js
|
||||
if !node_available().await {
|
||||
warn!(
|
||||
"WhatsApp Web gateway requires Node.js >= 18 but `node` was not found. \
|
||||
Install Node.js to enable WhatsApp Web integration."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract and install
|
||||
let gateway_path = match ensure_gateway_installed().await {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
warn!("WhatsApp Web gateway setup failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let port = DEFAULT_GATEWAY_PORT;
|
||||
let api_listen = &kernel.config.api_listen;
|
||||
let openfang_url = format!("http://{api_listen}");
|
||||
let default_agent = wa_config
|
||||
.default_agent
|
||||
.as_deref()
|
||||
.unwrap_or("assistant")
|
||||
.to_string();
|
||||
|
||||
// Auto-set the env var so the rest of the system finds the gateway
|
||||
std::env::set_var("WHATSAPP_WEB_GATEWAY_URL", format!("http://127.0.0.1:{port}"));
|
||||
info!("WHATSAPP_WEB_GATEWAY_URL set to http://127.0.0.1:{port}");
|
||||
|
||||
// Spawn with crash monitoring
|
||||
let kernel_weak = Arc::downgrade(kernel);
|
||||
let gateway_pid = Arc::clone(&kernel.whatsapp_gateway_pid);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut restarts = 0u32;
|
||||
|
||||
loop {
|
||||
let node_cmd = if cfg!(windows) { "node.exe" } else { "node" };
|
||||
|
||||
info!("Starting WhatsApp Web gateway (attempt {})", restarts + 1);
|
||||
|
||||
let child = tokio::process::Command::new(node_cmd)
|
||||
.arg("index.js")
|
||||
.current_dir(&gateway_path)
|
||||
.env("WHATSAPP_GATEWAY_PORT", port.to_string())
|
||||
.env("OPENFANG_URL", &openfang_url)
|
||||
.env("OPENFANG_DEFAULT_AGENT", &default_agent)
|
||||
.stdout(std::process::Stdio::inherit())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.spawn();
|
||||
|
||||
let mut child = match child {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
warn!("Failed to spawn WhatsApp gateway: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Store PID for shutdown cleanup
|
||||
if let Some(pid) = child.id() {
|
||||
if let Ok(mut guard) = gateway_pid.lock() {
|
||||
*guard = Some(pid);
|
||||
}
|
||||
info!("WhatsApp Web gateway started (PID {pid})");
|
||||
}
|
||||
|
||||
// Wait for process exit
|
||||
match child.wait().await {
|
||||
Ok(status) => {
|
||||
// Clear stored PID
|
||||
if let Ok(mut guard) = gateway_pid.lock() {
|
||||
*guard = None;
|
||||
}
|
||||
|
||||
// Check if kernel is still alive (not shutting down)
|
||||
let kernel = match kernel_weak.upgrade() {
|
||||
Some(k) => k,
|
||||
None => {
|
||||
info!("WhatsApp gateway exited (kernel dropped)");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if kernel.supervisor.is_shutting_down() {
|
||||
info!("WhatsApp gateway stopped (daemon shutting down)");
|
||||
return;
|
||||
}
|
||||
|
||||
if status.success() {
|
||||
info!("WhatsApp gateway exited cleanly");
|
||||
return;
|
||||
}
|
||||
|
||||
warn!(
|
||||
"WhatsApp gateway crashed (exit: {status}), restart {}/{MAX_RESTARTS}",
|
||||
restarts + 1
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
if let Ok(mut guard) = gateway_pid.lock() {
|
||||
*guard = None;
|
||||
}
|
||||
warn!("WhatsApp gateway wait error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
restarts += 1;
|
||||
if restarts >= MAX_RESTARTS {
|
||||
warn!(
|
||||
"WhatsApp gateway exceeded max restarts ({MAX_RESTARTS}), giving up"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Backoff before restart
|
||||
let delay = RESTART_DELAYS
|
||||
.get(restarts as usize - 1)
|
||||
.copied()
|
||||
.unwrap_or(20);
|
||||
info!("Restarting WhatsApp gateway in {delay}s...");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(delay)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_embedded_files_not_empty() {
|
||||
assert!(!GATEWAY_INDEX_JS.is_empty());
|
||||
assert!(!GATEWAY_PACKAGE_JSON.is_empty());
|
||||
assert!(GATEWAY_INDEX_JS.contains("WhatsApp"));
|
||||
assert!(GATEWAY_PACKAGE_JSON.contains("@openfang/whatsapp-gateway"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_hash_deterministic() {
|
||||
let h1 = content_hash("hello world");
|
||||
let h2 = content_hash("hello world");
|
||||
assert_eq!(h1, h2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_hash_changes_on_different_input() {
|
||||
let h1 = content_hash("version 1");
|
||||
let h2 = content_hash("version 2");
|
||||
assert_ne!(h1, h2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gateway_dir_under_openfang_home() {
|
||||
let dir = gateway_dir();
|
||||
assert!(dir.ends_with("whatsapp-gateway"));
|
||||
assert!(dir
|
||||
.parent()
|
||||
.unwrap()
|
||||
.to_string_lossy()
|
||||
.contains(".openfang"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_if_changed_creates_new_file() {
|
||||
let tmp = std::env::temp_dir().join("openfang_test_gateway");
|
||||
let _ = std::fs::create_dir_all(&tmp);
|
||||
let path = tmp.join("test_write.js");
|
||||
let hash_path = path.with_extension("hash");
|
||||
|
||||
// Clean up any previous runs
|
||||
let _ = std::fs::remove_file(&path);
|
||||
let _ = std::fs::remove_file(&hash_path);
|
||||
|
||||
// First write should return true (new file)
|
||||
let changed = write_if_changed(&path, "console.log('v1')").unwrap();
|
||||
assert!(changed);
|
||||
assert!(path.exists());
|
||||
assert!(hash_path.exists());
|
||||
|
||||
// Same content should return false
|
||||
let changed = write_if_changed(&path, "console.log('v1')").unwrap();
|
||||
assert!(!changed);
|
||||
|
||||
// Different content should return true
|
||||
let changed = write_if_changed(&path, "console.log('v2')").unwrap();
|
||||
assert!(changed);
|
||||
|
||||
// Clean up
|
||||
let _ = std::fs::remove_file(&path);
|
||||
let _ = std::fs::remove_file(&hash_path);
|
||||
let _ = std::fs::remove_dir(&tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_gateway_port() {
|
||||
assert_eq!(DEFAULT_GATEWAY_PORT, 3009);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_restart_backoff_delays() {
|
||||
assert_eq!(RESTART_DELAYS, [5, 10, 20]);
|
||||
assert_eq!(MAX_RESTARTS, 3);
|
||||
}
|
||||
}
|
||||
435
crates/openfang-kernel/src/wizard.rs
Normal file
435
crates/openfang-kernel/src/wizard.rs
Normal file
@@ -0,0 +1,435 @@
|
||||
//! NL Auto-Bootstrap Wizard — generates agent configs from natural language.
|
||||
//!
|
||||
//! The wizard takes a user's natural language description of what they want
|
||||
//! an agent to do, extracts structured intent, and generates a complete
|
||||
//! agent manifest (TOML config) ready to spawn.
|
||||
|
||||
use openfang_types::agent::{
|
||||
AgentManifest, ManifestCapabilities, ModelConfig, Priority, ResourceQuota, ScheduleMode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// The extracted intent from a user's natural language description.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentIntent {
|
||||
/// Agent name (slug-style).
|
||||
pub name: String,
|
||||
/// Short description.
|
||||
pub description: String,
|
||||
/// What the agent should do (summarized task).
|
||||
pub task: String,
|
||||
/// What skills/tools it needs.
|
||||
pub skills: Vec<String>,
|
||||
/// Suggested model tier (simple, medium, complex).
|
||||
pub model_tier: String,
|
||||
/// Whether it runs on a schedule.
|
||||
pub scheduled: bool,
|
||||
/// Schedule expression (cron or interval).
|
||||
pub schedule: Option<String>,
|
||||
/// Suggested capabilities.
|
||||
pub capabilities: Vec<String>,
|
||||
}
|
||||
|
||||
/// A generated setup plan from the wizard.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SetupPlan {
|
||||
/// The extracted intent.
|
||||
pub intent: AgentIntent,
|
||||
/// Generated agent manifest (ready to write as TOML).
|
||||
pub manifest: AgentManifest,
|
||||
/// Skills to install (if not already installed).
|
||||
pub skills_to_install: Vec<String>,
|
||||
/// Human-readable summary of what will be created.
|
||||
pub summary: String,
|
||||
}
|
||||
|
||||
/// The setup wizard builds agent configurations from natural language.
|
||||
pub struct SetupWizard;
|
||||
|
||||
impl SetupWizard {
|
||||
/// Build a setup plan from an extracted intent.
|
||||
///
|
||||
/// This maps the intent into a concrete agent manifest with appropriate
|
||||
/// model configuration, capabilities, and schedule.
|
||||
pub fn build_plan(intent: AgentIntent) -> SetupPlan {
|
||||
// Map model tier to provider/model
|
||||
let (provider, model) = match intent.model_tier.as_str() {
|
||||
"simple" => ("groq", "llama-3.3-70b-versatile"),
|
||||
"complex" => ("anthropic", "claude-sonnet-4-20250514"),
|
||||
_ => ("groq", "llama-3.3-70b-versatile"), // medium default
|
||||
};
|
||||
|
||||
// Build capabilities from intent
|
||||
let mut caps = ManifestCapabilities::default();
|
||||
for cap in &intent.capabilities {
|
||||
match cap.as_str() {
|
||||
"web" | "network" => caps.network.push("*".to_string()),
|
||||
"file_read" => caps.tools.push("file_read".to_string()),
|
||||
"file_write" => caps.tools.push("file_write".to_string()),
|
||||
"file" | "files" => {
|
||||
for t in &["file_read", "file_write", "file_list"] {
|
||||
let s = t.to_string();
|
||||
if !caps.tools.contains(&s) {
|
||||
caps.tools.push(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
"shell" => caps.shell.push("*".to_string()),
|
||||
"memory" => {
|
||||
caps.memory_read.push("*".to_string());
|
||||
caps.memory_write.push("*".to_string());
|
||||
for t in &["memory_store", "memory_recall"] {
|
||||
let s = t.to_string();
|
||||
if !caps.tools.contains(&s) {
|
||||
caps.tools.push(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
"browser" | "browse" => {
|
||||
caps.network.push("*".to_string());
|
||||
for t in &[
|
||||
"browser_navigate",
|
||||
"browser_click",
|
||||
"browser_type",
|
||||
"browser_read_page",
|
||||
"browser_screenshot",
|
||||
"browser_close",
|
||||
] {
|
||||
let s = t.to_string();
|
||||
if !caps.tools.contains(&s) {
|
||||
caps.tools.push(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
other => caps.tools.push(other.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
// Add web_search + web_fetch if web/network capability is needed
|
||||
if caps.network.contains(&"*".to_string()) {
|
||||
for t in &["web_search", "web_fetch"] {
|
||||
let s = t.to_string();
|
||||
if !caps.tools.contains(&s) {
|
||||
caps.tools.push(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build schedule
|
||||
let schedule = if intent.scheduled {
|
||||
if let Some(ref cron) = intent.schedule {
|
||||
ScheduleMode::Periodic { cron: cron.clone() }
|
||||
} else {
|
||||
ScheduleMode::default()
|
||||
}
|
||||
} else {
|
||||
ScheduleMode::default()
|
||||
};
|
||||
|
||||
// Build system prompt — rich enough to guide the agent on its task.
|
||||
// The prompt_builder will wrap this with tool descriptions, memory protocol,
|
||||
// safety guidelines, etc. at execution time.
|
||||
let tool_hints = Self::tool_hints_for(&caps.tools);
|
||||
let system_prompt = format!(
|
||||
"You are {name}, an AI agent running inside the OpenFang Agent OS.\n\
|
||||
\n\
|
||||
YOUR TASK: {task}\n\
|
||||
\n\
|
||||
APPROACH:\n\
|
||||
- Understand the request fully before acting.\n\
|
||||
- Use your tools to accomplish the task rather than just describing what to do.\n\
|
||||
- If you need information, search for it. If you need to read a file, read it.\n\
|
||||
- Be concise in your responses. Lead with results, not process narration.\n\
|
||||
{tool_hints}",
|
||||
name = intent.name,
|
||||
task = intent.task,
|
||||
tool_hints = tool_hints,
|
||||
);
|
||||
|
||||
let manifest = AgentManifest {
|
||||
name: intent.name.clone(),
|
||||
version: "0.1.0".to_string(),
|
||||
description: intent.description.clone(),
|
||||
author: "wizard".to_string(),
|
||||
module: "builtin:chat".to_string(),
|
||||
schedule,
|
||||
model: ModelConfig {
|
||||
provider: provider.to_string(),
|
||||
model: model.to_string(),
|
||||
max_tokens: 4096,
|
||||
temperature: 0.7,
|
||||
system_prompt,
|
||||
api_key_env: None,
|
||||
base_url: None,
|
||||
},
|
||||
resources: ResourceQuota::default(),
|
||||
priority: Priority::default(),
|
||||
capabilities: caps,
|
||||
tools: HashMap::new(),
|
||||
skills: intent.skills.clone(),
|
||||
mcp_servers: vec![],
|
||||
metadata: HashMap::new(),
|
||||
tags: vec![],
|
||||
routing: None,
|
||||
autonomous: None,
|
||||
pinned_model: None,
|
||||
workspace: None,
|
||||
generate_identity_files: true,
|
||||
profile: None,
|
||||
fallback_models: vec![],
|
||||
exec_policy: None,
|
||||
};
|
||||
|
||||
let skills_to_install: Vec<String> = intent
|
||||
.skills
|
||||
.iter()
|
||||
.filter(|s| !s.is_empty())
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let summary = format!(
|
||||
"Agent '{}': {}\n Model: {}/{}\n Skills: {}\n Schedule: {}",
|
||||
intent.name,
|
||||
intent.description,
|
||||
provider,
|
||||
model,
|
||||
if skills_to_install.is_empty() {
|
||||
"none".to_string()
|
||||
} else {
|
||||
skills_to_install.join(", ")
|
||||
},
|
||||
if intent.scheduled {
|
||||
intent.schedule.as_deref().unwrap_or("on-demand")
|
||||
} else {
|
||||
"on-demand"
|
||||
}
|
||||
);
|
||||
|
||||
SetupPlan {
|
||||
intent,
|
||||
manifest,
|
||||
skills_to_install,
|
||||
summary,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a short tool usage hint block for the system prompt based on granted tools.
|
||||
fn tool_hints_for(tools: &[String]) -> String {
|
||||
let mut hints = Vec::new();
|
||||
let has = |name: &str| tools.iter().any(|t| t == name);
|
||||
|
||||
if has("web_search") {
|
||||
hints.push("- Use web_search to find current information on any topic.");
|
||||
}
|
||||
if has("web_fetch") {
|
||||
hints.push("- Use web_fetch to read the full content of a specific URL as markdown.");
|
||||
}
|
||||
if has("browser_navigate") {
|
||||
hints.push("- Use browser_navigate/click/type/read_page to interact with websites.");
|
||||
}
|
||||
if has("file_read") {
|
||||
hints.push("- Use file_read to examine files before modifying them.");
|
||||
}
|
||||
if has("shell_exec") {
|
||||
hints.push(
|
||||
"- Use shell_exec to run commands. Explain destructive commands before running.",
|
||||
);
|
||||
}
|
||||
if has("memory_store") {
|
||||
hints.push(
|
||||
"- Use memory_store/memory_recall to persist and retrieve important context.",
|
||||
);
|
||||
}
|
||||
|
||||
if hints.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\nKEY TOOLS:\n{}", hints.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a TOML string from an agent manifest.
|
||||
pub fn manifest_to_toml(manifest: &AgentManifest) -> Result<String, toml::ser::Error> {
|
||||
toml::to_string_pretty(manifest)
|
||||
}
|
||||
|
||||
/// Parse an intent from a JSON string (typically LLM output).
|
||||
pub fn parse_intent(json: &str) -> Result<AgentIntent, serde_json::Error> {
|
||||
serde_json::from_str(json)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_intent() -> AgentIntent {
|
||||
AgentIntent {
|
||||
name: "research-bot".to_string(),
|
||||
description: "Researches topics and provides summaries".to_string(),
|
||||
task: "Search the web for information and provide concise summaries".to_string(),
|
||||
skills: vec!["web-summarizer".to_string()],
|
||||
model_tier: "medium".to_string(),
|
||||
scheduled: false,
|
||||
schedule: None,
|
||||
capabilities: vec!["web".to_string(), "memory".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_plan_basic() {
|
||||
let intent = sample_intent();
|
||||
let plan = SetupWizard::build_plan(intent);
|
||||
|
||||
assert_eq!(plan.manifest.name, "research-bot");
|
||||
assert_eq!(plan.manifest.model.provider, "groq");
|
||||
assert!(plan
|
||||
.manifest
|
||||
.capabilities
|
||||
.network
|
||||
.contains(&"*".to_string()));
|
||||
assert!(plan.summary.contains("research-bot"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_plan_complex_tier() {
|
||||
let mut intent = sample_intent();
|
||||
intent.model_tier = "complex".to_string();
|
||||
let plan = SetupWizard::build_plan(intent);
|
||||
|
||||
assert_eq!(plan.manifest.model.provider, "anthropic");
|
||||
assert!(plan.manifest.model.model.contains("sonnet"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_plan_scheduled() {
|
||||
let mut intent = sample_intent();
|
||||
intent.scheduled = true;
|
||||
intent.schedule = Some("0 */6 * * *".to_string());
|
||||
let plan = SetupWizard::build_plan(intent);
|
||||
|
||||
match &plan.manifest.schedule {
|
||||
ScheduleMode::Periodic { cron } => {
|
||||
assert_eq!(cron, "0 */6 * * *");
|
||||
}
|
||||
_ => panic!("Expected periodic schedule mode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_intent_json() {
|
||||
let json = r#"{
|
||||
"name": "code-reviewer",
|
||||
"description": "Reviews code and suggests improvements",
|
||||
"task": "Analyze pull requests and provide feedback",
|
||||
"skills": [],
|
||||
"model_tier": "complex",
|
||||
"scheduled": false,
|
||||
"schedule": null,
|
||||
"capabilities": ["file_read"]
|
||||
}"#;
|
||||
|
||||
let intent = SetupWizard::parse_intent(json).unwrap();
|
||||
assert_eq!(intent.name, "code-reviewer");
|
||||
assert_eq!(intent.model_tier, "complex");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manifest_to_toml() {
|
||||
let intent = sample_intent();
|
||||
let plan = SetupWizard::build_plan(intent);
|
||||
let toml = SetupWizard::manifest_to_toml(&plan.manifest);
|
||||
assert!(toml.is_ok());
|
||||
let toml_str = toml.unwrap();
|
||||
assert!(toml_str.contains("research-bot"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_web_tools_auto_added() {
|
||||
let intent = AgentIntent {
|
||||
name: "test".to_string(),
|
||||
description: "test".to_string(),
|
||||
task: "test".to_string(),
|
||||
skills: vec![],
|
||||
model_tier: "simple".to_string(),
|
||||
scheduled: false,
|
||||
schedule: None,
|
||||
capabilities: vec!["web".to_string()],
|
||||
};
|
||||
let plan = SetupWizard::build_plan(intent);
|
||||
assert!(plan
|
||||
.manifest
|
||||
.capabilities
|
||||
.tools
|
||||
.contains(&"web_fetch".to_string()));
|
||||
assert!(plan
|
||||
.manifest
|
||||
.capabilities
|
||||
.tools
|
||||
.contains(&"web_search".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_tools_auto_added() {
|
||||
let intent = AgentIntent {
|
||||
name: "test".to_string(),
|
||||
description: "test".to_string(),
|
||||
task: "test".to_string(),
|
||||
skills: vec![],
|
||||
model_tier: "simple".to_string(),
|
||||
scheduled: false,
|
||||
schedule: None,
|
||||
capabilities: vec!["memory".to_string()],
|
||||
};
|
||||
let plan = SetupWizard::build_plan(intent);
|
||||
assert!(plan
|
||||
.manifest
|
||||
.capabilities
|
||||
.tools
|
||||
.contains(&"memory_store".to_string()));
|
||||
assert!(plan
|
||||
.manifest
|
||||
.capabilities
|
||||
.tools
|
||||
.contains(&"memory_recall".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_browser_tools_auto_added() {
|
||||
let intent = AgentIntent {
|
||||
name: "test".to_string(),
|
||||
description: "test".to_string(),
|
||||
task: "test".to_string(),
|
||||
skills: vec![],
|
||||
model_tier: "simple".to_string(),
|
||||
scheduled: false,
|
||||
schedule: None,
|
||||
capabilities: vec!["browser".to_string()],
|
||||
};
|
||||
let plan = SetupWizard::build_plan(intent);
|
||||
assert!(plan
|
||||
.manifest
|
||||
.capabilities
|
||||
.tools
|
||||
.contains(&"browser_navigate".to_string()));
|
||||
assert!(plan
|
||||
.manifest
|
||||
.capabilities
|
||||
.tools
|
||||
.contains(&"browser_click".to_string()));
|
||||
assert!(plan
|
||||
.manifest
|
||||
.capabilities
|
||||
.tools
|
||||
.contains(&"browser_read_page".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wizard_system_prompt_has_task() {
|
||||
let intent = sample_intent();
|
||||
let plan = SetupWizard::build_plan(intent);
|
||||
assert!(plan.manifest.model.system_prompt.contains("YOUR TASK:"));
|
||||
assert!(plan.manifest.model.system_prompt.contains("Search the web"));
|
||||
}
|
||||
}
|
||||
1367
crates/openfang-kernel/src/workflow.rs
Normal file
1367
crates/openfang-kernel/src/workflow.rs
Normal file
File diff suppressed because it is too large
Load Diff
163
crates/openfang-kernel/tests/integration_test.rs
Normal file
163
crates/openfang-kernel/tests/integration_test.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
//! Integration test: boot kernel -> spawn agent -> send message via Groq API.
|
||||
//!
|
||||
//! Run with: GROQ_API_KEY=gsk_... cargo test -p openfang-kernel --test integration_test -- --nocapture
|
||||
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use openfang_types::agent::AgentManifest;
|
||||
use openfang_types::config::{DefaultModelConfig, KernelConfig};
|
||||
|
||||
fn test_config() -> KernelConfig {
|
||||
let tmp = std::env::temp_dir().join("openfang-integration-test");
|
||||
let _ = std::fs::remove_dir_all(&tmp);
|
||||
std::fs::create_dir_all(&tmp).unwrap();
|
||||
|
||||
KernelConfig {
|
||||
home_dir: tmp.clone(),
|
||||
data_dir: tmp.join("data"),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: "groq".to_string(),
|
||||
model: "llama-3.3-70b-versatile".to_string(),
|
||||
api_key_env: "GROQ_API_KEY".to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_full_pipeline_with_groq() {
|
||||
if std::env::var("GROQ_API_KEY").is_err() {
|
||||
eprintln!("GROQ_API_KEY not set, skipping integration test");
|
||||
return;
|
||||
}
|
||||
|
||||
// Boot kernel
|
||||
let config = test_config();
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
// Spawn agent
|
||||
let manifest: AgentManifest = toml::from_str(
|
||||
r#"
|
||||
name = "test-agent"
|
||||
version = "0.1.0"
|
||||
description = "Integration test agent"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.3-70b-versatile"
|
||||
system_prompt = "You are a test agent. Reply concisely in one sentence."
|
||||
|
||||
[capabilities]
|
||||
tools = ["file_read"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let agent_id = kernel.spawn_agent(manifest).expect("Agent should spawn");
|
||||
|
||||
// Send message
|
||||
let result = kernel
|
||||
.send_message(agent_id, "Say hello in exactly 5 words.")
|
||||
.await
|
||||
.expect("Message should get a response");
|
||||
|
||||
println!("\n=== AGENT RESPONSE ===");
|
||||
println!("{}", result.response);
|
||||
println!(
|
||||
"=== USAGE: {} tokens in, {} tokens out, {} iterations ===",
|
||||
result.total_usage.input_tokens, result.total_usage.output_tokens, result.iterations
|
||||
);
|
||||
|
||||
assert!(!result.response.is_empty(), "Response should not be empty");
|
||||
assert!(
|
||||
result.total_usage.input_tokens > 0,
|
||||
"Should have used tokens"
|
||||
);
|
||||
|
||||
// Kill agent
|
||||
kernel.kill_agent(agent_id).expect("Agent should be killed");
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_agents_different_models() {
|
||||
if std::env::var("GROQ_API_KEY").is_err() {
|
||||
eprintln!("GROQ_API_KEY not set, skipping integration test");
|
||||
return;
|
||||
}
|
||||
|
||||
let config = test_config();
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
// Spawn agent 1: llama 70b
|
||||
let manifest1: AgentManifest = toml::from_str(
|
||||
r#"
|
||||
name = "agent-llama70b"
|
||||
version = "0.1.0"
|
||||
description = "Llama 70B agent"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.3-70b-versatile"
|
||||
system_prompt = "You are Agent A. Always start your reply with 'A:'."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Spawn agent 2: llama 8b (faster, smaller)
|
||||
let manifest2: AgentManifest = toml::from_str(
|
||||
r#"
|
||||
name = "agent-llama8b"
|
||||
version = "0.1.0"
|
||||
description = "Llama 8B agent"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.1-8b-instant"
|
||||
system_prompt = "You are Agent B. Always start your reply with 'B:'."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let id1 = kernel.spawn_agent(manifest1).expect("Agent 1 should spawn");
|
||||
let id2 = kernel.spawn_agent(manifest2).expect("Agent 2 should spawn");
|
||||
|
||||
// Send messages to both
|
||||
let r1 = kernel
|
||||
.send_message(id1, "What model are you?")
|
||||
.await
|
||||
.expect("Agent 1 response");
|
||||
let r2 = kernel
|
||||
.send_message(id2, "What model are you?")
|
||||
.await
|
||||
.expect("Agent 2 response");
|
||||
|
||||
println!("\n=== AGENT 1 (llama-70b) ===");
|
||||
println!("{}", r1.response);
|
||||
println!("\n=== AGENT 2 (llama-8b) ===");
|
||||
println!("{}", r2.response);
|
||||
|
||||
assert!(!r1.response.is_empty());
|
||||
assert!(!r2.response.is_empty());
|
||||
|
||||
// Cleanup
|
||||
kernel.kill_agent(id1).unwrap();
|
||||
kernel.kill_agent(id2).unwrap();
|
||||
kernel.shutdown();
|
||||
}
|
||||
201
crates/openfang-kernel/tests/multi_agent_test.rs
Normal file
201
crates/openfang-kernel/tests/multi_agent_test.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
//! Multi-agent integration test: spawn 6 agents, send messages, verify all respond.
|
||||
//!
|
||||
//! Run with: GROQ_API_KEY=gsk_... cargo test -p openfang-kernel --test multi_agent_test -- --nocapture
|
||||
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use openfang_types::agent::AgentManifest;
|
||||
use openfang_types::config::{DefaultModelConfig, KernelConfig};
|
||||
|
||||
fn test_config() -> KernelConfig {
|
||||
let tmp = std::env::temp_dir().join("openfang-multi-agent-test");
|
||||
let _ = std::fs::remove_dir_all(&tmp);
|
||||
std::fs::create_dir_all(&tmp).unwrap();
|
||||
|
||||
KernelConfig {
|
||||
home_dir: tmp.clone(),
|
||||
data_dir: tmp.join("data"),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: "groq".to_string(),
|
||||
model: "llama-3.3-70b-versatile".to_string(),
|
||||
api_key_env: "GROQ_API_KEY".to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn load_manifest(toml_str: &str) -> AgentManifest {
|
||||
toml::from_str(toml_str).expect("Should parse manifest")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_six_agent_fleet() {
|
||||
if std::env::var("GROQ_API_KEY").is_err() {
|
||||
eprintln!("GROQ_API_KEY not set, skipping multi-agent test");
|
||||
return;
|
||||
}
|
||||
|
||||
let kernel = OpenFangKernel::boot_with_config(test_config()).expect("Kernel should boot");
|
||||
|
||||
// Define all 6 agents with different roles and models
|
||||
let agents = vec![
|
||||
(
|
||||
"coder",
|
||||
r#"
|
||||
name = "coder"
|
||||
module = "builtin:chat"
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.3-70b-versatile"
|
||||
system_prompt = "You are Coder. Reply with 'CODER:' prefix. Be concise."
|
||||
[capabilities]
|
||||
tools = ["file_read", "file_write"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
"Write a one-line Rust function that adds two numbers.",
|
||||
),
|
||||
(
|
||||
"researcher",
|
||||
r#"
|
||||
name = "researcher"
|
||||
module = "builtin:chat"
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.3-70b-versatile"
|
||||
system_prompt = "You are Researcher. Reply with 'RESEARCHER:' prefix. Be concise."
|
||||
[capabilities]
|
||||
tools = ["web_fetch"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
"What is Rust's primary advantage over C++? One sentence.",
|
||||
),
|
||||
(
|
||||
"writer",
|
||||
r#"
|
||||
name = "writer"
|
||||
module = "builtin:chat"
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.3-70b-versatile"
|
||||
system_prompt = "You are Writer. Reply with 'WRITER:' prefix. Be concise."
|
||||
[capabilities]
|
||||
tools = ["file_read", "file_write"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
"Write a one-sentence tagline for an Agent Operating System.",
|
||||
),
|
||||
(
|
||||
"ops",
|
||||
r#"
|
||||
name = "ops"
|
||||
module = "builtin:chat"
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.1-8b-instant"
|
||||
system_prompt = "You are Ops. Reply with 'OPS:' prefix. Be concise."
|
||||
[capabilities]
|
||||
tools = ["shell_exec"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
"What would you check first if a server is running slowly?",
|
||||
),
|
||||
(
|
||||
"analyst",
|
||||
r#"
|
||||
name = "analyst"
|
||||
module = "builtin:chat"
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.3-70b-versatile"
|
||||
system_prompt = "You are Analyst. Reply with 'ANALYST:' prefix. Be concise."
|
||||
[capabilities]
|
||||
tools = ["file_read"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
"What are the top 3 metrics to track for an API service?",
|
||||
),
|
||||
(
|
||||
"hello-world",
|
||||
r#"
|
||||
name = "hello-world"
|
||||
module = "builtin:chat"
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.1-8b-instant"
|
||||
system_prompt = "You are a friendly greeter. Reply with 'HELLO:' prefix. Be concise."
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
"Greet the user in a fun way.",
|
||||
),
|
||||
];
|
||||
|
||||
println!("\n{}", "=".repeat(60));
|
||||
println!(" OPENFANG MULTI-AGENT FLEET TEST");
|
||||
println!(" Spawning {} agents...", agents.len());
|
||||
println!("{}\n", "=".repeat(60));
|
||||
|
||||
// Spawn all agents
|
||||
let mut agent_ids = Vec::new();
|
||||
for (name, manifest_str, _) in &agents {
|
||||
let manifest = load_manifest(manifest_str);
|
||||
let id = kernel
|
||||
.spawn_agent(manifest)
|
||||
.unwrap_or_else(|e| panic!("Failed to spawn {name}: {e}"));
|
||||
println!(" Spawned: {name:<12} -> {id}");
|
||||
agent_ids.push(id);
|
||||
}
|
||||
|
||||
assert_eq!(kernel.registry.count(), 6, "Should have 6 agents");
|
||||
println!(
|
||||
"\n All {} agents spawned. Sending messages...\n",
|
||||
agents.len()
|
||||
);
|
||||
|
||||
// Send messages to each agent sequentially (to respect Groq rate limits)
|
||||
let mut results = Vec::new();
|
||||
for (i, (name, _, message)) in agents.iter().enumerate() {
|
||||
let result = kernel
|
||||
.send_message(agent_ids[i], message)
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("Failed to message {name}: {e}"));
|
||||
|
||||
println!("--- {name} ---");
|
||||
println!(" Q: {message}");
|
||||
println!(" A: {}", result.response);
|
||||
println!(
|
||||
" [{} tokens in, {} tokens out, {} iters]",
|
||||
result.total_usage.input_tokens, result.total_usage.output_tokens, result.iterations
|
||||
);
|
||||
println!();
|
||||
|
||||
assert!(
|
||||
!result.response.is_empty(),
|
||||
"{name} response should not be empty"
|
||||
);
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
// Summary
|
||||
let total_input: u64 = results.iter().map(|r| r.total_usage.input_tokens).sum();
|
||||
let total_output: u64 = results.iter().map(|r| r.total_usage.output_tokens).sum();
|
||||
println!("============================================================");
|
||||
println!(" FLEET SUMMARY");
|
||||
println!(" Agents: {}", agents.len());
|
||||
println!(" Total input: {} tokens", total_input);
|
||||
println!(" Total output: {} tokens", total_output);
|
||||
println!(" All responded: YES");
|
||||
println!("============================================================");
|
||||
|
||||
// Cleanup
|
||||
for id in agent_ids {
|
||||
kernel.kill_agent(id).unwrap();
|
||||
}
|
||||
kernel.shutdown();
|
||||
}
|
||||
410
crates/openfang-kernel/tests/wasm_agent_integration_test.rs
Normal file
410
crates/openfang-kernel/tests/wasm_agent_integration_test.rs
Normal file
@@ -0,0 +1,410 @@
|
||||
//! WASM agent integration tests.
|
||||
//!
|
||||
//! Tests the full pipeline: boot kernel → spawn agent with `module = "wasm:..."`
|
||||
//! → send message → verify WASM module executes and returns response.
|
||||
//!
|
||||
//! These tests use real WASM execution — no mocks.
|
||||
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use openfang_types::agent::AgentManifest;
|
||||
use openfang_types::config::{DefaultModelConfig, KernelConfig};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Minimal echo module: returns input JSON wrapped as `{"response": "..."}`.
|
||||
///
|
||||
/// Reads the "message" field from input and echoes it back as the response.
|
||||
/// Since WAT can't do real string manipulation, this module echoes the
|
||||
/// entire input JSON as-is (which the kernel extracts via serde).
|
||||
const ECHO_WAT: &str = r#"
|
||||
(module
|
||||
(memory (export "memory") 1)
|
||||
(global $bump (mut i32) (i32.const 1024))
|
||||
|
||||
(func (export "alloc") (param $size i32) (result i32)
|
||||
(local $ptr i32)
|
||||
(local.set $ptr (global.get $bump))
|
||||
(global.set $bump (i32.add (global.get $bump) (local.get $size)))
|
||||
(local.get $ptr)
|
||||
)
|
||||
|
||||
(func (export "execute") (param $ptr i32) (param $len i32) (result i64)
|
||||
;; Echo: return the input as-is (kernel will extract from JSON)
|
||||
(i64.or
|
||||
(i64.shl
|
||||
(i64.extend_i32_u (local.get $ptr))
|
||||
(i64.const 32)
|
||||
)
|
||||
(i64.extend_i32_u (local.get $len))
|
||||
)
|
||||
)
|
||||
)
|
||||
"#;
|
||||
|
||||
/// Module that always returns a fixed JSON response.
|
||||
/// Writes `{"response":"hello from wasm"}` at offset 0 and returns it.
|
||||
const HELLO_WAT: &str = r#"
|
||||
(module
|
||||
(memory (export "memory") 1)
|
||||
(global $bump (mut i32) (i32.const 4096))
|
||||
|
||||
;; Fixed response bytes: {"response":"hello from wasm"}
|
||||
(data (i32.const 0) "{\"response\":\"hello from wasm\"}")
|
||||
|
||||
(func (export "alloc") (param $size i32) (result i32)
|
||||
(local $ptr i32)
|
||||
(local.set $ptr (global.get $bump))
|
||||
(global.set $bump (i32.add (global.get $bump) (local.get $size)))
|
||||
(local.get $ptr)
|
||||
)
|
||||
|
||||
(func (export "execute") (param $ptr i32) (param $len i32) (result i64)
|
||||
;; Return pointer=0, length=30 (the fixed response)
|
||||
(i64.const 30) ;; low 32 = len=30, high 32 = ptr=0
|
||||
)
|
||||
)
|
||||
"#;
|
||||
|
||||
/// Module with infinite loop — tests fuel exhaustion enforcement.
|
||||
const INFINITE_LOOP_WAT: &str = r#"
|
||||
(module
|
||||
(memory (export "memory") 1)
|
||||
(global $bump (mut i32) (i32.const 1024))
|
||||
|
||||
(func (export "alloc") (param $size i32) (result i32)
|
||||
(local $ptr i32)
|
||||
(local.set $ptr (global.get $bump))
|
||||
(global.set $bump (i32.add (global.get $bump) (local.get $size)))
|
||||
(local.get $ptr)
|
||||
)
|
||||
|
||||
(func (export "execute") (param $ptr i32) (param $len i32) (result i64)
|
||||
(loop $inf
|
||||
(br $inf)
|
||||
)
|
||||
(i64.const 0)
|
||||
)
|
||||
)
|
||||
"#;
|
||||
|
||||
/// Host-call proxy: forwards input to host_call and returns the response.
|
||||
const HOST_CALL_PROXY_WAT: &str = r#"
|
||||
(module
|
||||
(import "openfang" "host_call" (func $host_call (param i32 i32) (result i64)))
|
||||
(memory (export "memory") 2)
|
||||
(global $bump (mut i32) (i32.const 1024))
|
||||
|
||||
(func (export "alloc") (param $size i32) (result i32)
|
||||
(local $ptr i32)
|
||||
(local.set $ptr (global.get $bump))
|
||||
(global.set $bump (i32.add (global.get $bump) (local.get $size)))
|
||||
(local.get $ptr)
|
||||
)
|
||||
|
||||
(func (export "execute") (param $input_ptr i32) (param $input_len i32) (result i64)
|
||||
(call $host_call (local.get $input_ptr) (local.get $input_len))
|
||||
)
|
||||
)
|
||||
"#;
|
||||
|
||||
fn test_config(tmp: &tempfile::TempDir) -> KernelConfig {
|
||||
KernelConfig {
|
||||
home_dir: tmp.path().to_path_buf(),
|
||||
data_dir: tmp.path().join("data"),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "test".to_string(),
|
||||
api_key_env: "OLLAMA_API_KEY".to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn wasm_manifest(name: &str, module: &str) -> AgentManifest {
|
||||
let toml_str = format!(
|
||||
r#"
|
||||
name = "{name}"
|
||||
version = "0.1.0"
|
||||
description = "WASM test agent"
|
||||
author = "test"
|
||||
module = "wasm:{module}"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test"
|
||||
system_prompt = "WASM agent."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#
|
||||
);
|
||||
toml::from_str(&toml_str).unwrap()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Test that a WASM agent can be spawned and returns a response.
|
||||
#[tokio::test]
|
||||
async fn test_wasm_agent_hello_response() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::fs::write(tmp.path().join("hello.wat"), HELLO_WAT).unwrap();
|
||||
|
||||
let config = test_config(&tmp);
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
let manifest = wasm_manifest("wasm-hello", "hello.wat");
|
||||
let agent_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
let result = kernel
|
||||
.send_message(agent_id, "Hi there!")
|
||||
.await
|
||||
.expect("WASM agent should execute");
|
||||
|
||||
assert_eq!(result.response, "hello from wasm");
|
||||
assert_eq!(result.iterations, 1);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test that a WASM echo module returns input data.
|
||||
#[tokio::test]
|
||||
async fn test_wasm_agent_echo() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::fs::write(tmp.path().join("echo.wat"), ECHO_WAT).unwrap();
|
||||
|
||||
let config = test_config(&tmp);
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
let manifest = wasm_manifest("wasm-echo", "echo.wat");
|
||||
let agent_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
let result = kernel
|
||||
.send_message(agent_id, "test message")
|
||||
.await
|
||||
.expect("Echo agent should execute");
|
||||
|
||||
// Echo returns the entire input JSON, so the response should contain our message
|
||||
assert!(
|
||||
result.response.contains("test message"),
|
||||
"Response should contain the input message, got: {}",
|
||||
result.response
|
||||
);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test that WASM fuel exhaustion is caught and reported as an error.
|
||||
#[tokio::test]
|
||||
async fn test_wasm_agent_fuel_exhaustion() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::fs::write(tmp.path().join("loop.wat"), INFINITE_LOOP_WAT).unwrap();
|
||||
|
||||
let config = test_config(&tmp);
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
let manifest = wasm_manifest("wasm-loop", "loop.wat");
|
||||
let agent_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
let result = kernel.send_message(agent_id, "go").await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Infinite loop should fail with fuel exhaustion"
|
||||
);
|
||||
let err_msg = format!("{}", result.unwrap_err());
|
||||
assert!(
|
||||
err_msg.contains("Fuel exhausted") || err_msg.contains("fuel") || err_msg.contains("WASM"),
|
||||
"Error should mention fuel exhaustion, got: {err_msg}"
|
||||
);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test that a missing WASM module produces a clear error.
|
||||
#[tokio::test]
|
||||
async fn test_wasm_agent_missing_module() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
// Don't write any .wat file
|
||||
|
||||
let config = test_config(&tmp);
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
let manifest = wasm_manifest("wasm-missing", "nonexistent.wasm");
|
||||
let agent_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
let result = kernel.send_message(agent_id, "hello").await;
|
||||
assert!(result.is_err(), "Missing module should fail");
|
||||
let err_msg = format!("{}", result.unwrap_err());
|
||||
assert!(
|
||||
err_msg.contains("Failed to read") || err_msg.contains("nonexistent"),
|
||||
"Error should mention the missing file, got: {err_msg}"
|
||||
);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test that host_call time_now works end-to-end through the kernel.
|
||||
#[tokio::test]
|
||||
async fn test_wasm_agent_host_call_time() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::fs::write(tmp.path().join("proxy.wat"), HOST_CALL_PROXY_WAT).unwrap();
|
||||
|
||||
let config = test_config(&tmp);
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
// Proxy module forwards input to host_call — send a time_now request
|
||||
let toml_str = r#"
|
||||
name = "wasm-proxy"
|
||||
version = "0.1.0"
|
||||
description = "Host call proxy"
|
||||
author = "test"
|
||||
module = "wasm:proxy.wat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test"
|
||||
system_prompt = "Proxy."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#;
|
||||
let manifest: AgentManifest = toml::from_str(toml_str).unwrap();
|
||||
let agent_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
// The proxy module expects JSON like {"method":"time_now","params":{}}
|
||||
// But our kernel wraps it as {"message":"...", "agent_id":"...", "agent_name":"..."}
|
||||
// So the proxy will try to dispatch with method=null which returns "Unknown"
|
||||
// This still proves the full pipeline works end-to-end
|
||||
let result = kernel
|
||||
.send_message(agent_id, r#"{"method":"time_now","params":{}}"#)
|
||||
.await
|
||||
.expect("Proxy agent should execute");
|
||||
|
||||
// The response will contain the host_call dispatch result
|
||||
assert!(!result.response.is_empty(), "Response should not be empty");
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test WASM agent with streaming (falls back to single event).
|
||||
#[tokio::test]
|
||||
async fn test_wasm_agent_streaming_fallback() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::fs::write(tmp.path().join("hello.wat"), HELLO_WAT).unwrap();
|
||||
|
||||
let config = test_config(&tmp);
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
let kernel = Arc::new(kernel);
|
||||
|
||||
let manifest = wasm_manifest("wasm-stream", "hello.wat");
|
||||
let agent_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
let (mut rx, handle) = kernel
|
||||
.send_message_streaming(agent_id, "Hi!", None)
|
||||
.expect("Streaming should start");
|
||||
|
||||
// Collect all stream events
|
||||
let mut events = vec![];
|
||||
while let Some(event) = rx.recv().await {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Should have gotten a TextDelta + ContentComplete
|
||||
assert!(
|
||||
events.len() >= 2,
|
||||
"Expected at least 2 stream events, got {}",
|
||||
events.len()
|
||||
);
|
||||
|
||||
let final_result = handle.await.unwrap().expect("Task should complete");
|
||||
assert_eq!(final_result.response, "hello from wasm");
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test that spawning multiple WASM agents works concurrently.
|
||||
#[tokio::test]
|
||||
async fn test_multiple_wasm_agents() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::fs::write(tmp.path().join("hello.wat"), HELLO_WAT).unwrap();
|
||||
std::fs::write(tmp.path().join("echo.wat"), ECHO_WAT).unwrap();
|
||||
|
||||
let config = test_config(&tmp);
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
let hello_id = kernel
|
||||
.spawn_agent(wasm_manifest("hello-agent", "hello.wat"))
|
||||
.unwrap();
|
||||
let echo_id = kernel
|
||||
.spawn_agent(wasm_manifest("echo-agent", "echo.wat"))
|
||||
.unwrap();
|
||||
|
||||
// Execute both
|
||||
let hello_result = kernel.send_message(hello_id, "hi").await.unwrap();
|
||||
let echo_result = kernel.send_message(echo_id, "test data").await.unwrap();
|
||||
|
||||
assert_eq!(hello_result.response, "hello from wasm");
|
||||
assert!(echo_result.response.contains("test data"));
|
||||
|
||||
// Verify agent list shows both
|
||||
let agents = kernel.registry.list();
|
||||
assert_eq!(agents.len(), 2);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test WASM agent alongside LLM agent (mixed fleet).
|
||||
#[tokio::test]
|
||||
async fn test_mixed_wasm_and_llm_agents() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::fs::write(tmp.path().join("hello.wat"), HELLO_WAT).unwrap();
|
||||
|
||||
let config = test_config(&tmp);
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
// Spawn a WASM agent
|
||||
let wasm_id = kernel
|
||||
.spawn_agent(wasm_manifest("wasm-agent", "hello.wat"))
|
||||
.unwrap();
|
||||
|
||||
// Spawn a regular LLM agent (won't actually call LLM since ollama isn't running,
|
||||
// but it should spawn fine and coexist)
|
||||
let llm_toml = r#"
|
||||
name = "llm-agent"
|
||||
version = "0.1.0"
|
||||
description = "LLM test agent"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test"
|
||||
system_prompt = "You are a test agent."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#;
|
||||
let llm_manifest: AgentManifest = toml::from_str(llm_toml).unwrap();
|
||||
let llm_id = kernel.spawn_agent(llm_manifest).unwrap();
|
||||
|
||||
// Verify both agents exist
|
||||
let agents = kernel.registry.list();
|
||||
assert_eq!(agents.len(), 2);
|
||||
|
||||
// WASM agent should work
|
||||
let result = kernel.send_message(wasm_id, "hello").await.unwrap();
|
||||
assert_eq!(result.response, "hello from wasm");
|
||||
|
||||
// LLM agent exists but we won't send it a message (no real LLM)
|
||||
assert!(kernel.registry.get(llm_id).is_some());
|
||||
|
||||
// Kill WASM agent
|
||||
kernel.kill_agent(wasm_id).unwrap();
|
||||
assert_eq!(kernel.registry.list().len(), 1);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
404
crates/openfang-kernel/tests/workflow_integration_test.rs
Normal file
404
crates/openfang-kernel/tests/workflow_integration_test.rs
Normal file
@@ -0,0 +1,404 @@
|
||||
//! End-to-end workflow integration tests.
|
||||
//!
|
||||
//! Tests the full pipeline: boot kernel → spawn agents → create workflow →
|
||||
//! execute workflow → verify outputs flow through the pipeline.
|
||||
//!
|
||||
//! LLM tests require GROQ_API_KEY. Non-LLM tests verify the kernel-level
|
||||
//! workflow wiring without making real API calls.
|
||||
|
||||
use openfang_kernel::workflow::{
|
||||
ErrorMode, StepAgent, StepMode, Workflow, WorkflowId, WorkflowStep,
|
||||
};
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use openfang_types::agent::AgentManifest;
|
||||
use openfang_types::config::{DefaultModelConfig, KernelConfig};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn test_config(provider: &str, model: &str, api_key_env: &str) -> KernelConfig {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
KernelConfig {
|
||||
home_dir: tmp.path().to_path_buf(),
|
||||
data_dir: tmp.path().join("data"),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: provider.to_string(),
|
||||
model: model.to_string(),
|
||||
api_key_env: api_key_env.to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_test_agent(
|
||||
kernel: &OpenFangKernel,
|
||||
name: &str,
|
||||
system_prompt: &str,
|
||||
) -> openfang_types::agent::AgentId {
|
||||
let manifest_str = format!(
|
||||
r#"
|
||||
name = "{name}"
|
||||
version = "0.1.0"
|
||||
description = "Workflow test agent: {name}"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.3-70b-versatile"
|
||||
system_prompt = "{system_prompt}"
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#
|
||||
);
|
||||
let manifest: AgentManifest = toml::from_str(&manifest_str).unwrap();
|
||||
kernel.spawn_agent(manifest).expect("Agent should spawn")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Kernel-level workflow wiring tests (no LLM needed)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Test that workflow registration and agent resolution work at the kernel level.
|
||||
#[tokio::test]
|
||||
async fn test_workflow_register_and_resolve() {
|
||||
let config = test_config("ollama", "test-model", "OLLAMA_API_KEY");
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
let kernel = Arc::new(kernel);
|
||||
|
||||
// Spawn agents
|
||||
let manifest: AgentManifest = toml::from_str(
|
||||
r#"
|
||||
name = "agent-alpha"
|
||||
version = "0.1.0"
|
||||
description = "Alpha"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test"
|
||||
system_prompt = "Alpha."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let alpha_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
let manifest2: AgentManifest = toml::from_str(
|
||||
r#"
|
||||
name = "agent-beta"
|
||||
version = "0.1.0"
|
||||
description = "Beta"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test"
|
||||
system_prompt = "Beta."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let beta_id = kernel.spawn_agent(manifest2).unwrap();
|
||||
|
||||
// Create a 2-step workflow referencing agents by name
|
||||
let workflow = Workflow {
|
||||
id: WorkflowId::new(),
|
||||
name: "alpha-beta-pipeline".to_string(),
|
||||
description: "Tests agent resolution by name".to_string(),
|
||||
steps: vec![
|
||||
WorkflowStep {
|
||||
name: "step-alpha".to_string(),
|
||||
agent: StepAgent::ByName {
|
||||
name: "agent-alpha".to_string(),
|
||||
},
|
||||
prompt_template: "Analyze: {{input}}".to_string(),
|
||||
mode: StepMode::Sequential,
|
||||
timeout_secs: 30,
|
||||
error_mode: ErrorMode::Fail,
|
||||
output_var: Some("alpha_out".to_string()),
|
||||
},
|
||||
WorkflowStep {
|
||||
name: "step-beta".to_string(),
|
||||
agent: StepAgent::ByName {
|
||||
name: "agent-beta".to_string(),
|
||||
},
|
||||
prompt_template: "Summarize: {{input}} (alpha said: {{alpha_out}})".to_string(),
|
||||
mode: StepMode::Sequential,
|
||||
timeout_secs: 30,
|
||||
error_mode: ErrorMode::Fail,
|
||||
output_var: None,
|
||||
},
|
||||
],
|
||||
created_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let wf_id = kernel.register_workflow(workflow).await;
|
||||
|
||||
// Verify workflow is registered
|
||||
let workflows = kernel.workflows.list_workflows().await;
|
||||
assert_eq!(workflows.len(), 1);
|
||||
assert_eq!(workflows[0].name, "alpha-beta-pipeline");
|
||||
|
||||
// Verify agents can be found by name
|
||||
let alpha = kernel.registry.find_by_name("agent-alpha");
|
||||
assert!(alpha.is_some());
|
||||
assert_eq!(alpha.unwrap().id, alpha_id);
|
||||
|
||||
let beta = kernel.registry.find_by_name("agent-beta");
|
||||
assert!(beta.is_some());
|
||||
assert_eq!(beta.unwrap().id, beta_id);
|
||||
|
||||
// Verify workflow run can be created
|
||||
let run_id = kernel
|
||||
.workflows
|
||||
.create_run(wf_id, "test input".to_string())
|
||||
.await;
|
||||
assert!(run_id.is_some());
|
||||
|
||||
let run = kernel.workflows.get_run(run_id.unwrap()).await.unwrap();
|
||||
assert_eq!(run.input, "test input");
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test workflow with agent referenced by ID.
|
||||
#[tokio::test]
|
||||
async fn test_workflow_agent_by_id() {
|
||||
let config = test_config("ollama", "test-model", "OLLAMA_API_KEY");
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
let manifest: AgentManifest = toml::from_str(
|
||||
r#"
|
||||
name = "id-agent"
|
||||
version = "0.1.0"
|
||||
description = "Test"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test"
|
||||
system_prompt = "Test."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let agent_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
let workflow = Workflow {
|
||||
id: WorkflowId::new(),
|
||||
name: "by-id-test".to_string(),
|
||||
description: "".to_string(),
|
||||
steps: vec![WorkflowStep {
|
||||
name: "step1".to_string(),
|
||||
agent: StepAgent::ById {
|
||||
id: agent_id.to_string(),
|
||||
},
|
||||
prompt_template: "{{input}}".to_string(),
|
||||
mode: StepMode::Sequential,
|
||||
timeout_secs: 30,
|
||||
error_mode: ErrorMode::Fail,
|
||||
output_var: None,
|
||||
}],
|
||||
created_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let wf_id = kernel.register_workflow(workflow).await;
|
||||
|
||||
// Can create run (agent resolution happens at execute time)
|
||||
let run_id = kernel
|
||||
.workflows
|
||||
.create_run(wf_id, "hello".to_string())
|
||||
.await;
|
||||
assert!(run_id.is_some());
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test trigger registration and listing at kernel level.
|
||||
#[tokio::test]
|
||||
async fn test_trigger_registration_with_kernel() {
|
||||
use openfang_kernel::triggers::TriggerPattern;
|
||||
|
||||
let config = test_config("ollama", "test-model", "OLLAMA_API_KEY");
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
|
||||
let manifest: AgentManifest = toml::from_str(
|
||||
r#"
|
||||
name = "trigger-agent"
|
||||
version = "0.1.0"
|
||||
description = "Trigger test"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test"
|
||||
system_prompt = "Test."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let agent_id = kernel.spawn_agent(manifest).unwrap();
|
||||
|
||||
// Register triggers
|
||||
let t1 = kernel
|
||||
.register_trigger(
|
||||
agent_id,
|
||||
TriggerPattern::Lifecycle,
|
||||
"Lifecycle event: {{event}}".to_string(),
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let t2 = kernel
|
||||
.register_trigger(
|
||||
agent_id,
|
||||
TriggerPattern::SystemKeyword {
|
||||
keyword: "deploy".to_string(),
|
||||
},
|
||||
"Deploy event: {{event}}".to_string(),
|
||||
5,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// List all triggers
|
||||
let all = kernel.list_triggers(None);
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
// List triggers for specific agent
|
||||
let agent_triggers = kernel.list_triggers(Some(agent_id));
|
||||
assert_eq!(agent_triggers.len(), 2);
|
||||
|
||||
// Remove one
|
||||
assert!(kernel.remove_trigger(t1));
|
||||
let remaining = kernel.list_triggers(None);
|
||||
assert_eq!(remaining.len(), 1);
|
||||
assert_eq!(remaining[0].id, t2);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Full E2E with real LLM (skip if no GROQ_API_KEY)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// End-to-end: boot kernel → spawn 2 agents → create 2-step workflow →
|
||||
/// run it through the real Groq LLM → verify output flows from step 1 to step 2.
|
||||
#[tokio::test]
|
||||
async fn test_workflow_e2e_with_groq() {
|
||||
if std::env::var("GROQ_API_KEY").is_err() {
|
||||
eprintln!("GROQ_API_KEY not set, skipping E2E workflow test");
|
||||
return;
|
||||
}
|
||||
|
||||
let config = test_config("groq", "llama-3.3-70b-versatile", "GROQ_API_KEY");
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
let kernel = Arc::new(kernel);
|
||||
kernel.set_self_handle();
|
||||
|
||||
// Spawn two agents with distinct roles
|
||||
let _analyst_id = spawn_test_agent(
|
||||
&kernel,
|
||||
"wf-analyst",
|
||||
"You are an analyst. When given text, respond with exactly: ANALYSIS: followed by a one-sentence analysis.",
|
||||
);
|
||||
let _writer_id = spawn_test_agent(
|
||||
&kernel,
|
||||
"wf-writer",
|
||||
"You are a writer. When given text, respond with exactly: SUMMARY: followed by a one-sentence summary.",
|
||||
);
|
||||
|
||||
// Create a 2-step pipeline: analyst → writer
|
||||
let workflow = Workflow {
|
||||
id: WorkflowId::new(),
|
||||
name: "analyst-writer-pipeline".to_string(),
|
||||
description: "E2E integration test workflow".to_string(),
|
||||
steps: vec![
|
||||
WorkflowStep {
|
||||
name: "analyze".to_string(),
|
||||
agent: StepAgent::ByName {
|
||||
name: "wf-analyst".to_string(),
|
||||
},
|
||||
prompt_template: "Analyze the following: {{input}}".to_string(),
|
||||
mode: StepMode::Sequential,
|
||||
timeout_secs: 60,
|
||||
error_mode: ErrorMode::Fail,
|
||||
output_var: None,
|
||||
},
|
||||
WorkflowStep {
|
||||
name: "summarize".to_string(),
|
||||
agent: StepAgent::ByName {
|
||||
name: "wf-writer".to_string(),
|
||||
},
|
||||
prompt_template: "Summarize this analysis: {{input}}".to_string(),
|
||||
mode: StepMode::Sequential,
|
||||
timeout_secs: 60,
|
||||
error_mode: ErrorMode::Fail,
|
||||
output_var: None,
|
||||
},
|
||||
],
|
||||
created_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let wf_id = kernel.register_workflow(workflow).await;
|
||||
|
||||
// Run the workflow
|
||||
let result = kernel
|
||||
.run_workflow(
|
||||
wf_id,
|
||||
"The Rust programming language is growing rapidly.".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Workflow should complete: {:?}",
|
||||
result.err()
|
||||
);
|
||||
let (run_id, output) = result.unwrap();
|
||||
|
||||
println!("\n=== WORKFLOW OUTPUT ===");
|
||||
println!("{output}");
|
||||
println!("======================\n");
|
||||
|
||||
assert!(!output.is_empty(), "Workflow output should not be empty");
|
||||
|
||||
// Verify the workflow run record
|
||||
let run = kernel.workflows.get_run(run_id).await.unwrap();
|
||||
assert!(matches!(
|
||||
run.state,
|
||||
openfang_kernel::workflow::WorkflowRunState::Completed
|
||||
));
|
||||
assert_eq!(run.step_results.len(), 2);
|
||||
assert_eq!(run.step_results[0].step_name, "analyze");
|
||||
assert_eq!(run.step_results[1].step_name, "summarize");
|
||||
|
||||
// Both steps should have used tokens
|
||||
assert!(run.step_results[0].input_tokens > 0);
|
||||
assert!(run.step_results[0].output_tokens > 0);
|
||||
assert!(run.step_results[1].input_tokens > 0);
|
||||
assert!(run.step_results[1].output_tokens > 0);
|
||||
|
||||
// List runs
|
||||
let runs = kernel.workflows.list_runs(None).await;
|
||||
assert_eq!(runs.len(), 1);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
Reference in New Issue
Block a user