初始化提交
Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled

This commit is contained in:
iven
2026-03-01 16:24:24 +08:00
commit 92e5def702
492 changed files with 211343 additions and 0 deletions

View File

@@ -0,0 +1,405 @@
//! Execution approval manager — gates dangerous operations behind human approval.
use chrono::Utc;
use dashmap::DashMap;
use openfang_types::approval::{
ApprovalDecision, ApprovalPolicy, ApprovalRequest, ApprovalResponse, RiskLevel,
};
use tracing::{debug, info, warn};
use uuid::Uuid;
/// Max pending requests per agent.
const MAX_PENDING_PER_AGENT: usize = 5;
/// Manages approval requests with oneshot channels for blocking resolution.
pub struct ApprovalManager {
pending: DashMap<Uuid, PendingRequest>,
policy: std::sync::RwLock<ApprovalPolicy>,
}
struct PendingRequest {
request: ApprovalRequest,
sender: tokio::sync::oneshot::Sender<ApprovalDecision>,
}
impl ApprovalManager {
pub fn new(policy: ApprovalPolicy) -> Self {
Self {
pending: DashMap::new(),
policy: std::sync::RwLock::new(policy),
}
}
/// Check if a tool requires approval based on current policy.
pub fn requires_approval(&self, tool_name: &str) -> bool {
let policy = self.policy.read().unwrap_or_else(|e| e.into_inner());
policy.require_approval.iter().any(|t| t == tool_name)
}
/// Submit an approval request. Returns a future that resolves when approved/denied/timed out.
pub async fn request_approval(&self, req: ApprovalRequest) -> ApprovalDecision {
// Check per-agent pending limit
let agent_pending = self
.pending
.iter()
.filter(|r| r.value().request.agent_id == req.agent_id)
.count();
if agent_pending >= MAX_PENDING_PER_AGENT {
warn!(agent_id = %req.agent_id, "Approval request rejected: too many pending");
return ApprovalDecision::Denied;
}
let timeout = std::time::Duration::from_secs(req.timeout_secs);
let id = req.id;
let (tx, rx) = tokio::sync::oneshot::channel();
self.pending.insert(
id,
PendingRequest {
request: req,
sender: tx,
},
);
info!(request_id = %id, "Approval request submitted, waiting for resolution");
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(decision)) => {
debug!(request_id = %id, ?decision, "Approval resolved");
decision
}
_ => {
self.pending.remove(&id);
warn!(request_id = %id, "Approval request timed out");
ApprovalDecision::TimedOut
}
}
}
/// Resolve a pending request (called by API/UI).
pub fn resolve(
&self,
request_id: Uuid,
decision: ApprovalDecision,
decided_by: Option<String>,
) -> Result<ApprovalResponse, String> {
match self.pending.remove(&request_id) {
Some((_, pending)) => {
let response = ApprovalResponse {
request_id,
decision,
decided_at: Utc::now(),
decided_by,
};
// Send decision to waiting agent (ignore error if receiver dropped)
let _ = pending.sender.send(decision);
info!(request_id = %request_id, ?decision, "Approval request resolved");
Ok(response)
}
None => Err(format!("No pending approval request with id {request_id}")),
}
}
/// List all pending requests (for API/dashboard display).
pub fn list_pending(&self) -> Vec<ApprovalRequest> {
self.pending
.iter()
.map(|r| r.value().request.clone())
.collect()
}
/// Number of pending requests.
pub fn pending_count(&self) -> usize {
self.pending.len()
}
/// Update the approval policy (for hot-reload).
pub fn update_policy(&self, policy: ApprovalPolicy) {
*self.policy.write().unwrap_or_else(|e| e.into_inner()) = policy;
}
/// Get a copy of the current policy.
pub fn policy(&self) -> ApprovalPolicy {
self.policy
.read()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
/// Classify the risk level of a tool invocation.
pub fn classify_risk(tool_name: &str) -> RiskLevel {
match tool_name {
"shell_exec" => RiskLevel::Critical,
"file_write" | "file_delete" => RiskLevel::High,
"web_fetch" | "browser_navigate" => RiskLevel::Medium,
_ => RiskLevel::Low,
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use openfang_types::approval::ApprovalPolicy;
use std::sync::Arc;
fn default_manager() -> ApprovalManager {
ApprovalManager::new(ApprovalPolicy::default())
}
fn make_request(agent_id: &str, tool_name: &str, timeout_secs: u64) -> ApprovalRequest {
ApprovalRequest {
id: Uuid::new_v4(),
agent_id: agent_id.to_string(),
tool_name: tool_name.to_string(),
description: "test operation".to_string(),
action_summary: "test action".to_string(),
risk_level: RiskLevel::High,
requested_at: Utc::now(),
timeout_secs,
}
}
// -----------------------------------------------------------------------
// requires_approval
// -----------------------------------------------------------------------
#[test]
fn test_requires_approval_default() {
let mgr = default_manager();
assert!(mgr.requires_approval("shell_exec"));
assert!(!mgr.requires_approval("file_read"));
}
#[test]
fn test_requires_approval_custom_policy() {
let policy = ApprovalPolicy {
require_approval: vec!["file_write".to_string(), "file_delete".to_string()],
timeout_secs: 30,
auto_approve_autonomous: false,
auto_approve: false,
};
let mgr = ApprovalManager::new(policy);
assert!(mgr.requires_approval("file_write"));
assert!(mgr.requires_approval("file_delete"));
assert!(!mgr.requires_approval("shell_exec"));
assert!(!mgr.requires_approval("file_read"));
}
// -----------------------------------------------------------------------
// classify_risk
// -----------------------------------------------------------------------
#[test]
fn test_classify_risk() {
assert_eq!(
ApprovalManager::classify_risk("shell_exec"),
RiskLevel::Critical
);
assert_eq!(
ApprovalManager::classify_risk("file_write"),
RiskLevel::High
);
assert_eq!(
ApprovalManager::classify_risk("file_delete"),
RiskLevel::High
);
assert_eq!(
ApprovalManager::classify_risk("web_fetch"),
RiskLevel::Medium
);
assert_eq!(
ApprovalManager::classify_risk("browser_navigate"),
RiskLevel::Medium
);
assert_eq!(ApprovalManager::classify_risk("file_read"), RiskLevel::Low);
assert_eq!(
ApprovalManager::classify_risk("unknown_tool"),
RiskLevel::Low
);
}
// -----------------------------------------------------------------------
// resolve nonexistent
// -----------------------------------------------------------------------
#[test]
fn test_resolve_nonexistent() {
let mgr = default_manager();
let result = mgr.resolve(Uuid::new_v4(), ApprovalDecision::Approved, None);
assert!(result.is_err());
assert!(result.unwrap_err().contains("No pending approval request"));
}
// -----------------------------------------------------------------------
// list_pending empty
// -----------------------------------------------------------------------
#[test]
fn test_list_pending_empty() {
let mgr = default_manager();
assert!(mgr.list_pending().is_empty());
}
// -----------------------------------------------------------------------
// update_policy
// -----------------------------------------------------------------------
#[test]
fn test_update_policy() {
let mgr = default_manager();
assert!(mgr.requires_approval("shell_exec"));
assert!(!mgr.requires_approval("file_write"));
let new_policy = ApprovalPolicy {
require_approval: vec!["file_write".to_string()],
timeout_secs: 120,
auto_approve_autonomous: true,
auto_approve: false,
};
mgr.update_policy(new_policy);
assert!(!mgr.requires_approval("shell_exec"));
assert!(mgr.requires_approval("file_write"));
let policy = mgr.policy();
assert_eq!(policy.timeout_secs, 120);
assert!(policy.auto_approve_autonomous);
}
// -----------------------------------------------------------------------
// pending_count
// -----------------------------------------------------------------------
#[test]
fn test_pending_count() {
let mgr = default_manager();
assert_eq!(mgr.pending_count(), 0);
}
// -----------------------------------------------------------------------
// request_approval — timeout
// -----------------------------------------------------------------------
#[tokio::test]
async fn test_request_approval_timeout() {
let mgr = Arc::new(default_manager());
let req = make_request("agent-1", "shell_exec", 10);
let decision = mgr.request_approval(req).await;
assert_eq!(decision, ApprovalDecision::TimedOut);
// After timeout, pending map should be cleaned up
assert_eq!(mgr.pending_count(), 0);
}
// -----------------------------------------------------------------------
// request_approval — approve
// -----------------------------------------------------------------------
#[tokio::test]
async fn test_request_approval_approve() {
let mgr = Arc::new(default_manager());
let req = make_request("agent-1", "shell_exec", 60);
let request_id = req.id;
let mgr2 = Arc::clone(&mgr);
tokio::spawn(async move {
// Small delay to let the request register
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let result = mgr2.resolve(
request_id,
ApprovalDecision::Approved,
Some("admin".to_string()),
);
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.decision, ApprovalDecision::Approved);
assert_eq!(resp.decided_by, Some("admin".to_string()));
});
let decision = mgr.request_approval(req).await;
assert_eq!(decision, ApprovalDecision::Approved);
}
// -----------------------------------------------------------------------
// request_approval — deny
// -----------------------------------------------------------------------
#[tokio::test]
async fn test_request_approval_deny() {
let mgr = Arc::new(default_manager());
let req = make_request("agent-1", "shell_exec", 60);
let request_id = req.id;
let mgr2 = Arc::clone(&mgr);
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let result = mgr2.resolve(request_id, ApprovalDecision::Denied, None);
assert!(result.is_ok());
});
let decision = mgr.request_approval(req).await;
assert_eq!(decision, ApprovalDecision::Denied);
}
// -----------------------------------------------------------------------
// max pending per agent
// -----------------------------------------------------------------------
#[tokio::test]
async fn test_max_pending_per_agent() {
let mgr = Arc::new(default_manager());
// Fill up 5 pending requests for agent-1 (they will all be waiting)
let mut ids = Vec::new();
for _ in 0..MAX_PENDING_PER_AGENT {
let req = make_request("agent-1", "shell_exec", 300);
ids.push(req.id);
let mgr_clone = Arc::clone(&mgr);
tokio::spawn(async move {
mgr_clone.request_approval(req).await;
});
}
// Give spawned tasks time to register
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(mgr.pending_count(), MAX_PENDING_PER_AGENT);
// 6th request for the same agent should be immediately denied
let req6 = make_request("agent-1", "shell_exec", 300);
let decision = mgr.request_approval(req6).await;
assert_eq!(decision, ApprovalDecision::Denied);
// A different agent should still be able to submit
let req_other = make_request("agent-2", "shell_exec", 300);
let other_id = req_other.id;
let mgr2 = Arc::clone(&mgr);
tokio::spawn(async move {
mgr2.request_approval(req_other).await;
});
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
assert_eq!(mgr.pending_count(), MAX_PENDING_PER_AGENT + 1);
// Cleanup: resolve all pending to avoid hanging tasks
for id in &ids {
let _ = mgr.resolve(*id, ApprovalDecision::Denied, None);
}
let _ = mgr.resolve(other_id, ApprovalDecision::Denied, None);
}
// -----------------------------------------------------------------------
// policy defaults
// -----------------------------------------------------------------------
#[test]
fn test_policy_defaults() {
let mgr = default_manager();
let policy = mgr.policy();
assert_eq!(policy.require_approval, vec!["shell_exec".to_string()]);
assert_eq!(policy.timeout_secs, 60);
assert!(!policy.auto_approve_autonomous);
}
}

View File

@@ -0,0 +1,316 @@
//! RBAC authentication and authorization for multi-user access control.
//!
//! The AuthManager maps platform user identities (Telegram ID, Discord ID, etc.)
//! to OpenFang users with roles, then enforces permission checks on actions.
use dashmap::DashMap;
use openfang_types::agent::UserId;
use openfang_types::config::UserConfig;
use openfang_types::error::{OpenFangError, OpenFangResult};
use std::fmt;
use tracing::info;
/// User roles with hierarchical permissions.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum UserRole {
/// Read-only access — can view agent output but cannot interact.
Viewer = 0,
/// Standard user — can chat with agents.
User = 1,
/// Admin — can spawn/kill agents, install skills, view usage.
Admin = 2,
/// Owner — full access including user management and config changes.
Owner = 3,
}
impl fmt::Display for UserRole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UserRole::Viewer => write!(f, "viewer"),
UserRole::User => write!(f, "user"),
UserRole::Admin => write!(f, "admin"),
UserRole::Owner => write!(f, "owner"),
}
}
}
impl UserRole {
/// Parse a role from a string.
pub fn from_str_role(s: &str) -> Self {
match s.to_lowercase().as_str() {
"owner" => UserRole::Owner,
"admin" => UserRole::Admin,
"viewer" => UserRole::Viewer,
_ => UserRole::User,
}
}
}
/// Actions that can be authorized.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Action {
/// Chat with an agent.
ChatWithAgent,
/// Spawn a new agent.
SpawnAgent,
/// Kill a running agent.
KillAgent,
/// Install a skill.
InstallSkill,
/// View kernel configuration.
ViewConfig,
/// Modify kernel configuration.
ModifyConfig,
/// View usage/billing data.
ViewUsage,
/// Manage users (create, delete, change roles).
ManageUsers,
}
impl Action {
/// Minimum role required for this action.
fn required_role(&self) -> UserRole {
match self {
Action::ChatWithAgent => UserRole::User,
Action::ViewConfig => UserRole::User,
Action::ViewUsage => UserRole::Admin,
Action::SpawnAgent => UserRole::Admin,
Action::KillAgent => UserRole::Admin,
Action::InstallSkill => UserRole::Admin,
Action::ModifyConfig => UserRole::Owner,
Action::ManageUsers => UserRole::Owner,
}
}
}
/// A resolved user identity.
#[derive(Debug, Clone)]
pub struct UserIdentity {
/// OpenFang user ID.
pub id: UserId,
/// Display name.
pub name: String,
/// Role.
pub role: UserRole,
}
/// RBAC authentication and authorization manager.
pub struct AuthManager {
/// Known users by their OpenFang user ID.
users: DashMap<UserId, UserIdentity>,
/// Channel binding index: "channel_type:platform_id" → UserId.
channel_index: DashMap<String, UserId>,
}
impl AuthManager {
/// Create a new AuthManager from kernel user configuration.
pub fn new(user_configs: &[UserConfig]) -> Self {
let manager = Self {
users: DashMap::new(),
channel_index: DashMap::new(),
};
for config in user_configs {
let user_id = UserId::new();
let role = UserRole::from_str_role(&config.role);
let identity = UserIdentity {
id: user_id,
name: config.name.clone(),
role,
};
manager.users.insert(user_id, identity);
// Index channel bindings
for (channel_type, platform_id) in &config.channel_bindings {
let key = format!("{channel_type}:{platform_id}");
manager.channel_index.insert(key, user_id);
}
info!(
user = %config.name,
role = %role,
bindings = config.channel_bindings.len(),
"Registered user"
);
}
manager
}
/// Identify a user from a channel identity.
///
/// Returns the OpenFang UserId if a matching channel binding exists,
/// or None for unrecognized users.
pub fn identify(&self, channel_type: &str, platform_id: &str) -> Option<UserId> {
let key = format!("{channel_type}:{platform_id}");
self.channel_index.get(&key).map(|r| *r.value())
}
/// Get a user's identity by their UserId.
pub fn get_user(&self, user_id: UserId) -> Option<UserIdentity> {
self.users.get(&user_id).map(|r| r.value().clone())
}
/// Authorize a user for an action.
///
/// Returns Ok(()) if the user has sufficient permissions, or AuthDenied error.
pub fn authorize(&self, user_id: UserId, action: &Action) -> OpenFangResult<()> {
let identity = self
.users
.get(&user_id)
.ok_or_else(|| OpenFangError::AuthDenied("Unknown user".to_string()))?;
let required = action.required_role();
if identity.role >= required {
Ok(())
} else {
Err(OpenFangError::AuthDenied(format!(
"User '{}' (role: {}) lacks permission for {:?} (requires: {})",
identity.name, identity.role, action, required
)))
}
}
/// Check if RBAC is configured (any users registered).
pub fn is_enabled(&self) -> bool {
!self.users.is_empty()
}
/// Get the count of registered users.
pub fn user_count(&self) -> usize {
self.users.len()
}
/// List all registered users.
pub fn list_users(&self) -> Vec<UserIdentity> {
self.users.iter().map(|r| r.value().clone()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn test_configs() -> Vec<UserConfig> {
vec![
UserConfig {
name: "Alice".to_string(),
role: "owner".to_string(),
channel_bindings: {
let mut m = HashMap::new();
m.insert("telegram".to_string(), "123456".to_string());
m.insert("discord".to_string(), "987654".to_string());
m
},
api_key_hash: None,
},
UserConfig {
name: "Guest".to_string(),
role: "user".to_string(),
channel_bindings: {
let mut m = HashMap::new();
m.insert("telegram".to_string(), "999999".to_string());
m
},
api_key_hash: None,
},
UserConfig {
name: "ReadOnly".to_string(),
role: "viewer".to_string(),
channel_bindings: HashMap::new(),
api_key_hash: None,
},
]
}
#[test]
fn test_user_registration() {
let manager = AuthManager::new(&test_configs());
assert!(manager.is_enabled());
assert_eq!(manager.user_count(), 3);
}
#[test]
fn test_identify_from_channel() {
let manager = AuthManager::new(&test_configs());
// Alice on Telegram
let owner_tg = manager.identify("telegram", "123456");
assert!(owner_tg.is_some());
// Alice on Discord
let owner_dc = manager.identify("discord", "987654");
assert!(owner_dc.is_some());
// Same user across channels
assert_eq!(owner_tg.unwrap(), owner_dc.unwrap());
// Unknown user
assert!(manager.identify("telegram", "unknown").is_none());
}
#[test]
fn test_owner_can_do_everything() {
let manager = AuthManager::new(&test_configs());
let owner_id = manager.identify("telegram", "123456").unwrap();
assert!(manager.authorize(owner_id, &Action::ChatWithAgent).is_ok());
assert!(manager.authorize(owner_id, &Action::SpawnAgent).is_ok());
assert!(manager.authorize(owner_id, &Action::KillAgent).is_ok());
assert!(manager.authorize(owner_id, &Action::ManageUsers).is_ok());
assert!(manager.authorize(owner_id, &Action::ModifyConfig).is_ok());
}
#[test]
fn test_user_limited_access() {
let manager = AuthManager::new(&test_configs());
let guest_id = manager.identify("telegram", "999999").unwrap();
// User can chat and view config
assert!(manager.authorize(guest_id, &Action::ChatWithAgent).is_ok());
assert!(manager.authorize(guest_id, &Action::ViewConfig).is_ok());
// User cannot spawn/kill/manage
assert!(manager.authorize(guest_id, &Action::SpawnAgent).is_err());
assert!(manager.authorize(guest_id, &Action::KillAgent).is_err());
assert!(manager.authorize(guest_id, &Action::ManageUsers).is_err());
}
#[test]
fn test_viewer_read_only() {
let manager = AuthManager::new(&test_configs());
let users = manager.list_users();
let viewer = users.iter().find(|u| u.name == "ReadOnly").unwrap();
// Viewer cannot even chat
assert!(manager
.authorize(viewer.id, &Action::ChatWithAgent)
.is_err());
}
#[test]
fn test_unknown_user_denied() {
let manager = AuthManager::new(&test_configs());
let fake_id = UserId::new();
assert!(manager.authorize(fake_id, &Action::ChatWithAgent).is_err());
}
#[test]
fn test_no_users_means_disabled() {
let manager = AuthManager::new(&[]);
assert!(!manager.is_enabled());
assert_eq!(manager.user_count(), 0);
}
#[test]
fn test_role_parsing() {
assert_eq!(UserRole::from_str_role("owner"), UserRole::Owner);
assert_eq!(UserRole::from_str_role("admin"), UserRole::Admin);
assert_eq!(UserRole::from_str_role("viewer"), UserRole::Viewer);
assert_eq!(UserRole::from_str_role("user"), UserRole::User);
assert_eq!(UserRole::from_str_role("OWNER"), UserRole::Owner);
assert_eq!(UserRole::from_str_role("unknown"), UserRole::User);
}
}

View File

@@ -0,0 +1,211 @@
//! Auto-reply background engine — trigger-driven background replies with concurrency control.
use openfang_types::agent::AgentId;
use openfang_types::config::AutoReplyConfig;
use std::sync::Arc;
use tokio::sync::Semaphore;
use tracing::{debug, info, warn};
/// Where to deliver the auto-reply result.
#[derive(Debug, Clone)]
pub struct AutoReplyChannel {
/// Channel type string (e.g., "telegram", "discord").
pub channel_type: String,
/// Peer/user ID to send the reply to.
pub peer_id: String,
/// Optional thread ID for threaded replies.
pub thread_id: Option<String>,
}
/// Auto-reply engine with concurrency limits and suppression patterns.
pub struct AutoReplyEngine {
config: AutoReplyConfig,
semaphore: Arc<Semaphore>,
}
impl AutoReplyEngine {
/// Create a new auto-reply engine from configuration.
pub fn new(config: AutoReplyConfig) -> Self {
let permits = config.max_concurrent.max(1);
Self {
semaphore: Arc::new(Semaphore::new(permits)),
config,
}
}
/// Check if a message should trigger auto-reply.
/// Returns `None` if suppressed or disabled, `Some(agent_id)` if should auto-reply.
pub fn should_reply(
&self,
message: &str,
_channel_type: &str,
agent_id: AgentId,
) -> Option<AgentId> {
if !self.config.enabled {
return None;
}
// Check suppression patterns
let lower = message.to_lowercase();
for pattern in &self.config.suppress_patterns {
if lower.contains(&pattern.to_lowercase()) {
debug!(pattern = %pattern, "Auto-reply suppressed by pattern");
return None;
}
}
Some(agent_id)
}
/// Execute an auto-reply in the background.
/// Returns a JoinHandle for the spawned task.
///
/// The `send_fn` is called with the agent response to deliver it back to the channel.
pub async fn execute_reply<F>(
&self,
kernel_handle: Arc<dyn openfang_runtime::kernel_handle::KernelHandle>,
agent_id: AgentId,
message: String,
reply_channel: AutoReplyChannel,
send_fn: F,
) -> Result<tokio::task::JoinHandle<()>, String>
where
F: Fn(String, AutoReplyChannel) -> futures::future::BoxFuture<'static, ()>
+ Send
+ Sync
+ 'static,
{
// Try to acquire a semaphore permit
let permit = match self.semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
return Err(format!(
"Auto-reply concurrency limit reached ({} max)",
self.config.max_concurrent
));
}
};
let timeout_secs = self.config.timeout_secs;
let handle = tokio::spawn(async move {
let _permit = permit; // Hold permit until task completes
info!(
agent = %agent_id,
channel = %reply_channel.channel_type,
peer = %reply_channel.peer_id,
"Starting auto-reply"
);
let result = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
kernel_handle.send_to_agent(&agent_id.to_string(), &message),
)
.await;
match result {
Ok(Ok(response)) => {
send_fn(response, reply_channel).await;
}
Ok(Err(e)) => {
warn!(agent = %agent_id, error = %e, "Auto-reply agent error");
}
Err(_) => {
warn!(agent = %agent_id, timeout = timeout_secs, "Auto-reply timed out");
}
}
});
Ok(handle)
}
/// Check if auto-reply is enabled.
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
/// Get the current configuration (read-only).
pub fn config(&self) -> &AutoReplyConfig {
&self.config
}
/// Get available permits (for monitoring).
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config(enabled: bool) -> AutoReplyConfig {
AutoReplyConfig {
enabled,
max_concurrent: 3,
timeout_secs: 120,
suppress_patterns: vec!["/stop".to_string(), "/pause".to_string()],
}
}
#[test]
fn test_disabled_engine() {
let engine = AutoReplyEngine::new(test_config(false));
let agent_id = AgentId::new();
assert!(engine.should_reply("hello", "telegram", agent_id).is_none());
}
#[test]
fn test_enabled_engine_allows() {
let engine = AutoReplyEngine::new(test_config(true));
let agent_id = AgentId::new();
let result = engine.should_reply("hello there", "telegram", agent_id);
assert_eq!(result, Some(agent_id));
}
#[test]
fn test_suppression_patterns() {
let engine = AutoReplyEngine::new(test_config(true));
let agent_id = AgentId::new();
// Should be suppressed
assert!(engine.should_reply("/stop", "telegram", agent_id).is_none());
assert!(engine
.should_reply("please /pause this", "telegram", agent_id)
.is_none());
// Not suppressed
assert!(engine.should_reply("hello", "telegram", agent_id).is_some());
}
#[test]
fn test_concurrency_limit() {
let config = AutoReplyConfig {
enabled: true,
max_concurrent: 2,
timeout_secs: 120,
suppress_patterns: Vec::new(),
};
let engine = AutoReplyEngine::new(config);
assert_eq!(engine.available_permits(), 2);
}
#[test]
fn test_is_enabled() {
let on = AutoReplyEngine::new(test_config(true));
assert!(on.is_enabled());
let off = AutoReplyEngine::new(test_config(false));
assert!(!off.is_enabled());
}
#[test]
fn test_default_config() {
let config = AutoReplyConfig::default();
assert!(!config.enabled);
assert_eq!(config.max_concurrent, 3);
assert_eq!(config.timeout_secs, 120);
assert!(config.suppress_patterns.contains(&"/stop".to_string()));
}
}

View File

@@ -0,0 +1,457 @@
//! Background agent executor — runs agents autonomously on schedules, timers, and conditions.
//!
//! Supports three autonomous modes:
//! - **Continuous**: Agent self-prompts on a fixed interval.
//! - **Periodic**: Agent wakes on a simplified cron schedule (e.g. "every 5m").
//! - **Proactive**: Agent wakes when matching events fire (via the trigger engine).
use crate::triggers::TriggerPattern;
use dashmap::DashMap;
use openfang_types::agent::{AgentId, ScheduleMode};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
/// Maximum number of concurrent background LLM calls across all agents.
const MAX_CONCURRENT_BG_LLM: usize = 5;
/// Manages background task loops for autonomous agents.
pub struct BackgroundExecutor {
/// Running background task handles, keyed by agent ID.
tasks: DashMap<AgentId, JoinHandle<()>>,
/// Shutdown signal receiver (from Supervisor).
shutdown_rx: watch::Receiver<bool>,
/// SECURITY: Global semaphore to limit concurrent background LLM calls.
llm_semaphore: Arc<tokio::sync::Semaphore>,
}
impl BackgroundExecutor {
/// Create a new executor bound to the supervisor's shutdown signal.
pub fn new(shutdown_rx: watch::Receiver<bool>) -> Self {
Self {
tasks: DashMap::new(),
shutdown_rx,
llm_semaphore: Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT_BG_LLM)),
}
}
/// Start a background loop for an agent based on its schedule mode.
///
/// For `Continuous` and `Periodic` modes, spawns a tokio task that
/// periodically sends a self-prompt message to the agent.
/// For `Proactive` mode, registers triggers — no dedicated task needed.
///
/// `send_message` is a closure that sends a message to the given agent
/// and returns a result. It captures an `Arc<OpenFangKernel>` from the caller.
pub fn start_agent<F>(
&self,
agent_id: AgentId,
agent_name: &str,
schedule: &ScheduleMode,
send_message: F,
) where
F: Fn(AgentId, String) -> tokio::task::JoinHandle<()> + Send + Sync + 'static,
{
match schedule {
ScheduleMode::Reactive => {} // nothing to do
ScheduleMode::Continuous {
check_interval_secs,
} => {
let interval = std::time::Duration::from_secs(*check_interval_secs);
let name = agent_name.to_string();
let mut shutdown = self.shutdown_rx.clone();
let busy = Arc::new(AtomicBool::new(false));
let semaphore = self.llm_semaphore.clone();
info!(
agent = %name, id = %agent_id,
interval_secs = check_interval_secs,
"Starting continuous background loop"
);
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = tokio::time::sleep(interval) => {}
_ = shutdown.changed() => {
info!(agent = %name, "Continuous loop: shutdown signal received");
break;
}
}
// Skip if previous tick is still running
if busy
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
debug!(agent = %name, "Continuous loop: skipping tick (busy)");
continue;
}
// SECURITY: Acquire global LLM concurrency permit
let permit = match semaphore.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => {
busy.store(false, Ordering::SeqCst);
break; // Semaphore closed
}
};
let prompt = format!(
"[AUTONOMOUS TICK] You are running in continuous mode. \
Check your goals, review shared memory for pending tasks, \
and take any necessary actions. Agent: {name}"
);
debug!(agent = %name, "Continuous loop: sending self-prompt");
let busy_clone = busy.clone();
let jh = (send_message)(agent_id, prompt);
// Spawn a watcher that clears the busy flag and drops permit when done
tokio::spawn(async move {
let _ = jh.await;
drop(permit);
busy_clone.store(false, Ordering::SeqCst);
});
}
});
self.tasks.insert(agent_id, handle);
}
ScheduleMode::Periodic { cron } => {
let interval_secs = parse_cron_to_secs(cron);
let interval = std::time::Duration::from_secs(interval_secs);
let name = agent_name.to_string();
let cron_owned = cron.clone();
let mut shutdown = self.shutdown_rx.clone();
let busy = Arc::new(AtomicBool::new(false));
let semaphore = self.llm_semaphore.clone();
info!(
agent = %name, id = %agent_id,
cron = %cron, interval_secs = interval_secs,
"Starting periodic background loop"
);
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = tokio::time::sleep(interval) => {}
_ = shutdown.changed() => {
info!(agent = %name, "Periodic loop: shutdown signal received");
break;
}
}
if busy
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
debug!(agent = %name, "Periodic loop: skipping tick (busy)");
continue;
}
// SECURITY: Acquire global LLM concurrency permit
let permit = match semaphore.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => {
busy.store(false, Ordering::SeqCst);
break; // Semaphore closed
}
};
let prompt = format!(
"[SCHEDULED TICK] You are running on a periodic schedule ({cron_owned}). \
Perform your routine duties. Agent: {name}"
);
debug!(agent = %name, "Periodic loop: sending scheduled prompt");
let busy_clone = busy.clone();
let jh = (send_message)(agent_id, prompt);
tokio::spawn(async move {
let _ = jh.await;
drop(permit);
busy_clone.store(false, Ordering::SeqCst);
});
}
});
self.tasks.insert(agent_id, handle);
}
ScheduleMode::Proactive { .. } => {
// Proactive agents rely on triggers, not a dedicated loop.
// Triggers are registered by the kernel during spawn_agent / start_background_agents.
debug!(agent = %agent_name, "Proactive agent — triggers handle activation");
}
}
}
/// Stop the background loop for an agent, if one is running.
pub fn stop_agent(&self, agent_id: AgentId) {
if let Some((_, handle)) = self.tasks.remove(&agent_id) {
handle.abort();
info!(id = %agent_id, "Background loop stopped");
}
}
/// Number of actively running background loops.
pub fn active_count(&self) -> usize {
self.tasks.len()
}
}
/// Parse a proactive condition string into a `TriggerPattern`.
///
/// Supported formats:
/// - `"event:agent_spawned"` → `TriggerPattern::AgentSpawned { name_pattern: "*" }`
/// - `"event:agent_terminated"` → `TriggerPattern::AgentTerminated`
/// - `"event:lifecycle"` → `TriggerPattern::Lifecycle`
/// - `"event:system"` → `TriggerPattern::System`
/// - `"memory:some_key"` → `TriggerPattern::MemoryKeyPattern { key_pattern: "some_key" }`
/// - `"all"` → `TriggerPattern::All`
pub fn parse_condition(condition: &str) -> Option<TriggerPattern> {
let condition = condition.trim();
if condition.eq_ignore_ascii_case("all") {
return Some(TriggerPattern::All);
}
if let Some(event_kind) = condition.strip_prefix("event:") {
let kind = event_kind.trim().to_lowercase();
return match kind.as_str() {
"agent_spawned" => Some(TriggerPattern::AgentSpawned {
name_pattern: "*".to_string(),
}),
"agent_terminated" => Some(TriggerPattern::AgentTerminated),
"lifecycle" => Some(TriggerPattern::Lifecycle),
"system" => Some(TriggerPattern::System),
"memory_update" => Some(TriggerPattern::MemoryUpdate),
other => {
warn!(condition = %condition, "Unknown event condition: {other}");
None
}
};
}
if let Some(key) = condition.strip_prefix("memory:") {
return Some(TriggerPattern::MemoryKeyPattern {
key_pattern: key.trim().to_string(),
});
}
warn!(condition = %condition, "Unrecognized proactive condition format");
None
}
/// Parse a simplified cron expression into seconds.
///
/// Supported formats:
/// - `"every 30s"` → 30
/// - `"every 5m"` → 300
/// - `"every 1h"` → 3600
/// - `"every 2d"` → 172800
///
/// Falls back to 300 seconds (5 minutes) for unparseable expressions.
pub fn parse_cron_to_secs(cron: &str) -> u64 {
let cron = cron.trim().to_lowercase();
// Try "every <N><unit>" format
if let Some(rest) = cron.strip_prefix("every ") {
let rest = rest.trim();
if let Some(num_str) = rest.strip_suffix('s') {
if let Ok(n) = num_str.trim().parse::<u64>() {
return n;
}
}
if let Some(num_str) = rest.strip_suffix('m') {
if let Ok(n) = num_str.trim().parse::<u64>() {
return n * 60;
}
}
if let Some(num_str) = rest.strip_suffix('h') {
if let Ok(n) = num_str.trim().parse::<u64>() {
return n * 3600;
}
}
if let Some(num_str) = rest.strip_suffix('d') {
if let Ok(n) = num_str.trim().parse::<u64>() {
return n * 86400;
}
}
}
warn!(cron = %cron, "Unparseable cron expression, defaulting to 300s");
300
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_cron_seconds() {
assert_eq!(parse_cron_to_secs("every 30s"), 30);
assert_eq!(parse_cron_to_secs("every 1s"), 1);
}
#[test]
fn test_parse_cron_minutes() {
assert_eq!(parse_cron_to_secs("every 5m"), 300);
assert_eq!(parse_cron_to_secs("every 1m"), 60);
}
#[test]
fn test_parse_cron_hours() {
assert_eq!(parse_cron_to_secs("every 1h"), 3600);
assert_eq!(parse_cron_to_secs("every 2h"), 7200);
}
#[test]
fn test_parse_cron_days() {
assert_eq!(parse_cron_to_secs("every 1d"), 86400);
}
#[test]
fn test_parse_cron_fallback() {
// Unparseable → 300
assert_eq!(parse_cron_to_secs("*/5 * * * *"), 300);
assert_eq!(parse_cron_to_secs("gibberish"), 300);
}
#[test]
fn test_parse_condition_events() {
assert!(matches!(
parse_condition("event:agent_spawned"),
Some(TriggerPattern::AgentSpawned { .. })
));
assert!(matches!(
parse_condition("event:agent_terminated"),
Some(TriggerPattern::AgentTerminated)
));
assert!(matches!(
parse_condition("event:lifecycle"),
Some(TriggerPattern::Lifecycle)
));
assert!(matches!(
parse_condition("event:system"),
Some(TriggerPattern::System)
));
assert!(matches!(
parse_condition("event:memory_update"),
Some(TriggerPattern::MemoryUpdate)
));
}
#[test]
fn test_parse_condition_memory() {
match parse_condition("memory:agent.*.status") {
Some(TriggerPattern::MemoryKeyPattern { key_pattern }) => {
assert_eq!(key_pattern, "agent.*.status");
}
other => panic!("Expected MemoryKeyPattern, got {:?}", other),
}
}
#[test]
fn test_parse_condition_all() {
assert!(matches!(parse_condition("all"), Some(TriggerPattern::All)));
}
#[test]
fn test_parse_condition_unknown() {
assert!(parse_condition("event:unknown_thing").is_none());
assert!(parse_condition("badprefix:foo").is_none());
}
#[tokio::test]
async fn test_continuous_shutdown() {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let executor = BackgroundExecutor::new(shutdown_rx);
let agent_id = AgentId::new();
let tick_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let tick_clone = tick_count.clone();
let schedule = ScheduleMode::Continuous {
check_interval_secs: 1, // 1 second for fast test
};
executor.start_agent(agent_id, "test-agent", &schedule, move |_id, _msg| {
let tc = tick_clone.clone();
tokio::spawn(async move {
tc.fetch_add(1, Ordering::SeqCst);
})
});
assert_eq!(executor.active_count(), 1);
// Wait for at least 1 tick
tokio::time::sleep(std::time::Duration::from_millis(1500)).await;
assert!(tick_count.load(Ordering::SeqCst) >= 1);
// Shutdown
let _ = shutdown_tx.send(true);
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
// The loop should have exited (handle finished)
// Active count still shows the entry until stop_agent is called
executor.stop_agent(agent_id);
assert_eq!(executor.active_count(), 0);
}
#[tokio::test]
async fn test_skip_if_busy() {
let (_shutdown_tx, shutdown_rx) = watch::channel(false);
let executor = BackgroundExecutor::new(shutdown_rx);
let agent_id = AgentId::new();
let tick_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let tick_clone = tick_count.clone();
let schedule = ScheduleMode::Continuous {
check_interval_secs: 1,
};
// Each tick takes 3 seconds — should cause subsequent ticks to be skipped
executor.start_agent(agent_id, "slow-agent", &schedule, move |_id, _msg| {
let tc = tick_clone.clone();
tokio::spawn(async move {
tc.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
})
});
// Wait 2.5 seconds: 1 tick should fire at t=1s, second at t=2s should be skipped (busy)
tokio::time::sleep(std::time::Duration::from_millis(2500)).await;
let ticks = tick_count.load(Ordering::SeqCst);
// Should be exactly 1 because the first tick is still "busy" when the second arrives
assert_eq!(ticks, 1, "Expected 1 tick (skip-if-busy), got {ticks}");
executor.stop_agent(agent_id);
}
#[test]
fn test_executor_active_count() {
let (_tx, rx) = watch::channel(false);
let executor = BackgroundExecutor::new(rx);
assert_eq!(executor.active_count(), 0);
// Reactive mode → no background task
let id = AgentId::new();
executor.start_agent(id, "reactive", &ScheduleMode::Reactive, |_id, _msg| {
tokio::spawn(async {})
});
assert_eq!(executor.active_count(), 0);
// Proactive mode → no dedicated task
let id2 = AgentId::new();
executor.start_agent(
id2,
"proactive",
&ScheduleMode::Proactive {
conditions: vec!["event:agent_spawned".to_string()],
},
|_id, _msg| tokio::spawn(async {}),
);
assert_eq!(executor.active_count(), 0);
}
}

View File

@@ -0,0 +1,95 @@
//! Capability manager — enforces capability-based security.
use dashmap::DashMap;
use openfang_types::agent::AgentId;
use openfang_types::capability::{capability_matches, Capability, CapabilityCheck};
use tracing::debug;
/// Manages capability grants for all agents.
pub struct CapabilityManager {
/// Granted capabilities per agent.
grants: DashMap<AgentId, Vec<Capability>>,
}
impl CapabilityManager {
/// Create a new capability manager.
pub fn new() -> Self {
Self {
grants: DashMap::new(),
}
}
/// Grant capabilities to an agent.
pub fn grant(&self, agent_id: AgentId, capabilities: Vec<Capability>) {
self.grants.insert(agent_id, capabilities);
}
/// Check whether an agent has a specific capability.
pub fn check(&self, agent_id: AgentId, required: &Capability) -> CapabilityCheck {
let grants = match self.grants.get(&agent_id) {
Some(g) => g,
None => {
return CapabilityCheck::Denied(format!(
"No capabilities registered for agent {agent_id}"
))
}
};
for granted in grants.value() {
if capability_matches(granted, required) {
debug!(agent = %agent_id, ?required, "Capability granted");
return CapabilityCheck::Granted;
}
}
CapabilityCheck::Denied(format!(
"Agent {agent_id} does not have capability: {required:?}"
))
}
/// List all capabilities for an agent.
pub fn list(&self, agent_id: AgentId) -> Vec<Capability> {
self.grants
.get(&agent_id)
.map(|g| g.value().clone())
.unwrap_or_default()
}
/// Remove all capabilities for an agent.
pub fn revoke_all(&self, agent_id: AgentId) {
self.grants.remove(&agent_id);
}
}
impl Default for CapabilityManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grant_and_check() {
let mgr = CapabilityManager::new();
let id = AgentId::new();
mgr.grant(id, vec![Capability::ToolInvoke("file_read".to_string())]);
assert!(mgr
.check(id, &Capability::ToolInvoke("file_read".to_string()))
.is_granted());
assert!(!mgr
.check(id, &Capability::ToolInvoke("shell_exec".to_string()))
.is_granted());
}
#[test]
fn test_no_grants() {
let mgr = CapabilityManager::new();
let id = AgentId::new();
assert!(!mgr
.check(id, &Capability::ToolInvoke("anything".to_string()))
.is_granted());
}
}

View File

@@ -0,0 +1,434 @@
//! Configuration loading from `~/.openfang/config.toml` with defaults.
//!
//! Supports config includes: the `include` field specifies additional TOML files
//! to load and deep-merge before the root config (root overrides includes).
use openfang_types::config::KernelConfig;
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use tracing::info;
/// Maximum include nesting depth.
const MAX_INCLUDE_DEPTH: u32 = 10;
/// Load kernel configuration from a TOML file, with defaults.
///
/// If the config contains an `include` field, included files are loaded
/// and deep-merged first, then the root config overrides them.
pub fn load_config(path: Option<&Path>) -> KernelConfig {
let config_path = path
.map(|p| p.to_path_buf())
.unwrap_or_else(default_config_path);
if config_path.exists() {
match std::fs::read_to_string(&config_path) {
Ok(contents) => match toml::from_str::<toml::Value>(&contents) {
Ok(mut root_value) => {
// Process includes before deserializing
let config_dir = config_path
.parent()
.unwrap_or_else(|| Path::new("."))
.to_path_buf();
let mut visited = HashSet::new();
if let Ok(canonical) = std::fs::canonicalize(&config_path) {
visited.insert(canonical);
} else {
visited.insert(config_path.clone());
}
if let Err(e) =
resolve_config_includes(&mut root_value, &config_dir, &mut visited, 0)
{
tracing::warn!(
error = %e,
"Config include resolution failed, using root config only"
);
}
// Remove the `include` field before deserializing to avoid confusion
if let toml::Value::Table(ref mut tbl) = root_value {
tbl.remove("include");
}
match root_value.try_into::<KernelConfig>() {
Ok(config) => {
info!(path = %config_path.display(), "Loaded configuration");
return config;
}
Err(e) => {
tracing::warn!(
error = %e,
path = %config_path.display(),
"Failed to deserialize merged config, using defaults"
);
}
}
}
Err(e) => {
tracing::warn!(
error = %e,
path = %config_path.display(),
"Failed to parse config, using defaults"
);
}
},
Err(e) => {
tracing::warn!(
error = %e,
path = %config_path.display(),
"Failed to read config file, using defaults"
);
}
}
} else {
info!(
path = %config_path.display(),
"Config file not found, using defaults"
);
}
KernelConfig::default()
}
/// Resolve config includes by deep-merging included files into the root value.
///
/// Included files are loaded first and the root config overrides them.
/// Security: rejects absolute paths, `..` components, and circular references.
fn resolve_config_includes(
root_value: &mut toml::Value,
config_dir: &Path,
visited: &mut HashSet<PathBuf>,
depth: u32,
) -> Result<(), String> {
if depth > MAX_INCLUDE_DEPTH {
return Err(format!(
"Config include depth exceeded maximum of {MAX_INCLUDE_DEPTH}"
));
}
// Extract include list from the current value
let includes = match root_value {
toml::Value::Table(tbl) => {
if let Some(toml::Value::Array(arr)) = tbl.get("include") {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect::<Vec<_>>()
} else {
return Ok(());
}
}
_ => return Ok(()),
};
if includes.is_empty() {
return Ok(());
}
// Merge each include (earlier includes are overridden by later ones,
// and the root config overrides everything).
let mut merged_base = toml::Value::Table(toml::map::Map::new());
for include_path_str in &includes {
// SECURITY: reject absolute paths
let include_path = Path::new(include_path_str);
if include_path.is_absolute() {
return Err(format!(
"Config include rejects absolute path: {include_path_str}"
));
}
// SECURITY: reject `..` components
for component in include_path.components() {
if let std::path::Component::ParentDir = component {
return Err(format!(
"Config include rejects path traversal: {include_path_str}"
));
}
}
let resolved = config_dir.join(include_path);
// SECURITY: verify resolved path stays within config dir
let canonical = std::fs::canonicalize(&resolved).map_err(|e| {
format!(
"Config include '{}' cannot be resolved: {e}",
include_path_str
)
})?;
let canonical_dir = std::fs::canonicalize(config_dir)
.map_err(|e| format!("Config dir cannot be canonicalized: {e}"))?;
if !canonical.starts_with(&canonical_dir) {
return Err(format!(
"Config include '{}' escapes config directory",
include_path_str
));
}
// SECURITY: circular detection
if !visited.insert(canonical.clone()) {
return Err(format!(
"Circular config include detected: {include_path_str}"
));
}
info!(include = %include_path_str, "Loading config include");
let contents = std::fs::read_to_string(&canonical)
.map_err(|e| format!("Failed to read config include '{}': {e}", include_path_str))?;
let mut include_value: toml::Value = toml::from_str(&contents)
.map_err(|e| format!("Failed to parse config include '{}': {e}", include_path_str))?;
// Recursively resolve includes in the included file
let include_dir = canonical.parent().unwrap_or(config_dir).to_path_buf();
resolve_config_includes(&mut include_value, &include_dir, visited, depth + 1)?;
// Remove include field from the included file
if let toml::Value::Table(ref mut tbl) = include_value {
tbl.remove("include");
}
// Deep merge: include overrides the base built so far
deep_merge_toml(&mut merged_base, &include_value);
}
// Now deep merge: root overrides the merged includes
// Save root's current values (minus include), then merge root on top
let root_without_include = {
let mut v = root_value.clone();
if let toml::Value::Table(ref mut tbl) = v {
tbl.remove("include");
}
v
};
deep_merge_toml(&mut merged_base, &root_without_include);
*root_value = merged_base;
Ok(())
}
/// Deep-merge two TOML values. `overlay` values override `base` values.
/// For tables, recursively merge. For everything else, overlay wins.
pub fn deep_merge_toml(base: &mut toml::Value, overlay: &toml::Value) {
match (base, overlay) {
(toml::Value::Table(base_tbl), toml::Value::Table(overlay_tbl)) => {
for (key, overlay_val) in overlay_tbl {
if let Some(base_val) = base_tbl.get_mut(key) {
deep_merge_toml(base_val, overlay_val);
} else {
base_tbl.insert(key.clone(), overlay_val.clone());
}
}
}
(base, overlay) => {
*base = overlay.clone();
}
}
}
/// Get the default config file path.
pub fn default_config_path() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(std::env::temp_dir)
.join(".openfang")
.join("config.toml")
}
/// Get the default OpenFang home directory.
pub fn openfang_home() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(std::env::temp_dir)
.join(".openfang")
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn test_load_config_defaults() {
let config = load_config(None);
assert_eq!(config.log_level, "info");
}
#[test]
fn test_load_config_missing_file() {
let config = load_config(Some(Path::new("/nonexistent/config.toml")));
assert_eq!(config.log_level, "info");
}
#[test]
fn test_deep_merge_simple() {
let mut base: toml::Value = toml::from_str(
r#"
log_level = "debug"
api_listen = "0.0.0.0:4200"
"#,
)
.unwrap();
let overlay: toml::Value = toml::from_str(
r#"
log_level = "info"
network_enabled = true
"#,
)
.unwrap();
deep_merge_toml(&mut base, &overlay);
assert_eq!(base["log_level"].as_str(), Some("info"));
assert_eq!(base["api_listen"].as_str(), Some("0.0.0.0:4200"));
assert_eq!(base["network_enabled"].as_bool(), Some(true));
}
#[test]
fn test_deep_merge_nested_tables() {
let mut base: toml::Value = toml::from_str(
r#"
[memory]
decay_rate = 0.1
consolidation_threshold = 10000
"#,
)
.unwrap();
let overlay: toml::Value = toml::from_str(
r#"
[memory]
decay_rate = 0.5
"#,
)
.unwrap();
deep_merge_toml(&mut base, &overlay);
let mem = base["memory"].as_table().unwrap();
assert_eq!(mem["decay_rate"].as_float(), Some(0.5));
assert_eq!(mem["consolidation_threshold"].as_integer(), Some(10000));
}
#[test]
fn test_basic_include() {
let dir = tempfile::tempdir().unwrap();
let base_path = dir.path().join("base.toml");
let root_path = dir.path().join("config.toml");
// Base config
let mut f = std::fs::File::create(&base_path).unwrap();
writeln!(f, "log_level = \"debug\"").unwrap();
writeln!(f, "api_listen = \"0.0.0.0:9999\"").unwrap();
drop(f);
// Root config (includes base, overrides log_level)
let mut f = std::fs::File::create(&root_path).unwrap();
writeln!(f, "include = [\"base.toml\"]").unwrap();
writeln!(f, "log_level = \"warn\"").unwrap();
drop(f);
let config = load_config(Some(&root_path));
assert_eq!(config.log_level, "warn"); // root overrides
assert_eq!(config.api_listen, "0.0.0.0:9999"); // from base
}
#[test]
fn test_nested_include() {
let dir = tempfile::tempdir().unwrap();
let grandchild = dir.path().join("grandchild.toml");
let child = dir.path().join("child.toml");
let root = dir.path().join("config.toml");
let mut f = std::fs::File::create(&grandchild).unwrap();
writeln!(f, "log_level = \"trace\"").unwrap();
drop(f);
let mut f = std::fs::File::create(&child).unwrap();
writeln!(f, "include = [\"grandchild.toml\"]").unwrap();
writeln!(f, "log_level = \"debug\"").unwrap();
drop(f);
let mut f = std::fs::File::create(&root).unwrap();
writeln!(f, "include = [\"child.toml\"]").unwrap();
writeln!(f, "log_level = \"info\"").unwrap();
drop(f);
let config = load_config(Some(&root));
assert_eq!(config.log_level, "info"); // root wins
}
#[test]
fn test_circular_include_detected() {
let dir = tempfile::tempdir().unwrap();
let a_path = dir.path().join("a.toml");
let b_path = dir.path().join("b.toml");
let mut f = std::fs::File::create(&a_path).unwrap();
writeln!(f, "include = [\"b.toml\"]").unwrap();
writeln!(f, "log_level = \"info\"").unwrap();
drop(f);
let mut f = std::fs::File::create(&b_path).unwrap();
writeln!(f, "include = [\"a.toml\"]").unwrap();
drop(f);
// Should not panic — circular detection triggers, falls back gracefully
let config = load_config(Some(&a_path));
// Falls back to defaults due to the circular error
assert!(!config.log_level.is_empty());
}
#[test]
fn test_path_traversal_blocked() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path().join("config.toml");
let mut f = std::fs::File::create(&root).unwrap();
writeln!(f, "include = [\"../etc/passwd\"]").unwrap();
drop(f);
// Should not panic — path traversal triggers error, falls back
let config = load_config(Some(&root));
assert_eq!(config.log_level, "info"); // defaults
}
#[test]
fn test_max_depth_exceeded() {
let dir = tempfile::tempdir().unwrap();
// Create a chain of 12 files (exceeds MAX_INCLUDE_DEPTH=10)
for i in (0..12).rev() {
let name = format!("level{i}.toml");
let path = dir.path().join(&name);
let mut f = std::fs::File::create(&path).unwrap();
if i < 11 {
let next = format!("level{}.toml", i + 1);
writeln!(f, "include = [\"{next}\"]").unwrap();
}
writeln!(f, "log_level = \"level{i}\"").unwrap();
drop(f);
}
let root = dir.path().join("level0.toml");
let config = load_config(Some(&root));
// Falls back due to depth limit — but should not panic
assert!(!config.log_level.is_empty());
}
#[test]
fn test_absolute_path_rejected() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path().join("config.toml");
let mut f = std::fs::File::create(&root).unwrap();
writeln!(f, "include = [\"/etc/shadow\"]").unwrap();
drop(f);
let config = load_config(Some(&root));
assert_eq!(config.log_level, "info"); // defaults
}
#[test]
fn test_no_includes_works() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path().join("config.toml");
let mut f = std::fs::File::create(&root).unwrap();
writeln!(f, "log_level = \"trace\"").unwrap();
drop(f);
let config = load_config(Some(&root));
assert_eq!(config.log_level, "trace");
}
}

View File

@@ -0,0 +1,674 @@
//! Config hot-reload — diffs two `KernelConfig` instances and produces a `ReloadPlan`.
//!
//! **Hot-reload safe**: channels, skills, usage footer, web config, browser,
//! approval policy, cron settings, webhook triggers, extensions.
//!
//! **No-op** (informational only): log_level, language, mode.
//!
//! **Restart required**: api_listen, api_key, network, memory, default_model.
use openfang_types::config::{KernelConfig, ReloadMode};
use tracing::{info, warn};
// ---------------------------------------------------------------------------
// HotAction — what can be changed at runtime without restart
// ---------------------------------------------------------------------------
/// An individual action that can be applied at runtime (hot-reload).
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HotAction {
/// Channel configuration changed — reload channel bridges.
ReloadChannels,
/// Skill configuration changed — reload skill registry.
ReloadSkills,
/// Usage footer mode changed.
UpdateUsageFooter,
/// Web config changed — rebuild web tools context.
ReloadWebConfig,
/// Browser config changed.
ReloadBrowserConfig,
/// Approval policy changed.
UpdateApprovalPolicy,
/// Cron max jobs changed.
UpdateCronConfig,
/// Webhook trigger config changed.
UpdateWebhookConfig,
/// Extension config changed.
ReloadExtensions,
/// MCP server list changed — reconnect MCP clients.
ReloadMcpServers,
/// A2A config changed.
ReloadA2aConfig,
/// Fallback provider chain changed.
ReloadFallbackProviders,
/// Provider base URL overrides changed.
ReloadProviderUrls,
}
// ---------------------------------------------------------------------------
// ReloadPlan — the output of diffing two configs
// ---------------------------------------------------------------------------
/// A categorized plan for applying config changes.
///
/// After building a plan via [`build_reload_plan`], callers inspect
/// `restart_required` to decide whether a full restart is needed or
/// the `hot_actions` can be applied in-place.
#[derive(Debug, Clone)]
pub struct ReloadPlan {
/// Whether a full restart is needed.
pub restart_required: bool,
/// Human-readable reasons why restart is required.
pub restart_reasons: Vec<String>,
/// Actions that can be hot-reloaded without restart.
pub hot_actions: Vec<HotAction>,
/// Fields that changed but are no-ops (informational only).
pub noop_changes: Vec<String>,
}
impl ReloadPlan {
/// Whether any changes were detected at all.
pub fn has_changes(&self) -> bool {
self.restart_required || !self.hot_actions.is_empty() || !self.noop_changes.is_empty()
}
/// Whether the plan can be applied without restart.
pub fn is_hot_reloadable(&self) -> bool {
!self.restart_required
}
/// Log a human-readable summary of the plan.
pub fn log_summary(&self) {
if !self.has_changes() {
info!("config reload: no changes detected");
return;
}
if self.restart_required {
warn!(
"config reload: restart required — {}",
self.restart_reasons.join("; ")
);
}
for action in &self.hot_actions {
info!("config reload: hot-reload action queued — {action:?}");
}
for noop in &self.noop_changes {
info!("config reload: no-op change — {noop}");
}
}
}
// ---------------------------------------------------------------------------
// build_reload_plan
// ---------------------------------------------------------------------------
/// Compare JSON-serialized forms of a field. Returns `true` when the
/// serialized representations differ (or if one side fails to serialize).
fn field_changed<T: serde::Serialize>(old: &T, new: &T) -> bool {
let old_json = serde_json::to_string(old).ok();
let new_json = serde_json::to_string(new).ok();
old_json != new_json
}
/// Diff two configurations and produce a reload plan.
///
/// The plan categorizes every detected change into one of three buckets:
///
/// 1. **restart_required** — the change touches something that cannot be
/// patched at runtime (e.g. the listen address or database path).
/// 2. **hot_actions** — the change can be applied without restarting.
/// 3. **noop_changes** — the change is informational; no action needed.
pub fn build_reload_plan(old: &KernelConfig, new: &KernelConfig) -> ReloadPlan {
let mut plan = ReloadPlan {
restart_required: false,
restart_reasons: Vec::new(),
hot_actions: Vec::new(),
noop_changes: Vec::new(),
};
// ----- Restart-required fields -----
if old.api_listen != new.api_listen {
plan.restart_required = true;
plan.restart_reasons.push(format!(
"api_listen changed: {} -> {}",
old.api_listen, new.api_listen
));
}
if old.api_key != new.api_key {
plan.restart_required = true;
plan.restart_reasons.push("api_key changed".to_string());
}
if old.network_enabled != new.network_enabled {
plan.restart_required = true;
plan.restart_reasons
.push("network_enabled changed".to_string());
}
// Network config (shared_secret, listen_addresses, etc.)
if field_changed(&old.network, &new.network) {
plan.restart_required = true;
plan.restart_reasons
.push("network config changed".to_string());
}
// Memory config (requires restarting SQLite connections)
if field_changed(&old.memory, &new.memory) {
plan.restart_required = true;
plan.restart_reasons
.push("memory config changed".to_string());
}
// Default model (driver needs recreation)
if field_changed(&old.default_model, &new.default_model) {
plan.restart_required = true;
plan.restart_reasons
.push("default_model changed".to_string());
}
// Home/data directory changes
if old.home_dir != new.home_dir {
plan.restart_required = true;
plan.restart_reasons.push(format!(
"home_dir changed: {:?} -> {:?}",
old.home_dir, new.home_dir
));
}
if old.data_dir != new.data_dir {
plan.restart_required = true;
plan.restart_reasons.push(format!(
"data_dir changed: {:?} -> {:?}",
old.data_dir, new.data_dir
));
}
// Vault config (encryption key derivation)
if field_changed(&old.vault, &new.vault) {
plan.restart_required = true;
plan.restart_reasons
.push("vault config changed".to_string());
}
// ----- Hot-reloadable fields -----
if field_changed(&old.channels, &new.channels) {
plan.hot_actions.push(HotAction::ReloadChannels);
}
if old.usage_footer != new.usage_footer {
plan.hot_actions.push(HotAction::UpdateUsageFooter);
}
if field_changed(&old.web, &new.web) {
plan.hot_actions.push(HotAction::ReloadWebConfig);
}
if field_changed(&old.browser, &new.browser) {
plan.hot_actions.push(HotAction::ReloadBrowserConfig);
}
if field_changed(&old.approval, &new.approval) {
plan.hot_actions.push(HotAction::UpdateApprovalPolicy);
}
if old.max_cron_jobs != new.max_cron_jobs {
plan.hot_actions.push(HotAction::UpdateCronConfig);
}
if field_changed(&old.webhook_triggers, &new.webhook_triggers) {
plan.hot_actions.push(HotAction::UpdateWebhookConfig);
}
if field_changed(&old.extensions, &new.extensions) {
plan.hot_actions.push(HotAction::ReloadExtensions);
}
if field_changed(&old.mcp_servers, &new.mcp_servers) {
plan.hot_actions.push(HotAction::ReloadMcpServers);
}
if field_changed(&old.a2a, &new.a2a) {
plan.hot_actions.push(HotAction::ReloadA2aConfig);
}
if field_changed(&old.fallback_providers, &new.fallback_providers) {
plan.hot_actions.push(HotAction::ReloadFallbackProviders);
}
if field_changed(&old.provider_urls, &new.provider_urls) {
plan.hot_actions.push(HotAction::ReloadProviderUrls);
}
// ----- No-op fields -----
if old.log_level != new.log_level {
plan.noop_changes
.push(format!("log_level: {} -> {}", old.log_level, new.log_level));
}
if old.language != new.language {
plan.noop_changes
.push(format!("language: {} -> {}", old.language, new.language));
}
if old.mode != new.mode {
plan.noop_changes
.push(format!("mode: {:?} -> {:?}", old.mode, new.mode));
}
plan
}
// ---------------------------------------------------------------------------
// validate_config_for_reload
// ---------------------------------------------------------------------------
/// Validate a new config before applying it.
///
/// Returns `Ok(())` if the config passes basic sanity checks, or `Err` with
/// a list of human-readable error messages.
pub fn validate_config_for_reload(config: &KernelConfig) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
if config.api_listen.is_empty() {
errors.push("api_listen cannot be empty".to_string());
}
if config.max_cron_jobs > 10_000 {
errors.push("max_cron_jobs exceeds reasonable limit (10000)".to_string());
}
// Validate approval policy
if let Err(e) = config.approval.validate() {
errors.push(format!("approval policy: {e}"));
}
// Network config: if network is enabled, shared_secret must be set
if config.network_enabled && config.network.shared_secret.is_empty() {
errors.push("network_enabled is true but network.shared_secret is empty".to_string());
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
// ---------------------------------------------------------------------------
// should_reload — convenience helper for the reload mode
// ---------------------------------------------------------------------------
/// Given the configured [`ReloadMode`] and a [`ReloadPlan`], decide whether
/// the caller should apply hot actions.
///
/// Returns `true` if hot-reload actions should be applied.
pub fn should_apply_hot(mode: ReloadMode, plan: &ReloadPlan) -> bool {
match mode {
ReloadMode::Off => false,
ReloadMode::Restart => false, // caller must do a full restart
ReloadMode::Hot => !plan.hot_actions.is_empty(),
ReloadMode::Hybrid => !plan.hot_actions.is_empty(),
}
}
// ===========================================================================
// Tests
// ===========================================================================
#[cfg(test)]
mod tests {
use super::*;
use openfang_types::config::KernelConfig;
/// Helper: create a default config for diffing.
fn default_cfg() -> KernelConfig {
KernelConfig::default()
}
// -----------------------------------------------------------------------
// Plan detection tests
// -----------------------------------------------------------------------
#[test]
fn test_no_changes_detected() {
let a = default_cfg();
let b = default_cfg();
let plan = build_reload_plan(&a, &b);
assert!(!plan.has_changes());
assert!(!plan.restart_required);
assert!(plan.hot_actions.is_empty());
assert!(plan.noop_changes.is_empty());
}
#[test]
fn test_api_listen_requires_restart() {
let a = default_cfg();
let mut b = default_cfg();
b.api_listen = "0.0.0.0:8080".to_string();
let plan = build_reload_plan(&a, &b);
assert!(plan.restart_required);
assert!(plan
.restart_reasons
.iter()
.any(|r| r.contains("api_listen")));
}
#[test]
fn test_api_key_requires_restart() {
let a = default_cfg();
let mut b = default_cfg();
b.api_key = "super-secret-key".to_string();
let plan = build_reload_plan(&a, &b);
assert!(plan.restart_required);
assert!(plan.restart_reasons.iter().any(|r| r.contains("api_key")));
}
#[test]
fn test_network_requires_restart() {
let a = default_cfg();
let mut b = default_cfg();
b.network_enabled = true;
let plan = build_reload_plan(&a, &b);
assert!(plan.restart_required);
assert!(plan
.restart_reasons
.iter()
.any(|r| r.contains("network_enabled")));
}
#[test]
fn test_network_config_requires_restart() {
let a = default_cfg();
let mut b = default_cfg();
b.network.shared_secret = "new-secret".to_string();
let plan = build_reload_plan(&a, &b);
assert!(plan.restart_required);
assert!(plan
.restart_reasons
.iter()
.any(|r| r.contains("network config")));
}
#[test]
fn test_memory_config_requires_restart() {
let a = default_cfg();
let mut b = default_cfg();
b.memory.consolidation_threshold = 99_999;
let plan = build_reload_plan(&a, &b);
assert!(plan.restart_required);
assert!(plan
.restart_reasons
.iter()
.any(|r| r.contains("memory config")));
}
#[test]
fn test_default_model_requires_restart() {
let a = default_cfg();
let mut b = default_cfg();
b.default_model.model = "gpt-4".to_string();
let plan = build_reload_plan(&a, &b);
assert!(plan.restart_required);
assert!(plan
.restart_reasons
.iter()
.any(|r| r.contains("default_model")));
}
// -----------------------------------------------------------------------
// Hot-reload tests
// -----------------------------------------------------------------------
#[test]
fn test_channels_hot_reload() {
let a = default_cfg();
let mut b = default_cfg();
// Change the channels config by adding a Telegram config
b.channels.telegram = Some(openfang_types::config::TelegramConfig {
bot_token_env: "TG_TOKEN".to_string(),
..Default::default()
});
let plan = build_reload_plan(&a, &b);
assert!(!plan.restart_required);
assert!(plan.hot_actions.contains(&HotAction::ReloadChannels));
}
#[test]
fn test_usage_footer_hot_reload() {
use openfang_types::config::UsageFooterMode;
let a = default_cfg();
let mut b = default_cfg();
b.usage_footer = UsageFooterMode::Off;
let plan = build_reload_plan(&a, &b);
assert!(!plan.restart_required);
assert!(plan.hot_actions.contains(&HotAction::UpdateUsageFooter));
}
#[test]
fn test_max_cron_jobs_hot_reload() {
let a = default_cfg();
let mut b = default_cfg();
b.max_cron_jobs = 1000;
let plan = build_reload_plan(&a, &b);
assert!(!plan.restart_required);
assert!(plan.hot_actions.contains(&HotAction::UpdateCronConfig));
}
#[test]
fn test_extensions_hot_reload() {
let a = default_cfg();
let mut b = default_cfg();
b.extensions.reconnect_max_attempts = 20;
let plan = build_reload_plan(&a, &b);
assert!(!plan.restart_required);
assert!(plan.hot_actions.contains(&HotAction::ReloadExtensions));
}
#[test]
fn test_provider_urls_hot_reload() {
let a = default_cfg();
let mut b = default_cfg();
b.provider_urls
.insert("ollama".to_string(), "http://10.0.0.5:11434/v1".to_string());
let plan = build_reload_plan(&a, &b);
assert!(!plan.restart_required);
assert!(plan.hot_actions.contains(&HotAction::ReloadProviderUrls));
}
// -----------------------------------------------------------------------
// Mixed changes
// -----------------------------------------------------------------------
#[test]
fn test_mixed_changes() {
use openfang_types::config::UsageFooterMode;
let a = default_cfg();
let mut b = default_cfg();
// Restart-required
b.api_listen = "0.0.0.0:9999".to_string();
// Hot-reloadable
b.usage_footer = UsageFooterMode::Tokens;
b.max_cron_jobs = 100;
// No-op
b.log_level = "debug".to_string();
let plan = build_reload_plan(&a, &b);
assert!(plan.restart_required);
assert!(plan.has_changes());
// Hot actions are still collected even if restart is required,
// so the caller knows what will need re-initialization after restart.
assert!(plan.hot_actions.contains(&HotAction::UpdateUsageFooter));
assert!(plan.hot_actions.contains(&HotAction::UpdateCronConfig));
assert!(plan.noop_changes.iter().any(|c| c.contains("log_level")));
}
// -----------------------------------------------------------------------
// No-op changes
// -----------------------------------------------------------------------
#[test]
fn test_noop_changes() {
use openfang_types::config::KernelMode;
let a = default_cfg();
let mut b = default_cfg();
b.log_level = "debug".to_string();
b.language = "de".to_string();
b.mode = KernelMode::Dev;
let plan = build_reload_plan(&a, &b);
assert!(!plan.restart_required);
assert!(plan.hot_actions.is_empty());
assert_eq!(plan.noop_changes.len(), 3);
assert!(plan.noop_changes.iter().any(|c| c.contains("log_level")));
assert!(plan.noop_changes.iter().any(|c| c.contains("language")));
assert!(plan.noop_changes.iter().any(|c| c.contains("mode")));
}
// -----------------------------------------------------------------------
// has_changes / is_hot_reloadable helpers
// -----------------------------------------------------------------------
#[test]
fn test_has_changes() {
// No changes
let plan = ReloadPlan {
restart_required: false,
restart_reasons: vec![],
hot_actions: vec![],
noop_changes: vec![],
};
assert!(!plan.has_changes());
// Only noop
let plan = ReloadPlan {
restart_required: false,
restart_reasons: vec![],
hot_actions: vec![],
noop_changes: vec!["log_level: info -> debug".to_string()],
};
assert!(plan.has_changes());
// Only hot
let plan = ReloadPlan {
restart_required: false,
restart_reasons: vec![],
hot_actions: vec![HotAction::UpdateCronConfig],
noop_changes: vec![],
};
assert!(plan.has_changes());
// Only restart
let plan = ReloadPlan {
restart_required: true,
restart_reasons: vec!["api_listen changed".to_string()],
hot_actions: vec![],
noop_changes: vec![],
};
assert!(plan.has_changes());
}
#[test]
fn test_is_hot_reloadable() {
let plan = ReloadPlan {
restart_required: false,
restart_reasons: vec![],
hot_actions: vec![HotAction::ReloadChannels],
noop_changes: vec![],
};
assert!(plan.is_hot_reloadable());
let plan = ReloadPlan {
restart_required: true,
restart_reasons: vec!["api_listen changed".to_string()],
hot_actions: vec![HotAction::ReloadChannels],
noop_changes: vec![],
};
assert!(!plan.is_hot_reloadable());
}
// -----------------------------------------------------------------------
// Validation tests
// -----------------------------------------------------------------------
#[test]
fn test_validate_config_for_reload_valid() {
let config = default_cfg();
assert!(validate_config_for_reload(&config).is_ok());
}
#[test]
fn test_validate_config_for_reload_invalid() {
// Empty api_listen
let mut config = default_cfg();
config.api_listen = String::new();
let err = validate_config_for_reload(&config).unwrap_err();
assert!(err.iter().any(|e| e.contains("api_listen")));
// Excessive max_cron_jobs
let mut config = default_cfg();
config.max_cron_jobs = 100_000;
let err = validate_config_for_reload(&config).unwrap_err();
assert!(err.iter().any(|e| e.contains("max_cron_jobs")));
}
#[test]
fn test_validate_network_enabled_no_secret() {
let mut config = default_cfg();
config.network_enabled = true;
config.network.shared_secret = String::new();
let err = validate_config_for_reload(&config).unwrap_err();
assert!(err.iter().any(|e| e.contains("shared_secret")));
}
// -----------------------------------------------------------------------
// should_apply_hot
// -----------------------------------------------------------------------
#[test]
fn test_should_apply_hot_off() {
let plan = ReloadPlan {
restart_required: false,
restart_reasons: vec![],
hot_actions: vec![HotAction::ReloadChannels],
noop_changes: vec![],
};
assert!(!should_apply_hot(ReloadMode::Off, &plan));
}
#[test]
fn test_should_apply_hot_restart_mode() {
let plan = ReloadPlan {
restart_required: false,
restart_reasons: vec![],
hot_actions: vec![HotAction::ReloadChannels],
noop_changes: vec![],
};
assert!(!should_apply_hot(ReloadMode::Restart, &plan));
}
#[test]
fn test_should_apply_hot_hybrid() {
let plan = ReloadPlan {
restart_required: false,
restart_reasons: vec![],
hot_actions: vec![HotAction::ReloadChannels],
noop_changes: vec![],
};
assert!(should_apply_hot(ReloadMode::Hybrid, &plan));
assert!(should_apply_hot(ReloadMode::Hot, &plan));
}
#[test]
fn test_should_apply_hot_empty() {
let plan = ReloadPlan {
restart_required: false,
restart_reasons: vec![],
hot_actions: vec![],
noop_changes: vec![],
};
assert!(!should_apply_hot(ReloadMode::Hybrid, &plan));
}
}

View File

@@ -0,0 +1,790 @@
//! Cron job scheduler engine for the OpenFang kernel.
//!
//! Manages scheduled jobs (recurring and one-shot) across all agents.
//! This is separate from `scheduler.rs` which handles agent resource tracking.
//!
//! The scheduler stores jobs in a `DashMap` for concurrent access, persists
//! them to a JSON file on disk, and exposes methods for the kernel tick loop
//! to query due jobs and record outcomes.
use chrono::{Duration, Utc};
use dashmap::DashMap;
use openfang_types::agent::AgentId;
use openfang_types::error::{OpenFangError, OpenFangResult};
use openfang_types::scheduler::{CronJob, CronJobId, CronSchedule};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use tracing::{debug, info, warn};
/// Maximum consecutive errors before a job is auto-disabled.
const MAX_CONSECUTIVE_ERRORS: u32 = 5;
// ---------------------------------------------------------------------------
// JobMeta — extra runtime state not stored in CronJob itself
// ---------------------------------------------------------------------------
/// Runtime metadata for a cron job that extends the base `CronJob` type.
///
/// The `CronJob` struct in `openfang-types` is intentionally lean (no
/// `one_shot`, `last_status`, or error tracking). The scheduler tracks
/// these operational details separately.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobMeta {
/// The underlying job definition.
pub job: CronJob,
/// Whether this job should be removed after a single successful execution.
pub one_shot: bool,
/// Human-readable status of the last execution (e.g. `"ok"` or `"error: ..."`).
pub last_status: Option<String>,
/// Number of consecutive failed executions.
pub consecutive_errors: u32,
}
impl JobMeta {
/// Wrap a `CronJob` with default metadata.
pub fn new(job: CronJob, one_shot: bool) -> Self {
Self {
job,
one_shot,
last_status: None,
consecutive_errors: 0,
}
}
}
// ---------------------------------------------------------------------------
// CronScheduler
// ---------------------------------------------------------------------------
/// Cron job scheduler — manages scheduled jobs for all agents.
///
/// Thread-safe via `DashMap`. The kernel should call [`due_jobs`] on a
/// regular interval (e.g. every 10-30 seconds) to discover jobs that need
/// to fire, then call [`record_success`] or [`record_failure`] after
/// execution completes.
pub struct CronScheduler {
/// All tracked jobs, keyed by their unique ID.
jobs: DashMap<CronJobId, JobMeta>,
/// Path to the persistence file (`<home>/cron_jobs.json`).
persist_path: PathBuf,
/// Global cap on total jobs across all agents (atomic for hot-reload).
max_total_jobs: AtomicUsize,
}
impl CronScheduler {
/// Create a new scheduler.
///
/// `home_dir` is the OpenFang data directory; jobs are persisted to
/// `<home_dir>/cron_jobs.json`. `max_total_jobs` caps the total number
/// of jobs across all agents.
pub fn new(home_dir: &Path, max_total_jobs: usize) -> Self {
Self {
jobs: DashMap::new(),
persist_path: home_dir.join("cron_jobs.json"),
max_total_jobs: AtomicUsize::new(max_total_jobs),
}
}
/// Update the max total jobs limit (for hot-reload).
pub fn set_max_total_jobs(&self, new_max: usize) {
self.max_total_jobs.store(new_max, Ordering::Relaxed);
}
// -- Persistence --------------------------------------------------------
/// Load persisted jobs from disk.
///
/// Returns the number of jobs loaded. If the persistence file does not
/// exist, returns `Ok(0)` without error.
pub fn load(&self) -> OpenFangResult<usize> {
if !self.persist_path.exists() {
return Ok(0);
}
let data = std::fs::read_to_string(&self.persist_path)
.map_err(|e| OpenFangError::Internal(format!("Failed to read cron jobs: {e}")))?;
let metas: Vec<JobMeta> = serde_json::from_str(&data)
.map_err(|e| OpenFangError::Internal(format!("Failed to parse cron jobs: {e}")))?;
let count = metas.len();
for meta in metas {
self.jobs.insert(meta.job.id, meta);
}
info!(count, "Loaded cron jobs from disk");
Ok(count)
}
/// Persist all jobs to disk via atomic write (write to `.tmp`, then rename).
pub fn persist(&self) -> OpenFangResult<()> {
let metas: Vec<JobMeta> = self.jobs.iter().map(|r| r.value().clone()).collect();
let data = serde_json::to_string_pretty(&metas)
.map_err(|e| OpenFangError::Internal(format!("Failed to serialize cron jobs: {e}")))?;
let tmp_path = self.persist_path.with_extension("json.tmp");
std::fs::write(&tmp_path, data.as_bytes()).map_err(|e| {
OpenFangError::Internal(format!("Failed to write cron jobs temp file: {e}"))
})?;
std::fs::rename(&tmp_path, &self.persist_path).map_err(|e| {
OpenFangError::Internal(format!("Failed to rename cron jobs file: {e}"))
})?;
debug!(count = metas.len(), "Persisted cron jobs");
Ok(())
}
// -- CRUD ---------------------------------------------------------------
/// Add a new job. Validates fields, computes the initial `next_run`,
/// and inserts it into the scheduler.
///
/// `one_shot` controls whether the job is removed after a single
/// successful execution.
pub fn add_job(&self, mut job: CronJob, one_shot: bool) -> OpenFangResult<CronJobId> {
// Global limit
let max_jobs = self.max_total_jobs.load(Ordering::Relaxed);
if self.jobs.len() >= max_jobs {
return Err(OpenFangError::Internal(format!(
"Global cron job limit reached ({})",
max_jobs
)));
}
// Per-agent count
let agent_count = self
.jobs
.iter()
.filter(|r| r.value().job.agent_id == job.agent_id)
.count();
// CronJob.validate returns Result<(), String>
job.validate(agent_count)
.map_err(OpenFangError::InvalidInput)?;
// Compute initial next_run
job.next_run = Some(compute_next_run(&job.schedule));
let id = job.id;
self.jobs.insert(id, JobMeta::new(job, one_shot));
Ok(id)
}
/// Remove a job by ID. Returns the removed `CronJob`.
pub fn remove_job(&self, id: CronJobId) -> OpenFangResult<CronJob> {
self.jobs
.remove(&id)
.map(|(_, meta)| meta.job)
.ok_or_else(|| OpenFangError::Internal(format!("Cron job {id} not found")))
}
/// Enable or disable a job. Re-enabling resets errors and recomputes
/// `next_run`.
pub fn set_enabled(&self, id: CronJobId, enabled: bool) -> OpenFangResult<()> {
match self.jobs.get_mut(&id) {
Some(mut meta) => {
meta.job.enabled = enabled;
if enabled {
meta.consecutive_errors = 0;
meta.job.next_run = Some(compute_next_run(&meta.job.schedule));
}
Ok(())
}
None => Err(OpenFangError::Internal(format!("Cron job {id} not found"))),
}
}
// -- Queries ------------------------------------------------------------
/// Get a single job by ID.
pub fn get_job(&self, id: CronJobId) -> Option<CronJob> {
self.jobs.get(&id).map(|r| r.value().job.clone())
}
/// Get the full metadata for a job (includes `one_shot`, `last_status`,
/// `consecutive_errors`).
pub fn get_meta(&self, id: CronJobId) -> Option<JobMeta> {
self.jobs.get(&id).map(|r| r.value().clone())
}
/// List all jobs for a specific agent.
pub fn list_jobs(&self, agent_id: AgentId) -> Vec<CronJob> {
self.jobs
.iter()
.filter(|r| r.value().job.agent_id == agent_id)
.map(|r| r.value().job.clone())
.collect()
}
/// List all jobs across all agents.
pub fn list_all_jobs(&self) -> Vec<CronJob> {
self.jobs.iter().map(|r| r.value().job.clone()).collect()
}
/// Total number of tracked jobs.
pub fn total_jobs(&self) -> usize {
self.jobs.len()
}
/// Return jobs whose `next_run` is at or before `now` and are enabled.
///
/// **Important**: This also pre-advances each due job's `next_run` to the
/// next scheduled time. This prevents the same job from being returned as
/// "due" on subsequent tick iterations while it's still executing.
pub fn due_jobs(&self) -> Vec<CronJob> {
let now = Utc::now();
let mut due = Vec::new();
for mut entry in self.jobs.iter_mut() {
let meta = entry.value_mut();
if meta.job.enabled && meta.job.next_run.map(|t| t <= now).unwrap_or(false) {
due.push(meta.job.clone());
// Pre-advance next_run so the job won't fire again on the next
// tick while it's still executing. Use `now` as the base so the
// next fire time is computed strictly after the current moment.
meta.job.next_run = Some(compute_next_run_after(&meta.job.schedule, now));
}
}
due
}
// -- Outcome recording --------------------------------------------------
/// Record a successful execution for a job.
///
/// Updates `last_run`, resets errors, and either removes the job (if
/// one-shot) or advances `next_run`.
pub fn record_success(&self, id: CronJobId) {
// We need to check one_shot first, then potentially remove.
let should_remove = {
if let Some(mut meta) = self.jobs.get_mut(&id) {
meta.job.last_run = Some(Utc::now());
meta.last_status = Some("ok".to_string());
meta.consecutive_errors = 0;
// one_shot jobs get removed; recurring jobs keep the next_run
// already pre-advanced by due_jobs() — no recompute needed.
meta.one_shot
} else {
return;
}
};
if should_remove {
self.jobs.remove(&id);
}
}
/// Record a failed execution for a job.
///
/// Increments the consecutive error counter. If it reaches
/// [`MAX_CONSECUTIVE_ERRORS`], the job is automatically disabled.
pub fn record_failure(&self, id: CronJobId, error_msg: &str) {
if let Some(mut meta) = self.jobs.get_mut(&id) {
meta.job.last_run = Some(Utc::now());
meta.last_status = Some(format!(
"error: {}",
openfang_types::truncate_str(error_msg, 256)
));
meta.consecutive_errors += 1;
if meta.consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
warn!(
job_id = %id,
errors = meta.consecutive_errors,
"Auto-disabling cron job after repeated failures"
);
meta.job.enabled = false;
} else {
meta.job.next_run =
Some(compute_next_run_after(&meta.job.schedule, Utc::now()));
}
}
}
}
// ---------------------------------------------------------------------------
// compute_next_run
// ---------------------------------------------------------------------------
/// Compute the next fire time for a schedule, based on `now`.
///
/// - `At { at }` — returns `at` directly.
/// - `Every { every_secs }` — returns `now + every_secs`.
/// - `Cron { expr, tz }` — parses the cron expression and computes the next
/// matching time. Supports standard 5-field (`min hour dom month dow`) and
/// 6-field (`sec min hour dom month dow`) formats by converting to the
/// 7-field format required by the `cron` crate.
pub fn compute_next_run(schedule: &CronSchedule) -> chrono::DateTime<Utc> {
compute_next_run_after(schedule, Utc::now())
}
/// Compute the next fire time for a schedule, strictly after `after`.
///
/// Uses `after + 1 second` as the base time so the `cron` crate's
/// inclusive `.after()` always returns a strictly future time. Without
/// this offset, calling `compute_next_run` right after a job fires can
/// return the same minute (or even the same second), causing the
/// scheduler to re-fire immediately.
pub fn compute_next_run_after(
schedule: &CronSchedule,
after: chrono::DateTime<Utc>,
) -> chrono::DateTime<Utc> {
match schedule {
CronSchedule::At { at } => *at,
CronSchedule::Every { every_secs } => after + Duration::seconds(*every_secs as i64),
CronSchedule::Cron { expr, tz: _ } => {
// Convert standard 5/6-field cron to 7-field for the `cron` crate.
// Standard 5-field: min hour dom month dow
// 6-field: sec min hour dom month dow
// cron crate: sec min hour dom month dow year
let trimmed = expr.trim();
let fields: Vec<&str> = trimmed.split_whitespace().collect();
let seven_field = match fields.len() {
5 => format!("0 {trimmed} *"),
6 => format!("{trimmed} *"),
_ => expr.clone(),
};
// Add 1 second so `.after()` (inclusive) skips the current second.
let base = after + Duration::seconds(1);
match seven_field.parse::<cron::Schedule>() {
Ok(sched) => sched
.after(&base)
.next()
.unwrap_or_else(|| after + Duration::hours(1)),
Err(e) => {
warn!("Failed to parse cron expression '{}': {}", expr, e);
after + Duration::hours(1)
}
}
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use chrono::Duration;
use openfang_types::scheduler::{CronAction, CronDelivery};
/// Build a minimal valid `CronJob` with an `Every` schedule.
fn make_job(agent_id: AgentId) -> CronJob {
CronJob {
id: CronJobId::new(),
agent_id,
name: "test-job".into(),
enabled: true,
schedule: CronSchedule::Every { every_secs: 3600 },
action: CronAction::SystemEvent {
text: "ping".into(),
},
delivery: CronDelivery::None,
created_at: Utc::now(),
last_run: None,
next_run: None,
}
}
/// Create a scheduler backed by a temp directory.
fn make_scheduler(max_total: usize) -> (CronScheduler, tempfile::TempDir) {
let tmp = tempfile::tempdir().unwrap();
let sched = CronScheduler::new(tmp.path(), max_total);
(sched, tmp)
}
// -- test_add_job_and_list ----------------------------------------------
#[test]
fn test_add_job_and_list() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
let job = make_job(agent);
let id = sched.add_job(job, false).unwrap();
// Should appear in agent list
let jobs = sched.list_jobs(agent);
assert_eq!(jobs.len(), 1);
assert_eq!(jobs[0].id, id);
assert_eq!(jobs[0].name, "test-job");
// Should appear in global list
let all = sched.list_all_jobs();
assert_eq!(all.len(), 1);
// get_job should return it
let fetched = sched.get_job(id).unwrap();
assert_eq!(fetched.agent_id, agent);
// next_run should have been computed
assert!(fetched.next_run.is_some());
assert_eq!(sched.total_jobs(), 1);
}
// -- test_remove_job ----------------------------------------------------
#[test]
fn test_remove_job() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
let job = make_job(agent);
let id = sched.add_job(job, false).unwrap();
let removed = sched.remove_job(id).unwrap();
assert_eq!(removed.name, "test-job");
assert_eq!(sched.total_jobs(), 0);
// Removing again should fail
assert!(sched.remove_job(id).is_err());
}
// -- test_add_job_global_limit ------------------------------------------
#[test]
fn test_add_job_global_limit() {
let (sched, _tmp) = make_scheduler(2);
let agent = AgentId::new();
let j1 = make_job(agent);
let j2 = make_job(agent);
let j3 = make_job(agent);
sched.add_job(j1, false).unwrap();
sched.add_job(j2, false).unwrap();
// Third should hit global limit
let err = sched.add_job(j3, false).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("limit"),
"Expected global limit error, got: {msg}"
);
}
// -- test_add_job_per_agent_limit ---------------------------------------
#[test]
fn test_add_job_per_agent_limit() {
// MAX_JOBS_PER_AGENT = 50 in openfang-types
let (sched, _tmp) = make_scheduler(1000);
let agent = AgentId::new();
for i in 0..50 {
let mut job = make_job(agent);
job.name = format!("job-{i}");
sched.add_job(job, false).unwrap();
}
// 51st should be rejected by validate()
let mut overflow = make_job(agent);
overflow.name = "overflow".into();
let err = sched.add_job(overflow, false).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("50"),
"Expected per-agent limit error, got: {msg}"
);
}
// -- test_record_success_removes_one_shot --------------------------------
#[test]
fn test_record_success_removes_one_shot() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
let job = make_job(agent);
let id = sched.add_job(job, true).unwrap(); // one_shot = true
assert_eq!(sched.total_jobs(), 1);
sched.record_success(id);
// One-shot job should have been removed
assert_eq!(sched.total_jobs(), 0);
assert!(sched.get_job(id).is_none());
}
#[test]
fn test_record_success_keeps_recurring() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
let job = make_job(agent);
let id = sched.add_job(job, false).unwrap(); // one_shot = false
sched.record_success(id);
// Recurring job should still be there
assert_eq!(sched.total_jobs(), 1);
let meta = sched.get_meta(id).unwrap();
assert_eq!(meta.last_status.as_deref(), Some("ok"));
assert_eq!(meta.consecutive_errors, 0);
assert!(meta.job.last_run.is_some());
}
// -- test_record_failure_auto_disable -----------------------------------
#[test]
fn test_record_failure_auto_disable() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
let job = make_job(agent);
let id = sched.add_job(job, false).unwrap();
// Fail MAX_CONSECUTIVE_ERRORS - 1 times: should still be enabled
for i in 0..(MAX_CONSECUTIVE_ERRORS - 1) {
sched.record_failure(id, &format!("error {i}"));
let meta = sched.get_meta(id).unwrap();
assert!(
meta.job.enabled,
"Job should still be enabled after {} failures",
i + 1
);
assert_eq!(meta.consecutive_errors, i + 1);
}
// One more failure should auto-disable
sched.record_failure(id, "final error");
let meta = sched.get_meta(id).unwrap();
assert!(
!meta.job.enabled,
"Job should be auto-disabled after {MAX_CONSECUTIVE_ERRORS} failures"
);
assert_eq!(meta.consecutive_errors, MAX_CONSECUTIVE_ERRORS);
assert!(
meta.last_status.as_ref().unwrap().starts_with("error:"),
"last_status should record the error"
);
}
// -- test_due_jobs_only_enabled -----------------------------------------
#[test]
fn test_due_jobs_only_enabled() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
// Job 1: enabled, next_run in the past
let mut j1 = make_job(agent);
j1.name = "enabled-due".into();
let id1 = sched.add_job(j1, false).unwrap();
// Job 2: disabled
let mut j2 = make_job(agent);
j2.name = "disabled-job".into();
let id2 = sched.add_job(j2, false).unwrap();
sched.set_enabled(id2, false).unwrap();
// Force job 1's next_run to the past
if let Some(mut meta) = sched.jobs.get_mut(&id1) {
meta.job.next_run = Some(Utc::now() - Duration::seconds(10));
}
// Force job 2's next_run to the past too (but it's disabled)
if let Some(mut meta) = sched.jobs.get_mut(&id2) {
meta.job.next_run = Some(Utc::now() - Duration::seconds(10));
}
let due = sched.due_jobs();
assert_eq!(due.len(), 1);
assert_eq!(due[0].name, "enabled-due");
}
#[test]
fn test_due_jobs_future_not_included() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
let job = make_job(agent);
sched.add_job(job, false).unwrap();
// The job was just added with next_run = now + 3600s, so it should
// not be due yet.
let due = sched.due_jobs();
assert!(due.is_empty());
}
// -- test_set_enabled ---------------------------------------------------
#[test]
fn test_set_enabled() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
let job = make_job(agent);
let id = sched.add_job(job, false).unwrap();
// Disable
sched.set_enabled(id, false).unwrap();
let meta = sched.get_meta(id).unwrap();
assert!(!meta.job.enabled);
// Re-enable resets error count
sched.record_failure(id, "ignored because disabled");
// Actually the job is disabled so record_failure still updates it.
// Let's first re-enable to test reset.
sched.set_enabled(id, true).unwrap();
let meta = sched.get_meta(id).unwrap();
assert!(meta.job.enabled);
assert_eq!(meta.consecutive_errors, 0);
assert!(meta.job.next_run.is_some());
// Non-existent ID should fail
let fake_id = CronJobId::new();
assert!(sched.set_enabled(fake_id, true).is_err());
}
// -- test_persist_and_load ----------------------------------------------
#[test]
fn test_persist_and_load() {
let tmp = tempfile::tempdir().unwrap();
let agent = AgentId::new();
// Create scheduler, add jobs, persist
{
let sched = CronScheduler::new(tmp.path(), 100);
let mut j1 = make_job(agent);
j1.name = "persist-a".into();
let mut j2 = make_job(agent);
j2.name = "persist-b".into();
sched.add_job(j1, false).unwrap();
sched.add_job(j2, true).unwrap(); // one_shot
sched.persist().unwrap();
}
// Create a new scheduler and load from disk
{
let sched = CronScheduler::new(tmp.path(), 100);
let count = sched.load().unwrap();
assert_eq!(count, 2);
assert_eq!(sched.total_jobs(), 2);
let jobs = sched.list_jobs(agent);
assert_eq!(jobs.len(), 2);
let names: Vec<&str> = jobs.iter().map(|j| j.name.as_str()).collect();
assert!(names.contains(&"persist-a"));
assert!(names.contains(&"persist-b"));
// Verify one_shot flag was preserved
let b_id = jobs.iter().find(|j| j.name == "persist-b").unwrap().id;
let meta = sched.get_meta(b_id).unwrap();
assert!(meta.one_shot);
}
}
#[test]
fn test_load_no_file_returns_zero() {
let tmp = tempfile::tempdir().unwrap();
let sched = CronScheduler::new(tmp.path(), 100);
assert_eq!(sched.load().unwrap(), 0);
}
// -- compute_next_run ---------------------------------------------------
#[test]
fn test_compute_next_run_at() {
let target = Utc::now() + Duration::hours(2);
let schedule = CronSchedule::At { at: target };
let next = compute_next_run(&schedule);
assert_eq!(next, target);
}
#[test]
fn test_compute_next_run_every() {
let before = Utc::now();
let schedule = CronSchedule::Every { every_secs: 300 };
let next = compute_next_run(&schedule);
let after = Utc::now();
// Should be roughly now + 300s
assert!(next >= before + Duration::seconds(300));
assert!(next <= after + Duration::seconds(300));
}
#[test]
fn test_compute_next_run_cron_daily() {
let now = Utc::now();
let schedule = CronSchedule::Cron {
expr: "0 9 * * *".into(),
tz: None,
};
let next = compute_next_run(&schedule);
// Should be within the next 24 hours (next 09:00 UTC)
assert!(next > now);
assert!(next <= now + Duration::hours(24));
assert_eq!(next.format("%M").to_string(), "00");
assert_eq!(next.format("%H").to_string(), "09");
}
#[test]
fn test_compute_next_run_cron_with_dow() {
let now = Utc::now();
let schedule = CronSchedule::Cron {
expr: "30 14 * * 1-5".into(),
tz: None,
};
let next = compute_next_run(&schedule);
// Should be within the next 7 days and at 14:30
assert!(next > now);
assert!(next <= now + Duration::days(7));
assert_eq!(next.format("%H:%M").to_string(), "14:30");
}
#[test]
fn test_compute_next_run_cron_invalid_expr() {
let now = Utc::now();
let schedule = CronSchedule::Cron {
expr: "not a cron".into(),
tz: None,
};
let next = compute_next_run(&schedule);
// Invalid expression falls back to 1 hour from now
assert!(next > now + Duration::minutes(59));
assert!(next <= now + Duration::minutes(61));
}
// -- error message truncation in record_failure -------------------------
#[test]
fn test_compute_next_run_after_skips_current_second() {
// A "every 4 hours" cron: next_run should be >= 4 hours from now,
// not in the same minute (the bug from #55).
let schedule = CronSchedule::Cron {
expr: "0 */4 * * *".into(),
tz: None,
};
let now = Utc::now();
let next = compute_next_run_after(&schedule, now);
// Must be strictly after `now` and at least ~1 hour away
// (the closest 4-hourly boundary is at least minutes away).
assert!(next > now, "next_run should be strictly after now");
let diff = next - now;
assert!(
diff.num_minutes() >= 1,
"Expected next_run at least 1 min away, got {} seconds",
diff.num_seconds()
);
}
#[test]
fn test_record_failure_truncates_long_error() {
let (sched, _tmp) = make_scheduler(100);
let agent = AgentId::new();
let job = make_job(agent);
let id = sched.add_job(job, false).unwrap();
let long_error = "x".repeat(1000);
sched.record_failure(id, &long_error);
let meta = sched.get_meta(id).unwrap();
let status = meta.last_status.unwrap();
// "error: " is 7 chars + 256 chars of truncated message = 263 max
assert!(
status.len() <= 263,
"Status should be truncated, got {} chars",
status.len()
);
}
}

View File

@@ -0,0 +1,19 @@
//! Kernel-specific error types.
use openfang_types::error::OpenFangError;
use thiserror::Error;
/// Kernel error type wrapping OpenFangError with kernel-specific context.
#[derive(Error, Debug)]
pub enum KernelError {
/// A wrapped OpenFangError.
#[error(transparent)]
OpenFang(#[from] OpenFangError),
/// The kernel failed to boot.
#[error("Boot failed: {0}")]
BootFailed(String),
}
/// Alias for kernel results.
pub type KernelResult<T> = Result<T, KernelError>;

View File

@@ -0,0 +1,149 @@
//! Event bus — pub/sub with pattern matching and history ring buffer.
use dashmap::DashMap;
use openfang_types::agent::AgentId;
use openfang_types::event::{Event, EventTarget};
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tracing::debug;
/// Maximum events retained in the history ring buffer.
const HISTORY_SIZE: usize = 1000;
/// The central event bus for inter-agent and system communication.
pub struct EventBus {
/// Broadcast channel for all events.
sender: broadcast::Sender<Event>,
/// Per-agent event channels.
agent_channels: DashMap<AgentId, broadcast::Sender<Event>>,
/// Event history ring buffer.
history: Arc<RwLock<VecDeque<Event>>>,
}
impl EventBus {
/// Create a new event bus.
pub fn new() -> Self {
let (sender, _) = broadcast::channel(1024);
Self {
sender,
agent_channels: DashMap::new(),
history: Arc::new(RwLock::new(VecDeque::with_capacity(HISTORY_SIZE))),
}
}
/// Publish an event to the bus.
pub async fn publish(&self, event: Event) {
debug!(
event_id = %event.id,
source = %event.source,
"Publishing event"
);
// Store in history
{
let mut history = self.history.write().await;
if history.len() >= HISTORY_SIZE {
history.pop_front();
}
history.push_back(event.clone());
}
// Route to target
match &event.target {
EventTarget::Agent(agent_id) => {
if let Some(sender) = self.agent_channels.get(agent_id) {
let _ = sender.send(event.clone());
}
}
EventTarget::Broadcast => {
let _ = self.sender.send(event.clone());
for entry in self.agent_channels.iter() {
let _ = entry.value().send(event.clone());
}
}
EventTarget::Pattern(_pattern) => {
// Phase 1: broadcast to all for pattern matching
let _ = self.sender.send(event.clone());
}
EventTarget::System => {
let _ = self.sender.send(event.clone());
}
}
}
/// Subscribe to events for a specific agent.
pub fn subscribe_agent(&self, agent_id: AgentId) -> broadcast::Receiver<Event> {
let entry = self.agent_channels.entry(agent_id).or_insert_with(|| {
let (tx, _) = broadcast::channel(256);
tx
});
entry.subscribe()
}
/// Subscribe to all broadcast/system events.
pub fn subscribe_all(&self) -> broadcast::Receiver<Event> {
self.sender.subscribe()
}
/// Get recent event history.
pub async fn history(&self, limit: usize) -> Vec<Event> {
let history = self.history.read().await;
history.iter().rev().take(limit).cloned().collect()
}
/// Remove an agent's channel when it's terminated.
pub fn unsubscribe_agent(&self, agent_id: AgentId) {
self.agent_channels.remove(&agent_id);
}
}
impl Default for EventBus {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use openfang_types::event::{EventPayload, SystemEvent};
#[tokio::test]
async fn test_publish_and_history() {
let bus = EventBus::new();
let agent_id = AgentId::new();
let event = Event::new(
agent_id,
EventTarget::System,
EventPayload::System(SystemEvent::KernelStarted),
);
bus.publish(event).await;
let history = bus.history(10).await;
assert_eq!(history.len(), 1);
}
#[tokio::test]
async fn test_agent_subscribe() {
let bus = EventBus::new();
let agent_id = AgentId::new();
let mut rx = bus.subscribe_agent(agent_id);
let event = Event::new(
AgentId::new(),
EventTarget::Agent(agent_id),
EventPayload::System(SystemEvent::HealthCheck {
status: "ok".to_string(),
}),
);
bus.publish(event).await;
let received = rx.recv().await.unwrap();
match received.payload {
EventPayload::System(SystemEvent::HealthCheck { status }) => {
assert_eq!(status, "ok");
}
_ => panic!("Wrong payload"),
}
}
}

View File

@@ -0,0 +1,245 @@
//! Heartbeat monitor — detects unresponsive agents for 24/7 autonomous operation.
//!
//! The heartbeat monitor runs as a background tokio task, periodically checking
//! each running agent's `last_active` timestamp. If an agent hasn't been active
//! for longer than 2x its heartbeat interval, a `HealthCheckFailed` event is
//! published to the event bus.
use crate::registry::AgentRegistry;
use chrono::Utc;
use openfang_types::agent::{AgentId, AgentState};
use tracing::{debug, warn};
/// Default heartbeat check interval (seconds).
const DEFAULT_CHECK_INTERVAL_SECS: u64 = 30;
/// Multiplier: agent is considered unresponsive if inactive for this many
/// multiples of its heartbeat interval.
const UNRESPONSIVE_MULTIPLIER: u64 = 2;
/// Result of a heartbeat check.
#[derive(Debug, Clone)]
pub struct HeartbeatStatus {
/// Agent ID.
pub agent_id: AgentId,
/// Agent name.
pub name: String,
/// Seconds since last activity.
pub inactive_secs: i64,
/// Whether the agent is considered unresponsive.
pub unresponsive: bool,
}
/// Heartbeat monitor configuration.
#[derive(Debug, Clone)]
pub struct HeartbeatConfig {
/// How often to run the heartbeat check (seconds).
pub check_interval_secs: u64,
/// Default threshold for unresponsiveness (seconds).
/// Overridden per-agent by AutonomousConfig.heartbeat_interval_secs.
pub default_timeout_secs: u64,
}
impl Default for HeartbeatConfig {
fn default() -> Self {
Self {
check_interval_secs: DEFAULT_CHECK_INTERVAL_SECS,
default_timeout_secs: DEFAULT_CHECK_INTERVAL_SECS * UNRESPONSIVE_MULTIPLIER,
}
}
}
/// Check all running agents and return their heartbeat status.
///
/// This is a pure function — it doesn't start a background task.
/// The caller (kernel) can run this periodically or in a background task.
pub fn check_agents(registry: &AgentRegistry, config: &HeartbeatConfig) -> Vec<HeartbeatStatus> {
let now = Utc::now();
let mut statuses = Vec::new();
for entry_ref in registry.list() {
// Only check running agents
if entry_ref.state != AgentState::Running {
continue;
}
let inactive_secs = (now - entry_ref.last_active).num_seconds();
// Determine timeout: use agent's autonomous config if set, else default
let timeout_secs = entry_ref
.manifest
.autonomous
.as_ref()
.map(|a| a.heartbeat_interval_secs * UNRESPONSIVE_MULTIPLIER)
.unwrap_or(config.default_timeout_secs) as i64;
let unresponsive = inactive_secs > timeout_secs;
if unresponsive {
warn!(
agent = %entry_ref.name,
inactive_secs,
timeout_secs,
"Agent is unresponsive"
);
} else {
debug!(
agent = %entry_ref.name,
inactive_secs,
"Agent heartbeat OK"
);
}
statuses.push(HeartbeatStatus {
agent_id: entry_ref.id,
name: entry_ref.name.clone(),
inactive_secs,
unresponsive,
});
}
statuses
}
/// Check if an agent is currently within its quiet hours.
///
/// Quiet hours format: "HH:MM-HH:MM" (24-hour format, UTC).
/// Returns true if the current time falls within the quiet period.
pub fn is_quiet_hours(quiet_hours: &str) -> bool {
let parts: Vec<&str> = quiet_hours.split('-').collect();
if parts.len() != 2 {
return false;
}
let now = Utc::now();
let current_minutes = now.format("%H").to_string().parse::<u32>().unwrap_or(0) * 60
+ now.format("%M").to_string().parse::<u32>().unwrap_or(0);
let parse_time = |s: &str| -> Option<u32> {
let hm: Vec<&str> = s.trim().split(':').collect();
if hm.len() != 2 {
return None;
}
let h = hm[0].parse::<u32>().ok()?;
let m = hm[1].parse::<u32>().ok()?;
if h > 23 || m > 59 {
return None;
}
Some(h * 60 + m)
};
let start = match parse_time(parts[0]) {
Some(v) => v,
None => return false,
};
let end = match parse_time(parts[1]) {
Some(v) => v,
None => return false,
};
if start <= end {
// Same-day range: e.g., 22:00-06:00 would be cross-midnight
// This is start <= current < end
current_minutes >= start && current_minutes < end
} else {
// Cross-midnight: e.g., 22:00-06:00
current_minutes >= start || current_minutes < end
}
}
/// Aggregate heartbeat summary.
#[derive(Debug, Clone, Default)]
pub struct HeartbeatSummary {
/// Total agents checked.
pub total_checked: usize,
/// Number of responsive agents.
pub responsive: usize,
/// Number of unresponsive agents.
pub unresponsive: usize,
/// Details of unresponsive agents.
pub unresponsive_agents: Vec<HeartbeatStatus>,
}
/// Produce a summary from heartbeat statuses.
pub fn summarize(statuses: &[HeartbeatStatus]) -> HeartbeatSummary {
let unresponsive_agents: Vec<HeartbeatStatus> = statuses
.iter()
.filter(|s| s.unresponsive)
.cloned()
.collect();
HeartbeatSummary {
total_checked: statuses.len(),
responsive: statuses.len() - unresponsive_agents.len(),
unresponsive: unresponsive_agents.len(),
unresponsive_agents,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quiet_hours_parsing() {
// We can't easily test time-dependent logic, but we can test format parsing
assert!(!is_quiet_hours("invalid"));
assert!(!is_quiet_hours(""));
assert!(!is_quiet_hours("25:00-06:00")); // Invalid hours handled gracefully
}
#[test]
fn test_quiet_hours_format_valid() {
// The function returns true/false based on current time
// We just verify it doesn't panic on valid input
let _ = is_quiet_hours("22:00-06:00");
let _ = is_quiet_hours("00:00-23:59");
let _ = is_quiet_hours("09:00-17:00");
}
#[test]
fn test_heartbeat_config_default() {
let config = HeartbeatConfig::default();
assert_eq!(config.check_interval_secs, 30);
assert_eq!(config.default_timeout_secs, 60);
}
#[test]
fn test_summarize_empty() {
let summary = summarize(&[]);
assert_eq!(summary.total_checked, 0);
assert_eq!(summary.responsive, 0);
assert_eq!(summary.unresponsive, 0);
}
#[test]
fn test_summarize_mixed() {
let statuses = vec![
HeartbeatStatus {
agent_id: AgentId::new(),
name: "agent-1".to_string(),
inactive_secs: 10,
unresponsive: false,
},
HeartbeatStatus {
agent_id: AgentId::new(),
name: "agent-2".to_string(),
inactive_secs: 120,
unresponsive: true,
},
HeartbeatStatus {
agent_id: AgentId::new(),
name: "agent-3".to_string(),
inactive_secs: 5,
unresponsive: false,
},
];
let summary = summarize(&statuses);
assert_eq!(summary.total_checked, 3);
assert_eq!(summary.responsive, 2);
assert_eq!(summary.unresponsive, 1);
assert_eq!(summary.unresponsive_agents.len(), 1);
assert_eq!(summary.unresponsive_agents[0].name, "agent-2");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,29 @@
//! Core kernel for the OpenFang Agent Operating System.
//!
//! The kernel manages agent lifecycles, memory, permissions, scheduling,
//! and inter-agent communication.
pub mod approval;
pub mod auth;
pub mod auto_reply;
pub mod background;
pub mod capabilities;
pub mod config;
pub mod config_reload;
pub mod cron;
pub mod error;
pub mod event_bus;
pub mod heartbeat;
pub mod kernel;
pub mod metering;
pub mod pairing;
pub mod registry;
pub mod scheduler;
pub mod supervisor;
pub mod triggers;
pub mod whatsapp_gateway;
pub mod wizard;
pub mod workflow;
pub use kernel::DeliveryTracker;
pub use kernel::OpenFangKernel;

View File

@@ -0,0 +1,753 @@
//! Metering engine — tracks LLM cost and enforces spending quotas.
use openfang_memory::usage::{ModelUsage, UsageRecord, UsageStore, UsageSummary};
use openfang_types::agent::{AgentId, ResourceQuota};
use openfang_types::error::{OpenFangError, OpenFangResult};
use std::sync::Arc;
/// The metering engine tracks usage cost and enforces quota limits.
pub struct MeteringEngine {
/// Persistent usage store (SQLite-backed).
store: Arc<UsageStore>,
}
impl MeteringEngine {
/// Create a new metering engine with the given usage store.
pub fn new(store: Arc<UsageStore>) -> Self {
Self { store }
}
/// Record a usage event (persists to SQLite).
pub fn record(&self, record: &UsageRecord) -> OpenFangResult<()> {
self.store.record(record)
}
/// Check if an agent is within its spending quotas (hourly, daily, monthly).
/// Returns Ok(()) if under all quotas, or QuotaExceeded error if over any.
pub fn check_quota(&self, agent_id: AgentId, quota: &ResourceQuota) -> OpenFangResult<()> {
// Hourly check
if quota.max_cost_per_hour_usd > 0.0 {
let hourly_cost = self.store.query_hourly(agent_id)?;
if hourly_cost >= quota.max_cost_per_hour_usd {
return Err(OpenFangError::QuotaExceeded(format!(
"Agent {} exceeded hourly cost quota: ${:.4} / ${:.4}",
agent_id, hourly_cost, quota.max_cost_per_hour_usd
)));
}
}
// Daily check
if quota.max_cost_per_day_usd > 0.0 {
let daily_cost = self.store.query_daily(agent_id)?;
if daily_cost >= quota.max_cost_per_day_usd {
return Err(OpenFangError::QuotaExceeded(format!(
"Agent {} exceeded daily cost quota: ${:.4} / ${:.4}",
agent_id, daily_cost, quota.max_cost_per_day_usd
)));
}
}
// Monthly check
if quota.max_cost_per_month_usd > 0.0 {
let monthly_cost = self.store.query_monthly(agent_id)?;
if monthly_cost >= quota.max_cost_per_month_usd {
return Err(OpenFangError::QuotaExceeded(format!(
"Agent {} exceeded monthly cost quota: ${:.4} / ${:.4}",
agent_id, monthly_cost, quota.max_cost_per_month_usd
)));
}
}
Ok(())
}
/// Check global budget limits (across all agents).
pub fn check_global_budget(
&self,
budget: &openfang_types::config::BudgetConfig,
) -> OpenFangResult<()> {
if budget.max_hourly_usd > 0.0 {
let cost = self.store.query_global_hourly()?;
if cost >= budget.max_hourly_usd {
return Err(OpenFangError::QuotaExceeded(format!(
"Global hourly budget exceeded: ${:.4} / ${:.4}",
cost, budget.max_hourly_usd
)));
}
}
if budget.max_daily_usd > 0.0 {
let cost = self.store.query_today_cost()?;
if cost >= budget.max_daily_usd {
return Err(OpenFangError::QuotaExceeded(format!(
"Global daily budget exceeded: ${:.4} / ${:.4}",
cost, budget.max_daily_usd
)));
}
}
if budget.max_monthly_usd > 0.0 {
let cost = self.store.query_global_monthly()?;
if cost >= budget.max_monthly_usd {
return Err(OpenFangError::QuotaExceeded(format!(
"Global monthly budget exceeded: ${:.4} / ${:.4}",
cost, budget.max_monthly_usd
)));
}
}
Ok(())
}
/// Get budget status — current spend vs limits for all time windows.
pub fn budget_status(&self, budget: &openfang_types::config::BudgetConfig) -> BudgetStatus {
let hourly = self.store.query_global_hourly().unwrap_or(0.0);
let daily = self.store.query_today_cost().unwrap_or(0.0);
let monthly = self.store.query_global_monthly().unwrap_or(0.0);
BudgetStatus {
hourly_spend: hourly,
hourly_limit: budget.max_hourly_usd,
hourly_pct: if budget.max_hourly_usd > 0.0 {
hourly / budget.max_hourly_usd
} else {
0.0
},
daily_spend: daily,
daily_limit: budget.max_daily_usd,
daily_pct: if budget.max_daily_usd > 0.0 {
daily / budget.max_daily_usd
} else {
0.0
},
monthly_spend: monthly,
monthly_limit: budget.max_monthly_usd,
monthly_pct: if budget.max_monthly_usd > 0.0 {
monthly / budget.max_monthly_usd
} else {
0.0
},
alert_threshold: budget.alert_threshold,
}
}
/// Get a usage summary, optionally filtered by agent.
pub fn get_summary(&self, agent_id: Option<AgentId>) -> OpenFangResult<UsageSummary> {
self.store.query_summary(agent_id)
}
/// Get usage grouped by model.
pub fn get_by_model(&self) -> OpenFangResult<Vec<ModelUsage>> {
self.store.query_by_model()
}
/// Estimate the cost of an LLM call based on model and token counts.
///
/// Pricing table (approximate, per million tokens):
///
/// | Model Family | Input $/M | Output $/M |
/// |-----------------------|-----------|------------|
/// | claude-haiku | 0.25 | 1.25 |
/// | claude-sonnet-4-6 | 3.00 | 15.00 |
/// | claude-opus-4-6 | 5.00 | 25.00 |
/// | claude-opus (legacy) | 15.00 | 75.00 |
/// | gpt-5.2(-pro) | 1.75 | 14.00 |
/// | gpt-5(.1) | 1.25 | 10.00 |
/// | gpt-5-mini | 0.25 | 2.00 |
/// | gpt-5-nano | 0.05 | 0.40 |
/// | gpt-4o | 2.50 | 10.00 |
/// | gpt-4o-mini | 0.15 | 0.60 |
/// | gpt-4.1 | 2.00 | 8.00 |
/// | gpt-4.1-mini | 0.40 | 1.60 |
/// | gpt-4.1-nano | 0.10 | 0.40 |
/// | o3-mini | 1.10 | 4.40 |
/// | gemini-3.1 | 2.50 | 15.00 |
/// | gemini-3 | 0.50 | 3.00 |
/// | gemini-2.5-flash-lite | 0.04 | 0.15 |
/// | gemini-2.5-pro | 1.25 | 10.00 |
/// | gemini-2.5-flash | 0.15 | 0.60 |
/// | gemini-2.0-flash | 0.10 | 0.40 |
/// | deepseek-chat/v3 | 0.27 | 1.10 |
/// | deepseek-reasoner/r1 | 0.55 | 2.19 |
/// | llama-4-maverick | 0.50 | 0.77 |
/// | llama-4-scout | 0.11 | 0.34 |
/// | llama/mixtral (groq) | 0.05 | 0.10 |
/// | grok-4.1 | 0.20 | 0.50 |
/// | grok-4 | 3.00 | 15.00 |
/// | grok-3 | 3.00 | 15.00 |
/// | qwen | 0.20 | 0.60 |
/// | mistral-large | 2.00 | 6.00 |
/// | mistral-small | 0.10 | 0.30 |
/// | command-r-plus | 2.50 | 10.00 |
/// | Default (unknown) | 1.00 | 3.00 |
pub fn estimate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
let model_lower = model.to_lowercase();
let (input_per_m, output_per_m) = estimate_cost_rates(&model_lower);
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_per_m;
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_per_m;
input_cost + output_cost
}
/// Estimate cost using the model catalog as the pricing source.
///
/// Falls back to the default rate ($1/$3 per million) if the model is not
/// found in the catalog.
pub fn estimate_cost_with_catalog(
catalog: &openfang_runtime::model_catalog::ModelCatalog,
model: &str,
input_tokens: u64,
output_tokens: u64,
) -> f64 {
let (input_per_m, output_per_m) = catalog.pricing(model).unwrap_or((1.0, 3.0));
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_per_m;
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_per_m;
input_cost + output_cost
}
/// Clean up old usage records.
pub fn cleanup(&self, days: u32) -> OpenFangResult<usize> {
self.store.cleanup_old(days)
}
}
/// Budget status snapshot — current spend vs limits for all time windows.
#[derive(Debug, Clone, serde::Serialize)]
pub struct BudgetStatus {
pub hourly_spend: f64,
pub hourly_limit: f64,
pub hourly_pct: f64,
pub daily_spend: f64,
pub daily_limit: f64,
pub daily_pct: f64,
pub monthly_spend: f64,
pub monthly_limit: f64,
pub monthly_pct: f64,
pub alert_threshold: f64,
}
/// Returns (input_per_million, output_per_million) pricing for a model.
///
/// Order matters: more specific patterns must come before generic ones
/// (e.g. "gpt-4o-mini" before "gpt-4o", "gpt-4.1-mini" before "gpt-4.1").
fn estimate_cost_rates(model: &str) -> (f64, f64) {
// ── Anthropic ──────────────────────────────────────────────
if model.contains("haiku") {
return (0.25, 1.25);
}
if model.contains("opus-4-6") || model.contains("claude-opus-4-6") {
return (5.0, 25.0);
}
if model.contains("opus") {
return (15.0, 75.0);
}
if model.contains("sonnet-4-6") || model.contains("claude-sonnet-4-6") {
return (3.0, 15.0);
}
if model.contains("sonnet") {
return (3.0, 15.0);
}
// ── OpenAI ─────────────────────────────────────────────────
if model.contains("gpt-5.2-pro") {
return (1.75, 14.0);
}
if model.contains("gpt-5.2") {
return (1.75, 14.0);
}
if model.contains("gpt-5.1") {
return (1.25, 10.0);
}
if model.contains("gpt-5-nano") {
return (0.05, 0.40);
}
if model.contains("gpt-5-mini") {
return (0.25, 2.0);
}
if model.contains("gpt-5") {
return (1.25, 10.0);
}
if model.contains("gpt-4o-mini") {
return (0.15, 0.60);
}
if model.contains("gpt-4o") {
return (2.50, 10.0);
}
if model.contains("gpt-4.1-nano") {
return (0.10, 0.40);
}
if model.contains("gpt-4.1-mini") {
return (0.40, 1.60);
}
if model.contains("gpt-4.1") {
return (2.00, 8.00);
}
if model.contains("o4-mini") {
return (1.10, 4.40);
}
if model.contains("o3-mini") {
return (1.10, 4.40);
}
if model.contains("o3") {
return (2.00, 8.00);
}
// Generic gpt-4 fallback
if model.contains("gpt-4") {
return (2.50, 10.0);
}
// ── Google Gemini ──────────────────────────────────────────
if model.contains("gemini-3.1") {
return (2.50, 15.0);
}
if model.contains("gemini-3") {
return (0.50, 3.0);
}
if model.contains("gemini-2.5-flash-lite") {
return (0.04, 0.15);
}
if model.contains("gemini-2.5-pro") {
return (1.25, 10.0);
}
if model.contains("gemini-2.5-flash") {
return (0.15, 0.60);
}
if model.contains("gemini-2.0-flash") || model.contains("gemini-flash") {
return (0.10, 0.40);
}
// Generic gemini fallback
if model.contains("gemini") {
return (0.15, 0.60);
}
// ── DeepSeek ───────────────────────────────────────────────
if model.contains("deepseek-reasoner") || model.contains("deepseek-r1") {
return (0.55, 2.19);
}
if model.contains("deepseek") {
return (0.27, 1.10);
}
// ── Cerebras (ultra-fast, cheap) ── must come before llama ─
if model.contains("cerebras") {
return (0.06, 0.06);
}
// ── SambaNova ── must come before llama ──────────────────────
if model.contains("sambanova") {
return (0.06, 0.06);
}
// ── Replicate ── must come before llama ─────────────────────
if model.contains("replicate") {
return (0.40, 0.40);
}
// ── Open-source (Groq, Together, etc.) ─────────────────────
if model.contains("llama-4-maverick") {
return (0.50, 0.77);
}
if model.contains("llama-4-scout") {
return (0.11, 0.34);
}
if model.contains("llama") || model.contains("mixtral") {
return (0.05, 0.10);
}
// ── Qwen (Alibaba) ──────────────────────────────────────────
if model.contains("qwen-max") {
return (4.00, 12.00);
}
if model.contains("qwen-vl") {
return (1.50, 4.50);
}
if model.contains("qwen-plus") {
return (0.80, 2.00);
}
if model.contains("qwen-turbo") {
return (0.30, 0.60);
}
if model.contains("qwen") {
return (0.20, 0.60);
}
// ── MiniMax ──────────────────────────────────────────────────
if model.contains("minimax") {
return (1.00, 3.00);
}
// ── Zhipu / GLM ─────────────────────────────────────────────
if model.contains("glm-4-flash") {
return (0.10, 0.10);
}
if model.contains("glm") {
return (1.50, 5.00);
}
if model.contains("codegeex") {
return (0.10, 0.10);
}
// ── Moonshot / Kimi ─────────────────────────────────────────
if model.contains("moonshot") || model.contains("kimi") {
return (0.80, 0.80);
}
// ── Baidu ERNIE ─────────────────────────────────────────────
if model.contains("ernie") {
return (2.00, 6.00);
}
// ── AWS Bedrock ─────────────────────────────────────────────
if model.contains("nova-pro") {
return (0.80, 3.20);
}
if model.contains("nova-lite") {
return (0.06, 0.24);
}
// ── Mistral ────────────────────────────────────────────────
if model.contains("mistral-large") {
return (2.00, 6.00);
}
if model.contains("mistral-small") || model.contains("mistral") {
return (0.10, 0.30);
}
// ── Cohere ─────────────────────────────────────────────────
if model.contains("command-r-plus") {
return (2.50, 10.0);
}
if model.contains("command-r") {
return (0.15, 0.60);
}
// ── Perplexity ──────────────────────────────────────────────
if model.contains("sonar-pro") {
return (3.0, 15.0);
}
if model.contains("sonar") {
return (1.0, 5.0);
}
// ── xAI / Grok ──────────────────────────────────────────────
if model.contains("grok-4.1") {
return (0.20, 0.50);
}
if model.contains("grok-4") {
return (3.0, 15.0);
}
if model.contains("grok-3-mini") || model.contains("grok-2-mini") || model.contains("grok-mini")
{
return (0.30, 0.50);
}
if model.contains("grok-3") {
return (3.0, 15.0);
}
if model.contains("grok") {
return (2.0, 10.0);
}
// ── AI21 / Jamba ────────────────────────────────────────────
if model.contains("jamba") {
return (2.0, 8.0);
}
// ── Default (conservative) ─────────────────────────────────
(1.0, 3.0)
}
#[cfg(test)]
mod tests {
use super::*;
use openfang_memory::MemorySubstrate;
fn setup() -> MeteringEngine {
let substrate = MemorySubstrate::open_in_memory(0.1).unwrap();
let store = Arc::new(UsageStore::new(substrate.usage_conn()));
MeteringEngine::new(store)
}
#[test]
fn test_record_and_check_quota_under() {
let engine = setup();
let agent_id = AgentId::new();
let quota = ResourceQuota {
max_cost_per_hour_usd: 1.0,
..Default::default()
};
engine
.record(&UsageRecord {
agent_id,
model: "claude-haiku".to_string(),
input_tokens: 100,
output_tokens: 50,
cost_usd: 0.001,
tool_calls: 0,
})
.unwrap();
assert!(engine.check_quota(agent_id, &quota).is_ok());
}
#[test]
fn test_check_quota_exceeded() {
let engine = setup();
let agent_id = AgentId::new();
let quota = ResourceQuota {
max_cost_per_hour_usd: 0.01,
..Default::default()
};
engine
.record(&UsageRecord {
agent_id,
model: "claude-sonnet".to_string(),
input_tokens: 10000,
output_tokens: 5000,
cost_usd: 0.05,
tool_calls: 0,
})
.unwrap();
let result = engine.check_quota(agent_id, &quota);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("exceeded hourly cost quota"));
}
#[test]
fn test_check_quota_zero_limit_skipped() {
let engine = setup();
let agent_id = AgentId::new();
let quota = ResourceQuota {
max_cost_per_hour_usd: 0.0,
..Default::default()
};
// Even with high usage, a zero limit means no enforcement
engine
.record(&UsageRecord {
agent_id,
model: "claude-opus".to_string(),
input_tokens: 100000,
output_tokens: 50000,
cost_usd: 100.0,
tool_calls: 0,
})
.unwrap();
assert!(engine.check_quota(agent_id, &quota).is_ok());
}
#[test]
fn test_estimate_cost_haiku() {
let cost = MeteringEngine::estimate_cost("claude-haiku-4-5-20251001", 1_000_000, 1_000_000);
assert!((cost - 1.50).abs() < 0.01); // $0.25 + $1.25
}
#[test]
fn test_estimate_cost_sonnet() {
let cost = MeteringEngine::estimate_cost("claude-sonnet-4-20250514", 1_000_000, 1_000_000);
assert!((cost - 18.0).abs() < 0.01); // $3.00 + $15.00
}
#[test]
fn test_estimate_cost_opus() {
let cost = MeteringEngine::estimate_cost("claude-opus-4-20250514", 1_000_000, 1_000_000);
assert!((cost - 90.0).abs() < 0.01); // $15.00 + $75.00
}
#[test]
fn test_estimate_cost_gpt4o() {
let cost = MeteringEngine::estimate_cost("gpt-4o-2024-11-20", 1_000_000, 1_000_000);
assert!((cost - 12.50).abs() < 0.01); // $2.50 + $10.00
}
#[test]
fn test_estimate_cost_gpt4o_mini() {
let cost = MeteringEngine::estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000);
assert!((cost - 0.75).abs() < 0.01); // $0.15 + $0.60
}
#[test]
fn test_estimate_cost_gpt41() {
let cost = MeteringEngine::estimate_cost("gpt-4.1", 1_000_000, 1_000_000);
assert!((cost - 10.0).abs() < 0.01); // $2.00 + $8.00
}
#[test]
fn test_estimate_cost_gpt41_mini() {
let cost = MeteringEngine::estimate_cost("gpt-4.1-mini", 1_000_000, 1_000_000);
assert!((cost - 2.0).abs() < 0.01); // $0.40 + $1.60
}
#[test]
fn test_estimate_cost_gpt41_nano() {
let cost = MeteringEngine::estimate_cost("gpt-4.1-nano", 1_000_000, 1_000_000);
assert!((cost - 0.50).abs() < 0.01); // $0.10 + $0.40
}
#[test]
fn test_estimate_cost_o3_mini() {
let cost = MeteringEngine::estimate_cost("o3-mini", 1_000_000, 1_000_000);
assert!((cost - 5.50).abs() < 0.01); // $1.10 + $4.40
}
#[test]
fn test_estimate_cost_gemini_20_flash() {
let cost = MeteringEngine::estimate_cost("gemini-2.0-flash", 1_000_000, 1_000_000);
assert!((cost - 0.50).abs() < 0.01); // $0.10 + $0.40
}
#[test]
fn test_estimate_cost_gemini_25_pro() {
let cost = MeteringEngine::estimate_cost("gemini-2.5-pro", 1_000_000, 1_000_000);
assert!((cost - 11.25).abs() < 0.01); // $1.25 + $10.00
}
#[test]
fn test_estimate_cost_gemini_25_flash() {
let cost = MeteringEngine::estimate_cost("gemini-2.5-flash", 1_000_000, 1_000_000);
assert!((cost - 0.75).abs() < 0.01); // $0.15 + $0.60
}
#[test]
fn test_estimate_cost_deepseek_chat() {
let cost = MeteringEngine::estimate_cost("deepseek-chat", 1_000_000, 1_000_000);
assert!((cost - 1.37).abs() < 0.01); // $0.27 + $1.10
}
#[test]
fn test_estimate_cost_deepseek_reasoner() {
let cost = MeteringEngine::estimate_cost("deepseek-reasoner", 1_000_000, 1_000_000);
assert!((cost - 2.74).abs() < 0.01); // $0.55 + $2.19
}
#[test]
fn test_estimate_cost_llama() {
let cost = MeteringEngine::estimate_cost("llama-3.3-70b-versatile", 1_000_000, 1_000_000);
assert!((cost - 0.15).abs() < 0.01); // $0.05 + $0.10
}
#[test]
fn test_estimate_cost_mixtral() {
let cost = MeteringEngine::estimate_cost("mixtral-8x7b", 1_000_000, 1_000_000);
assert!((cost - 0.15).abs() < 0.01); // $0.05 + $0.10
}
#[test]
fn test_estimate_cost_qwen() {
let cost = MeteringEngine::estimate_cost("qwen-2.5-72b", 1_000_000, 1_000_000);
assert!((cost - 0.80).abs() < 0.01); // $0.20 + $0.60
}
#[test]
fn test_estimate_cost_mistral_large() {
let cost = MeteringEngine::estimate_cost("mistral-large-latest", 1_000_000, 1_000_000);
assert!((cost - 8.0).abs() < 0.01); // $2.00 + $6.00
}
#[test]
fn test_estimate_cost_mistral_small() {
let cost = MeteringEngine::estimate_cost("mistral-small-latest", 1_000_000, 1_000_000);
assert!((cost - 0.40).abs() < 0.01); // $0.10 + $0.30
}
#[test]
fn test_estimate_cost_command_r_plus() {
let cost = MeteringEngine::estimate_cost("command-r-plus", 1_000_000, 1_000_000);
assert!((cost - 12.50).abs() < 0.01); // $2.50 + $10.00
}
#[test]
fn test_estimate_cost_unknown() {
let cost = MeteringEngine::estimate_cost("my-custom-model", 1_000_000, 1_000_000);
assert!((cost - 4.0).abs() < 0.01); // $1.00 + $3.00
}
#[test]
fn test_estimate_cost_grok() {
let cost = MeteringEngine::estimate_cost("grok-2", 1_000_000, 1_000_000);
assert!((cost - 12.0).abs() < 0.01); // $2.00 + $10.00
}
#[test]
fn test_estimate_cost_grok_mini() {
let cost = MeteringEngine::estimate_cost("grok-2-mini", 1_000_000, 1_000_000);
assert!((cost - 0.80).abs() < 0.01); // $0.30 + $0.50
}
#[test]
fn test_estimate_cost_sonar_pro() {
let cost = MeteringEngine::estimate_cost("sonar-pro", 1_000_000, 1_000_000);
assert!((cost - 18.0).abs() < 0.01); // $3.00 + $15.00
}
#[test]
fn test_estimate_cost_jamba() {
let cost = MeteringEngine::estimate_cost("jamba-1.5-large", 1_000_000, 1_000_000);
assert!((cost - 10.0).abs() < 0.01); // $2.00 + $8.00
}
#[test]
fn test_estimate_cost_cerebras() {
let cost = MeteringEngine::estimate_cost("cerebras/llama3.3-70b", 1_000_000, 1_000_000);
assert!((cost - 0.12).abs() < 0.01); // $0.06 + $0.06
}
#[test]
fn test_estimate_cost_with_catalog() {
let catalog = openfang_runtime::model_catalog::ModelCatalog::new();
// Sonnet: $3/M input, $15/M output
let cost = MeteringEngine::estimate_cost_with_catalog(
&catalog,
"claude-sonnet-4-20250514",
1_000_000,
1_000_000,
);
assert!((cost - 18.0).abs() < 0.01);
}
#[test]
fn test_estimate_cost_with_catalog_alias() {
let catalog = openfang_runtime::model_catalog::ModelCatalog::new();
// "sonnet" alias should resolve to same pricing
let cost =
MeteringEngine::estimate_cost_with_catalog(&catalog, "sonnet", 1_000_000, 1_000_000);
assert!((cost - 18.0).abs() < 0.01);
}
#[test]
fn test_estimate_cost_with_catalog_unknown_uses_default() {
let catalog = openfang_runtime::model_catalog::ModelCatalog::new();
// Unknown model falls back to $1/$3
let cost = MeteringEngine::estimate_cost_with_catalog(
&catalog,
"totally-unknown-model",
1_000_000,
1_000_000,
);
assert!((cost - 4.0).abs() < 0.01);
}
#[test]
fn test_get_summary() {
let engine = setup();
let agent_id = AgentId::new();
engine
.record(&UsageRecord {
agent_id,
model: "haiku".to_string(),
input_tokens: 500,
output_tokens: 200,
cost_usd: 0.005,
tool_calls: 3,
})
.unwrap();
let summary = engine.get_summary(Some(agent_id)).unwrap();
assert_eq!(summary.call_count, 1);
assert_eq!(summary.total_input_tokens, 500);
}
}

View File

@@ -0,0 +1,510 @@
//! Device pairing — QR-code flow for mobile/desktop clients.
//!
//! Supports pairing via short-lived tokens, device management, and
//! push notifications via ntfy.sh or gotify.
use dashmap::DashMap;
use openfang_types::config::PairingConfig;
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
/// Maximum concurrent pairing requests (prevent token flooding).
const MAX_PENDING_REQUESTS: usize = 5;
/// A paired device record.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PairedDevice {
pub device_id: String,
pub display_name: String,
pub platform: String,
pub paired_at: chrono::DateTime<chrono::Utc>,
pub last_seen: chrono::DateTime<chrono::Utc>,
#[serde(skip_serializing)]
pub push_token: Option<String>,
}
/// Pairing request (short-lived, for QR code flow).
#[derive(Debug, Clone)]
pub struct PairingRequest {
pub token: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub expires_at: chrono::DateTime<chrono::Utc>,
}
/// Persistence callback — kernel injects this so PairingManager can save without
/// taking a direct dependency on openfang-memory.
pub type PersistFn = Box<dyn Fn(&PairedDevice, PersistOp) + Send + Sync>;
/// Persistence operation kind.
#[derive(Debug, Clone, Copy)]
pub enum PersistOp {
Save,
Remove,
}
/// Device pairing manager.
pub struct PairingManager {
config: PairingConfig,
pending: DashMap<String, PairingRequest>,
devices: DashMap<String, PairedDevice>,
persist: Option<PersistFn>,
}
impl PairingManager {
pub fn new(config: PairingConfig) -> Self {
Self {
config,
pending: DashMap::new(),
devices: DashMap::new(),
persist: None,
}
}
/// Attach a persistence callback (called after pair/unpair operations).
pub fn set_persist(&mut self, f: PersistFn) {
self.persist = Some(f);
}
/// Bulk-load devices from persistence (call once at boot).
pub fn load_devices(&self, devices: Vec<PairedDevice>) {
for d in devices {
self.devices.insert(d.device_id.clone(), d);
}
debug!(
count = self.devices.len(),
"Loaded paired devices from database"
);
}
/// Generate a new pairing request. Returns token for QR encoding.
pub fn create_pairing_request(&self) -> Result<PairingRequest, String> {
if !self.config.enabled {
return Err("Device pairing is disabled".into());
}
// Enforce max pending limit
if self.pending.len() >= MAX_PENDING_REQUESTS {
// Clean expired first
self.clean_expired();
if self.pending.len() >= MAX_PENDING_REQUESTS {
return Err("Too many pending pairing requests. Try again later.".into());
}
}
// Generate secure random token (32 bytes = 64 hex chars)
let mut token_bytes = [0u8; 32];
use rand::RngCore;
rand::thread_rng().fill_bytes(&mut token_bytes);
let token = hex::encode(token_bytes);
let now = chrono::Utc::now();
let expires_at = now + chrono::Duration::seconds(self.config.token_expiry_secs as i64);
let request = PairingRequest {
token: token.clone(),
created_at: now,
expires_at,
};
self.pending.insert(token, request.clone());
Ok(request)
}
/// Complete pairing — device submits token + device info.
pub fn complete_pairing(
&self,
token: &str,
device_info: PairedDevice,
) -> Result<PairedDevice, String> {
// SECURITY: Constant-time token comparison
let found = self.pending.iter().find(|entry| {
use subtle::ConstantTimeEq;
let stored = entry.value().token.as_bytes();
let provided = token.as_bytes();
if stored.len() != provided.len() {
return false;
}
stored.ct_eq(provided).into()
});
let entry = found.ok_or("Invalid or expired pairing token")?;
let request = entry.value().clone();
let key = entry.key().clone();
drop(entry);
// Check expiry
if chrono::Utc::now() > request.expires_at {
self.pending.remove(&key);
return Err("Pairing token has expired".into());
}
// Check max devices
if self.devices.len() >= self.config.max_devices {
return Err(format!(
"Maximum paired devices ({}) reached. Remove a device first.",
self.config.max_devices
));
}
// Remove the used token
self.pending.remove(&key);
// Store the device
let device_id = device_info.device_id.clone();
self.devices.insert(device_id.clone(), device_info.clone());
// Persist to database
if let Some(ref persist) = self.persist {
persist(&device_info, PersistOp::Save);
}
debug!(device_id = %device_id, "Device paired successfully");
Ok(device_info)
}
/// List paired devices.
pub fn list_devices(&self) -> Vec<PairedDevice> {
self.devices.iter().map(|e| e.value().clone()).collect()
}
/// Remove a paired device.
pub fn remove_device(&self, device_id: &str) -> Result<(), String> {
let removed = self
.devices
.remove(device_id)
.ok_or_else(|| format!("Device '{device_id}' not found"))?;
// Persist removal to database
if let Some(ref persist) = self.persist {
persist(&removed.1, PersistOp::Remove);
}
Ok(())
}
/// Send push notification to all paired devices.
pub async fn notify_devices(
&self,
title: &str,
body: &str,
) -> Vec<(String, Result<(), String>)> {
let mut results = Vec::new();
match self.config.push_provider.as_str() {
"ntfy" => {
let url = self.config.ntfy_url.as_deref().unwrap_or("https://ntfy.sh");
let topic = match &self.config.ntfy_topic {
Some(t) => t.clone(),
None => {
results.push(("ntfy".to_string(), Err("ntfy_topic not configured".into())));
return results;
}
};
let full_url = format!("{}/{}", url.trim_end_matches('/'), topic);
let client = reqwest::Client::new();
match client
.post(&full_url)
.header("Title", title)
.body(body.to_string())
.timeout(std::time::Duration::from_secs(10))
.send()
.await
{
Ok(resp) if resp.status().is_success() => {
for device in self.devices.iter() {
results.push((device.device_id.clone(), Ok(())));
}
}
Ok(resp) => {
let status = resp.status();
results.push((
"ntfy".to_string(),
Err(format!("ntfy returned HTTP {status}")),
));
}
Err(e) => {
results
.push(("ntfy".to_string(), Err(format!("ntfy request failed: {e}"))));
}
}
}
"gotify" => {
// Gotify requires an app token
let app_token = match std::env::var("GOTIFY_APP_TOKEN") {
Ok(t) => t,
Err(_) => {
results
.push(("gotify".to_string(), Err("GOTIFY_APP_TOKEN not set".into())));
return results;
}
};
let server_url = match std::env::var("GOTIFY_SERVER_URL") {
Ok(u) => u,
Err(_) => {
results.push((
"gotify".to_string(),
Err("GOTIFY_SERVER_URL not set".into()),
));
return results;
}
};
let url = format!("{}/message", server_url.trim_end_matches('/'));
let body_json = serde_json::json!({
"title": title,
"message": body,
"priority": 5,
});
let client = reqwest::Client::new();
match client
.post(&url)
.header("X-Gotify-Key", &app_token)
.json(&body_json)
.timeout(std::time::Duration::from_secs(10))
.send()
.await
{
Ok(resp) if resp.status().is_success() => {
for device in self.devices.iter() {
results.push((device.device_id.clone(), Ok(())));
}
}
Ok(resp) => {
let status = resp.status();
results.push((
"gotify".to_string(),
Err(format!("gotify returned HTTP {status}")),
));
}
Err(e) => {
results.push((
"gotify".to_string(),
Err(format!("gotify request failed: {e}")),
));
}
}
}
"none" | "" => {
// No push provider configured — silent
}
other => {
warn!(provider = other, "Unknown push notification provider");
}
}
results
}
/// Clean expired pairing requests.
pub fn clean_expired(&self) {
let now = chrono::Utc::now();
self.pending.retain(|_, req| req.expires_at > now);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> PairingConfig {
PairingConfig::default()
}
fn enabled_config() -> PairingConfig {
PairingConfig {
enabled: true,
..Default::default()
}
}
#[test]
fn test_manager_creation() {
let mgr = PairingManager::new(default_config());
assert!(mgr.devices.is_empty());
assert!(mgr.pending.is_empty());
}
#[test]
fn test_create_request_disabled() {
let mgr = PairingManager::new(default_config());
let result = mgr.create_pairing_request();
assert!(result.is_err());
assert!(result.unwrap_err().contains("disabled"));
}
#[test]
fn test_create_request_success() {
let mgr = PairingManager::new(enabled_config());
let req = mgr.create_pairing_request().unwrap();
assert_eq!(req.token.len(), 64); // 32 bytes = 64 hex chars
assert!(req.expires_at > req.created_at);
}
#[test]
fn test_max_pending_requests() {
let mgr = PairingManager::new(enabled_config());
for _ in 0..MAX_PENDING_REQUESTS {
mgr.create_pairing_request().unwrap();
}
let result = mgr.create_pairing_request();
assert!(result.is_err());
assert!(result.unwrap_err().contains("Too many"));
}
#[test]
fn test_complete_pairing_invalid_token() {
let mgr = PairingManager::new(enabled_config());
let device = PairedDevice {
device_id: "dev-1".to_string(),
display_name: "My Phone".to_string(),
platform: "android".to_string(),
paired_at: chrono::Utc::now(),
last_seen: chrono::Utc::now(),
push_token: None,
};
let result = mgr.complete_pairing("invalid-token", device);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid"));
}
#[test]
fn test_complete_pairing_success() {
let mgr = PairingManager::new(enabled_config());
let req = mgr.create_pairing_request().unwrap();
let device = PairedDevice {
device_id: "dev-1".to_string(),
display_name: "My Phone".to_string(),
platform: "android".to_string(),
paired_at: chrono::Utc::now(),
last_seen: chrono::Utc::now(),
push_token: None,
};
let result = mgr.complete_pairing(&req.token, device);
assert!(result.is_ok());
assert_eq!(mgr.devices.len(), 1);
assert!(mgr.pending.is_empty()); // Token consumed
}
#[test]
fn test_max_devices_enforced() {
let config = PairingConfig {
enabled: true,
max_devices: 1,
..Default::default()
};
let mgr = PairingManager::new(config);
// Pair first device
let req1 = mgr.create_pairing_request().unwrap();
let d1 = PairedDevice {
device_id: "dev-1".to_string(),
display_name: "Phone 1".to_string(),
platform: "ios".to_string(),
paired_at: chrono::Utc::now(),
last_seen: chrono::Utc::now(),
push_token: None,
};
mgr.complete_pairing(&req1.token, d1).unwrap();
// Try second device
let req2 = mgr.create_pairing_request().unwrap();
let d2 = PairedDevice {
device_id: "dev-2".to_string(),
display_name: "Phone 2".to_string(),
platform: "android".to_string(),
paired_at: chrono::Utc::now(),
last_seen: chrono::Utc::now(),
push_token: None,
};
let result = mgr.complete_pairing(&req2.token, d2);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Maximum"));
}
#[test]
fn test_list_devices() {
let mgr = PairingManager::new(enabled_config());
let req = mgr.create_pairing_request().unwrap();
let device = PairedDevice {
device_id: "dev-1".to_string(),
display_name: "My Phone".to_string(),
platform: "android".to_string(),
paired_at: chrono::Utc::now(),
last_seen: chrono::Utc::now(),
push_token: None,
};
mgr.complete_pairing(&req.token, device).unwrap();
let devices = mgr.list_devices();
assert_eq!(devices.len(), 1);
assert_eq!(devices[0].display_name, "My Phone");
}
#[test]
fn test_remove_device() {
let mgr = PairingManager::new(enabled_config());
let req = mgr.create_pairing_request().unwrap();
let device = PairedDevice {
device_id: "dev-1".to_string(),
display_name: "My Phone".to_string(),
platform: "android".to_string(),
paired_at: chrono::Utc::now(),
last_seen: chrono::Utc::now(),
push_token: None,
};
mgr.complete_pairing(&req.token, device).unwrap();
assert!(mgr.remove_device("dev-1").is_ok());
assert!(mgr.devices.is_empty());
}
#[test]
fn test_remove_nonexistent_device() {
let mgr = PairingManager::new(enabled_config());
assert!(mgr.remove_device("nonexistent").is_err());
}
#[test]
fn test_clean_expired() {
let config = PairingConfig {
enabled: true,
token_expiry_secs: 0, // Expire immediately
..Default::default()
};
let mgr = PairingManager::new(config);
mgr.create_pairing_request().unwrap();
assert_eq!(mgr.pending.len(), 1);
// Wait a tiny bit for expiry
std::thread::sleep(std::time::Duration::from_millis(10));
mgr.clean_expired();
assert!(mgr.pending.is_empty());
}
#[test]
fn test_token_length() {
let mgr = PairingManager::new(enabled_config());
let req = mgr.create_pairing_request().unwrap();
// 32 random bytes = 64 hex chars
assert_eq!(req.token.len(), 64);
}
#[test]
fn test_config_defaults() {
let config = PairingConfig::default();
assert!(!config.enabled);
assert_eq!(config.max_devices, 10);
assert_eq!(config.token_expiry_secs, 300);
assert_eq!(config.push_provider, "none");
assert!(config.ntfy_url.is_none());
assert!(config.ntfy_topic.is_none());
}
}

View File

@@ -0,0 +1,464 @@
//! Agent registry — tracks all agents, their state, and indexes.
use dashmap::DashMap;
use openfang_types::agent::{AgentEntry, AgentId, AgentMode, AgentState};
use openfang_types::error::{OpenFangError, OpenFangResult};
/// Registry of all agents in the kernel.
pub struct AgentRegistry {
/// Primary index: agent ID → entry.
agents: DashMap<AgentId, AgentEntry>,
/// Name index: human-readable name → agent ID.
name_index: DashMap<String, AgentId>,
/// Tag index: tag → list of agent IDs.
tag_index: DashMap<String, Vec<AgentId>>,
}
impl AgentRegistry {
/// Create a new empty registry.
pub fn new() -> Self {
Self {
agents: DashMap::new(),
name_index: DashMap::new(),
tag_index: DashMap::new(),
}
}
/// Register a new agent.
pub fn register(&self, entry: AgentEntry) -> OpenFangResult<()> {
if self.name_index.contains_key(&entry.name) {
return Err(OpenFangError::AgentAlreadyExists(entry.name.clone()));
}
let id = entry.id;
self.name_index.insert(entry.name.clone(), id);
for tag in &entry.tags {
self.tag_index.entry(tag.clone()).or_default().push(id);
}
self.agents.insert(id, entry);
Ok(())
}
/// Get an agent entry by ID.
pub fn get(&self, id: AgentId) -> Option<AgentEntry> {
self.agents.get(&id).map(|e| e.value().clone())
}
/// Find an agent by name.
pub fn find_by_name(&self, name: &str) -> Option<AgentEntry> {
self.name_index
.get(name)
.and_then(|id| self.agents.get(id.value()).map(|e| e.value().clone()))
}
/// Update agent state.
pub fn set_state(&self, id: AgentId, state: AgentState) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.state = state;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update agent operational mode.
pub fn set_mode(&self, id: AgentId, mode: AgentMode) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.mode = mode;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Remove an agent from the registry.
pub fn remove(&self, id: AgentId) -> OpenFangResult<AgentEntry> {
let (_, entry) = self
.agents
.remove(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
self.name_index.remove(&entry.name);
for tag in &entry.tags {
if let Some(mut ids) = self.tag_index.get_mut(tag) {
ids.retain(|&agent_id| agent_id != id);
}
}
Ok(entry)
}
/// List all agents.
pub fn list(&self) -> Vec<AgentEntry> {
self.agents.iter().map(|e| e.value().clone()).collect()
}
/// Add a child agent ID to a parent's children list.
pub fn add_child(&self, parent_id: AgentId, child_id: AgentId) {
if let Some(mut entry) = self.agents.get_mut(&parent_id) {
entry.children.push(child_id);
}
}
/// Count of registered agents.
pub fn count(&self) -> usize {
self.agents.len()
}
/// Update an agent's session ID (for session reset).
pub fn update_session_id(
&self,
id: AgentId,
new_session_id: openfang_types::agent::SessionId,
) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.session_id = new_session_id;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update an agent's workspace path.
pub fn update_workspace(
&self,
id: AgentId,
workspace: Option<std::path::PathBuf>,
) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.manifest.workspace = workspace;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update an agent's visual identity (emoji, avatar, color).
pub fn update_identity(
&self,
id: AgentId,
identity: openfang_types::agent::AgentIdentity,
) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.identity = identity;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update an agent's model configuration.
pub fn update_model(&self, id: AgentId, new_model: String) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.manifest.model.model = new_model;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update an agent's model AND provider together.
pub fn update_model_and_provider(
&self,
id: AgentId,
new_model: String,
new_provider: String,
) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.manifest.model.model = new_model;
entry.manifest.model.provider = new_provider;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update an agent's skill allowlist.
pub fn update_skills(&self, id: AgentId, skills: Vec<String>) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.manifest.skills = skills;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update an agent's MCP server allowlist.
pub fn update_mcp_servers(&self, id: AgentId, servers: Vec<String>) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.manifest.mcp_servers = servers;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update an agent's system prompt (hot-swap, takes effect on next message).
pub fn update_system_prompt(&self, id: AgentId, new_prompt: String) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.manifest.model.system_prompt = new_prompt;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Update an agent's name (also updates the name index).
pub fn update_name(&self, id: AgentId, new_name: String) -> OpenFangResult<()> {
if self.name_index.contains_key(&new_name) {
return Err(OpenFangError::AgentAlreadyExists(new_name));
}
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
let old_name = entry.name.clone();
entry.name = new_name.clone();
entry.manifest.name = new_name.clone();
entry.last_active = chrono::Utc::now();
// Update name index
drop(entry);
self.name_index.remove(&old_name);
self.name_index.insert(new_name, id);
Ok(())
}
/// Update an agent's description.
pub fn update_description(&self, id: AgentId, new_desc: String) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.manifest.description = new_desc;
entry.last_active = chrono::Utc::now();
Ok(())
}
/// Mark an agent's onboarding as complete.
pub fn mark_onboarding_complete(&self, id: AgentId) -> OpenFangResult<()> {
let mut entry = self
.agents
.get_mut(&id)
.ok_or_else(|| OpenFangError::AgentNotFound(id.to_string()))?;
entry.onboarding_completed = true;
entry.onboarding_completed_at = Some(chrono::Utc::now());
entry.last_active = chrono::Utc::now();
Ok(())
}
}
impl Default for AgentRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use openfang_types::agent::*;
use std::collections::HashMap;
fn test_entry(name: &str) -> AgentEntry {
AgentEntry {
id: AgentId::new(),
name: name.to_string(),
manifest: AgentManifest {
name: name.to_string(),
version: "0.1.0".to_string(),
description: "test".to_string(),
author: "test".to_string(),
module: "test".to_string(),
schedule: ScheduleMode::default(),
model: ModelConfig::default(),
fallback_models: vec![],
resources: ResourceQuota::default(),
priority: Priority::default(),
capabilities: ManifestCapabilities::default(),
profile: None,
tools: HashMap::new(),
skills: vec![],
mcp_servers: vec![],
metadata: HashMap::new(),
tags: vec![],
routing: None,
autonomous: None,
pinned_model: None,
workspace: None,
generate_identity_files: true,
exec_policy: None,
},
state: AgentState::Created,
mode: AgentMode::default(),
created_at: Utc::now(),
last_active: Utc::now(),
parent: None,
children: vec![],
session_id: SessionId::new(),
tags: vec![],
identity: Default::default(),
onboarding_completed: false,
onboarding_completed_at: None,
}
}
#[test]
fn test_register_and_get() {
let registry = AgentRegistry::new();
let entry = test_entry("test-agent");
let id = entry.id;
registry.register(entry).unwrap();
assert!(registry.get(id).is_some());
}
#[test]
fn test_find_by_name() {
let registry = AgentRegistry::new();
let entry = test_entry("my-agent");
registry.register(entry).unwrap();
assert!(registry.find_by_name("my-agent").is_some());
}
#[test]
fn test_duplicate_name() {
let registry = AgentRegistry::new();
registry.register(test_entry("dup")).unwrap();
assert!(registry.register(test_entry("dup")).is_err());
}
#[test]
fn test_remove() {
let registry = AgentRegistry::new();
let entry = test_entry("removable");
let id = entry.id;
registry.register(entry).unwrap();
registry.remove(id).unwrap();
assert!(registry.get(id).is_none());
}
#[test]
fn test_set_state() {
let registry = AgentRegistry::new();
let entry = test_entry("state-test");
let id = entry.id;
registry.register(entry).unwrap();
registry.set_state(id, AgentState::Running).unwrap();
let updated = registry.get(id).unwrap();
assert_eq!(updated.state, AgentState::Running);
}
#[test]
fn test_set_mode() {
let registry = AgentRegistry::new();
let entry = test_entry("mode-test");
let id = entry.id;
registry.register(entry).unwrap();
registry.set_mode(id, AgentMode::Autonomous).unwrap();
let updated = registry.get(id).unwrap();
assert_eq!(updated.mode, AgentMode::Autonomous);
}
#[test]
fn test_update_identity() {
let registry = AgentRegistry::new();
let entry = test_entry("identity-test");
let id = entry.id;
registry.register(entry).unwrap();
let identity = AgentIdentity {
emoji: Some("🤖".to_string()),
avatar_url: None,
color: Some("#ff0000".to_string()),
archetype: None,
vibe: None,
greeting_style: None,
};
registry.update_identity(id, identity).unwrap();
let updated = registry.get(id).unwrap();
assert_eq!(updated.identity.emoji, Some("🤖".to_string()));
assert_eq!(updated.identity.color, Some("#ff0000".to_string()));
}
#[test]
fn test_update_name() {
let registry = AgentRegistry::new();
let entry = test_entry("old-name");
let id = entry.id;
registry.register(entry).unwrap();
registry.update_name(id, "new-name".to_string()).unwrap();
// Should be findable by new name
assert!(registry.find_by_name("new-name").is_some());
// Should not be findable by old name
assert!(registry.find_by_name("old-name").is_none());
let updated = registry.get(id).unwrap();
assert_eq!(updated.name, "new-name");
}
#[test]
fn test_update_description() {
let registry = AgentRegistry::new();
let entry = test_entry("desc-test");
let id = entry.id;
registry.register(entry).unwrap();
registry.update_description(id, "New description".to_string()).unwrap();
let updated = registry.get(id).unwrap();
assert_eq!(updated.manifest.description, "New description");
}
#[test]
fn test_update_model() {
let registry = AgentRegistry::new();
let entry = test_entry("model-test");
let id = entry.id;
registry.register(entry).unwrap();
registry.update_model(id, "gpt-4".to_string()).unwrap();
let updated = registry.get(id).unwrap();
assert_eq!(updated.manifest.model.model, "gpt-4");
}
#[test]
fn test_update_skills() {
let registry = AgentRegistry::new();
let entry = test_entry("skills-test");
let id = entry.id;
registry.register(entry).unwrap();
let skills = vec!["web-search".to_string(), "code-execution".to_string()];
registry.update_skills(id, skills.clone()).unwrap();
let updated = registry.get(id).unwrap();
assert_eq!(updated.manifest.skills, skills);
}
#[test]
fn test_update_mcp_servers() {
let registry = AgentRegistry::new();
let entry = test_entry("mcp-test");
let id = entry.id;
registry.register(entry).unwrap();
let servers = vec!["filesystem".to_string(), "database".to_string()];
registry.update_mcp_servers(id, servers.clone()).unwrap();
let updated = registry.get(id).unwrap();
assert_eq!(updated.manifest.mcp_servers, servers);
}
}

View File

@@ -0,0 +1,170 @@
//! Agent scheduler — manages agent execution and resource tracking.
use dashmap::DashMap;
use openfang_types::agent::{AgentId, ResourceQuota};
use openfang_types::error::{OpenFangError, OpenFangResult};
use openfang_types::message::TokenUsage;
use std::time::Instant;
use tokio::task::JoinHandle;
use tracing::debug;
/// Tracks resource usage for an agent with a rolling hourly window.
#[derive(Debug)]
pub struct UsageTracker {
/// Total tokens consumed within the current window.
pub total_tokens: u64,
/// Total tool calls made within the current window.
pub tool_calls: u64,
/// Start of the current usage window.
pub window_start: Instant,
}
impl Default for UsageTracker {
fn default() -> Self {
Self {
total_tokens: 0,
tool_calls: 0,
window_start: Instant::now(),
}
}
}
impl UsageTracker {
/// Reset counters if the current window has expired (1 hour).
fn reset_if_expired(&mut self) {
if self.window_start.elapsed() >= std::time::Duration::from_secs(3600) {
self.total_tokens = 0;
self.tool_calls = 0;
self.window_start = Instant::now();
}
}
}
/// The agent scheduler manages execution ordering and resource quotas.
pub struct AgentScheduler {
/// Resource quotas per agent.
quotas: DashMap<AgentId, ResourceQuota>,
/// Usage tracking per agent.
usage: DashMap<AgentId, UsageTracker>,
/// Active task handles per agent.
tasks: DashMap<AgentId, JoinHandle<()>>,
}
impl AgentScheduler {
/// Create a new scheduler.
pub fn new() -> Self {
Self {
quotas: DashMap::new(),
usage: DashMap::new(),
tasks: DashMap::new(),
}
}
/// Register an agent with its resource quota.
pub fn register(&self, agent_id: AgentId, quota: ResourceQuota) {
self.quotas.insert(agent_id, quota);
self.usage.insert(agent_id, UsageTracker::default());
}
/// Record token usage for an agent.
pub fn record_usage(&self, agent_id: AgentId, usage: &TokenUsage) {
if let Some(mut tracker) = self.usage.get_mut(&agent_id) {
tracker.reset_if_expired();
tracker.total_tokens += usage.total();
}
}
/// Check if an agent has exceeded its quota.
pub fn check_quota(&self, agent_id: AgentId) -> OpenFangResult<()> {
let quota = match self.quotas.get(&agent_id) {
Some(q) => q.clone(),
None => return Ok(()), // No quota = no limit
};
let mut tracker = match self.usage.get_mut(&agent_id) {
Some(t) => t,
None => return Ok(()),
};
// Reset the window if an hour has passed
tracker.reset_if_expired();
if quota.max_llm_tokens_per_hour > 0
&& tracker.total_tokens > quota.max_llm_tokens_per_hour
{
return Err(OpenFangError::QuotaExceeded(format!(
"Token limit exceeded: {} / {}",
tracker.total_tokens, quota.max_llm_tokens_per_hour
)));
}
Ok(())
}
/// Abort an agent's active task.
pub fn abort_task(&self, agent_id: AgentId) {
if let Some((_, handle)) = self.tasks.remove(&agent_id) {
handle.abort();
debug!(agent = %agent_id, "Aborted agent task");
}
}
/// Remove an agent from the scheduler.
pub fn unregister(&self, agent_id: AgentId) {
self.abort_task(agent_id);
self.quotas.remove(&agent_id);
self.usage.remove(&agent_id);
}
/// Get usage stats for an agent.
pub fn get_usage(&self, agent_id: AgentId) -> Option<(u64, u64)> {
self.usage
.get(&agent_id)
.map(|t| (t.total_tokens, t.tool_calls))
}
}
impl Default for AgentScheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_record_usage() {
let scheduler = AgentScheduler::new();
let id = AgentId::new();
scheduler.register(id, ResourceQuota::default());
scheduler.record_usage(
id,
&TokenUsage {
input_tokens: 100,
output_tokens: 50,
},
);
let (tokens, _) = scheduler.get_usage(id).unwrap();
assert_eq!(tokens, 150);
}
#[test]
fn test_quota_check() {
let scheduler = AgentScheduler::new();
let id = AgentId::new();
let quota = ResourceQuota {
max_llm_tokens_per_hour: 100,
..Default::default()
};
scheduler.register(id, quota);
scheduler.record_usage(
id,
&TokenUsage {
input_tokens: 60,
output_tokens: 50,
},
);
assert!(scheduler.check_quota(id).is_err());
}
}

View File

@@ -0,0 +1,227 @@
//! Process supervision — graceful shutdown, signal handling, and health monitoring.
use dashmap::DashMap;
use openfang_types::agent::AgentId;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::watch;
use tracing::{info, warn};
/// Shutdown signal manager with health monitoring.
pub struct Supervisor {
/// Send side of the shutdown signal.
shutdown_tx: watch::Sender<bool>,
/// Receive side of the shutdown signal (clonable).
shutdown_rx: watch::Receiver<bool>,
/// Restart count (how many times agents have been restarted).
restart_count: AtomicU64,
/// Total panics caught across all agents.
panic_count: AtomicU64,
/// Per-agent restart counts for enforcing max_restarts.
agent_restarts: DashMap<AgentId, u32>,
}
impl Supervisor {
/// Create a new supervisor.
pub fn new() -> Self {
let (tx, rx) = watch::channel(false);
Self {
shutdown_tx: tx,
shutdown_rx: rx,
restart_count: AtomicU64::new(0),
panic_count: AtomicU64::new(0),
agent_restarts: DashMap::new(),
}
}
/// Get a receiver that will be notified on shutdown.
pub fn subscribe(&self) -> watch::Receiver<bool> {
self.shutdown_rx.clone()
}
/// Trigger a graceful shutdown.
pub fn shutdown(&self) {
info!("Supervisor: initiating graceful shutdown");
let _ = self.shutdown_tx.send(true);
}
/// Check if shutdown has been requested.
pub fn is_shutting_down(&self) -> bool {
*self.shutdown_rx.borrow()
}
/// Record that a panic was caught during agent execution.
pub fn record_panic(&self) {
self.panic_count.fetch_add(1, Ordering::Relaxed);
warn!(
total_panics = self.panic_count.load(Ordering::Relaxed),
"Agent panic recorded"
);
}
/// Record that an agent was restarted.
pub fn record_restart(&self) {
self.restart_count.fetch_add(1, Ordering::Relaxed);
}
/// Get the total number of panics caught.
pub fn panic_count(&self) -> u64 {
self.panic_count.load(Ordering::Relaxed)
}
/// Get the total number of restarts.
pub fn restart_count(&self) -> u64 {
self.restart_count.load(Ordering::Relaxed)
}
/// Record a restart for a specific agent and check if limit is exceeded.
///
/// Returns Ok(restart_count) if within limit, or Err(count) if limit exceeded.
pub fn record_agent_restart(&self, agent_id: AgentId, max_restarts: u32) -> Result<u32, u32> {
let mut count = self.agent_restarts.entry(agent_id).or_insert(0);
*count += 1;
self.record_restart();
if max_restarts > 0 && *count > max_restarts {
warn!(
agent = %agent_id,
restarts = *count,
max = max_restarts,
"Agent exceeded max restart limit"
);
Err(*count)
} else {
Ok(*count)
}
}
/// Get the restart count for a specific agent.
pub fn agent_restart_count(&self, agent_id: AgentId) -> u32 {
self.agent_restarts.get(&agent_id).map(|r| *r).unwrap_or(0)
}
/// Reset restart counter for an agent (e.g., on manual intervention).
pub fn reset_agent_restarts(&self, agent_id: AgentId) {
self.agent_restarts.remove(&agent_id);
}
/// Get a health summary.
pub fn health(&self) -> SupervisorHealth {
SupervisorHealth {
is_shutting_down: self.is_shutting_down(),
panic_count: self.panic_count(),
restart_count: self.restart_count(),
}
}
}
impl Default for Supervisor {
fn default() -> Self {
Self::new()
}
}
/// Health report from the supervisor.
#[derive(Debug, Clone)]
pub struct SupervisorHealth {
pub is_shutting_down: bool,
pub panic_count: u64,
pub restart_count: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shutdown() {
let supervisor = Supervisor::new();
assert!(!supervisor.is_shutting_down());
supervisor.shutdown();
assert!(supervisor.is_shutting_down());
}
#[test]
fn test_subscribe() {
let supervisor = Supervisor::new();
let rx = supervisor.subscribe();
assert!(!*rx.borrow());
supervisor.shutdown();
assert!(rx.has_changed().unwrap());
}
#[test]
fn test_panic_tracking() {
let supervisor = Supervisor::new();
assert_eq!(supervisor.panic_count(), 0);
supervisor.record_panic();
supervisor.record_panic();
assert_eq!(supervisor.panic_count(), 2);
}
#[test]
fn test_restart_tracking() {
let supervisor = Supervisor::new();
assert_eq!(supervisor.restart_count(), 0);
supervisor.record_restart();
assert_eq!(supervisor.restart_count(), 1);
}
#[test]
fn test_health() {
let supervisor = Supervisor::new();
let health = supervisor.health();
assert!(!health.is_shutting_down);
assert_eq!(health.panic_count, 0);
assert_eq!(health.restart_count, 0);
}
#[test]
fn test_agent_restart_within_limit() {
let supervisor = Supervisor::new();
let agent_id = AgentId::new();
// Allow up to 3 restarts
assert!(supervisor.record_agent_restart(agent_id, 3).is_ok());
assert_eq!(supervisor.agent_restart_count(agent_id), 1);
assert!(supervisor.record_agent_restart(agent_id, 3).is_ok());
assert!(supervisor.record_agent_restart(agent_id, 3).is_ok());
assert_eq!(supervisor.agent_restart_count(agent_id), 3);
}
#[test]
fn test_agent_restart_exceeds_limit() {
let supervisor = Supervisor::new();
let agent_id = AgentId::new();
assert!(supervisor.record_agent_restart(agent_id, 2).is_ok());
assert!(supervisor.record_agent_restart(agent_id, 2).is_ok());
// 3rd restart exceeds max_restarts=2
let result = supervisor.record_agent_restart(agent_id, 2);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), 3);
}
#[test]
fn test_agent_restart_zero_limit_unlimited() {
let supervisor = Supervisor::new();
let agent_id = AgentId::new();
// max_restarts=0 means unlimited
for _ in 0..100 {
assert!(supervisor.record_agent_restart(agent_id, 0).is_ok());
}
}
#[test]
fn test_reset_agent_restarts() {
let supervisor = Supervisor::new();
let agent_id = AgentId::new();
supervisor.record_agent_restart(agent_id, 10).unwrap();
supervisor.record_agent_restart(agent_id, 10).unwrap();
assert_eq!(supervisor.agent_restart_count(agent_id), 2);
supervisor.reset_agent_restarts(agent_id);
assert_eq!(supervisor.agent_restart_count(agent_id), 0);
}
}

View File

@@ -0,0 +1,511 @@
//! Event-driven agent triggers — agents auto-activate when events match patterns.
//!
//! Agents register triggers that describe which events should wake them.
//! When a matching event arrives on the EventBus, the trigger system
//! sends the event content as a message to the subscribing agent.
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use openfang_types::agent::AgentId;
use openfang_types::event::{Event, EventPayload, LifecycleEvent, SystemEvent};
use serde::{Deserialize, Serialize};
use tracing::{debug, info};
use uuid::Uuid;
/// Unique identifier for a trigger.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TriggerId(pub Uuid);
impl TriggerId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl Default for TriggerId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for TriggerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
/// What kind of events a trigger matches on.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TriggerPattern {
/// Match any lifecycle event (agent spawned, started, terminated, etc.).
Lifecycle,
/// Match when a specific agent is spawned.
AgentSpawned { name_pattern: String },
/// Match when any agent is terminated.
AgentTerminated,
/// Match any system event.
System,
/// Match a specific system event by keyword.
SystemKeyword { keyword: String },
/// Match any memory update event.
MemoryUpdate,
/// Match memory updates for a specific key pattern.
MemoryKeyPattern { key_pattern: String },
/// Match all events (wildcard).
All,
/// Match custom events by content substring.
ContentMatch { substring: String },
}
/// A registered trigger definition.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trigger {
/// Unique trigger ID.
pub id: TriggerId,
/// Which agent owns this trigger.
pub agent_id: AgentId,
/// The event pattern to match.
pub pattern: TriggerPattern,
/// Prompt template to send when triggered. Use `{{event}}` for event description.
pub prompt_template: String,
/// Whether this trigger is currently active.
pub enabled: bool,
/// When this trigger was created.
pub created_at: DateTime<Utc>,
/// How many times this trigger has fired.
pub fire_count: u64,
/// Maximum number of times this trigger can fire (0 = unlimited).
pub max_fires: u64,
}
/// The trigger engine manages event-to-agent routing.
pub struct TriggerEngine {
/// All registered triggers.
triggers: DashMap<TriggerId, Trigger>,
/// Index: agent_id → list of trigger IDs belonging to that agent.
agent_triggers: DashMap<AgentId, Vec<TriggerId>>,
}
impl TriggerEngine {
/// Create a new trigger engine.
pub fn new() -> Self {
Self {
triggers: DashMap::new(),
agent_triggers: DashMap::new(),
}
}
/// Register a new trigger.
pub fn register(
&self,
agent_id: AgentId,
pattern: TriggerPattern,
prompt_template: String,
max_fires: u64,
) -> TriggerId {
let trigger = Trigger {
id: TriggerId::new(),
agent_id,
pattern,
prompt_template,
enabled: true,
created_at: Utc::now(),
fire_count: 0,
max_fires,
};
let id = trigger.id;
self.triggers.insert(id, trigger);
self.agent_triggers.entry(agent_id).or_default().push(id);
info!(trigger_id = %id, agent_id = %agent_id, "Trigger registered");
id
}
/// Remove a trigger.
pub fn remove(&self, trigger_id: TriggerId) -> bool {
if let Some((_, trigger)) = self.triggers.remove(&trigger_id) {
if let Some(mut list) = self.agent_triggers.get_mut(&trigger.agent_id) {
list.retain(|id| *id != trigger_id);
}
true
} else {
false
}
}
/// Remove all triggers for an agent.
pub fn remove_agent_triggers(&self, agent_id: AgentId) {
if let Some((_, trigger_ids)) = self.agent_triggers.remove(&agent_id) {
for id in trigger_ids {
self.triggers.remove(&id);
}
}
}
/// Enable or disable a trigger. Returns true if the trigger was found.
pub fn set_enabled(&self, trigger_id: TriggerId, enabled: bool) -> bool {
if let Some(mut t) = self.triggers.get_mut(&trigger_id) {
t.enabled = enabled;
true
} else {
false
}
}
/// List all triggers for an agent.
pub fn list_agent_triggers(&self, agent_id: AgentId) -> Vec<Trigger> {
self.agent_triggers
.get(&agent_id)
.map(|ids| {
ids.iter()
.filter_map(|id| self.triggers.get(id).map(|t| t.clone()))
.collect()
})
.unwrap_or_default()
}
/// List all registered triggers.
pub fn list_all(&self) -> Vec<Trigger> {
self.triggers.iter().map(|e| e.value().clone()).collect()
}
/// Evaluate an event against all triggers. Returns a list of
/// (agent_id, message_to_send) pairs for matching triggers.
pub fn evaluate(&self, event: &Event) -> Vec<(AgentId, String)> {
let event_description = describe_event(event);
let mut matches = Vec::new();
for mut entry in self.triggers.iter_mut() {
let trigger = entry.value_mut();
if !trigger.enabled {
continue;
}
// Check max fires
if trigger.max_fires > 0 && trigger.fire_count >= trigger.max_fires {
trigger.enabled = false;
continue;
}
if matches_pattern(&trigger.pattern, event, &event_description) {
let message = trigger
.prompt_template
.replace("{{event}}", &event_description);
matches.push((trigger.agent_id, message));
trigger.fire_count += 1;
debug!(
trigger_id = %trigger.id,
agent_id = %trigger.agent_id,
fire_count = trigger.fire_count,
"Trigger fired"
);
}
}
matches
}
/// Get a trigger by ID.
pub fn get(&self, trigger_id: TriggerId) -> Option<Trigger> {
self.triggers.get(&trigger_id).map(|t| t.clone())
}
}
impl Default for TriggerEngine {
fn default() -> Self {
Self::new()
}
}
/// Check if an event matches a trigger pattern.
fn matches_pattern(pattern: &TriggerPattern, event: &Event, description: &str) -> bool {
match pattern {
TriggerPattern::All => true,
TriggerPattern::Lifecycle => {
matches!(event.payload, EventPayload::Lifecycle(_))
}
TriggerPattern::AgentSpawned { name_pattern } => {
if let EventPayload::Lifecycle(LifecycleEvent::Spawned { name, .. }) = &event.payload {
name.contains(name_pattern.as_str()) || name_pattern == "*"
} else {
false
}
}
TriggerPattern::AgentTerminated => matches!(
event.payload,
EventPayload::Lifecycle(LifecycleEvent::Terminated { .. })
| EventPayload::Lifecycle(LifecycleEvent::Crashed { .. })
),
TriggerPattern::System => {
matches!(event.payload, EventPayload::System(_))
}
TriggerPattern::SystemKeyword { keyword } => {
if let EventPayload::System(se) = &event.payload {
let se_str = format!("{:?}", se).to_lowercase();
se_str.contains(&keyword.to_lowercase())
} else {
false
}
}
TriggerPattern::MemoryUpdate => {
matches!(event.payload, EventPayload::MemoryUpdate(_))
}
TriggerPattern::MemoryKeyPattern { key_pattern } => {
if let EventPayload::MemoryUpdate(delta) = &event.payload {
delta.key.contains(key_pattern.as_str()) || key_pattern == "*"
} else {
false
}
}
TriggerPattern::ContentMatch { substring } => description
.to_lowercase()
.contains(&substring.to_lowercase()),
}
}
/// Create a human-readable description of an event for use in prompts.
fn describe_event(event: &Event) -> String {
match &event.payload {
EventPayload::Message(msg) => {
format!("Message from {:?}: {}", msg.role, msg.content)
}
EventPayload::ToolResult(tr) => {
format!(
"Tool '{}' {} ({}ms): {}",
tr.tool_id,
if tr.success { "succeeded" } else { "failed" },
tr.execution_time_ms,
&tr.content[..tr.content.len().min(200)]
)
}
EventPayload::MemoryUpdate(delta) => {
format!(
"Memory {:?} on key '{}' for agent {}",
delta.operation, delta.key, delta.agent_id
)
}
EventPayload::Lifecycle(le) => match le {
LifecycleEvent::Spawned { agent_id, name } => {
format!("Agent '{name}' (id: {agent_id}) was spawned")
}
LifecycleEvent::Started { agent_id } => {
format!("Agent {agent_id} started")
}
LifecycleEvent::Suspended { agent_id } => {
format!("Agent {agent_id} suspended")
}
LifecycleEvent::Resumed { agent_id } => {
format!("Agent {agent_id} resumed")
}
LifecycleEvent::Terminated { agent_id, reason } => {
format!("Agent {agent_id} terminated: {reason}")
}
LifecycleEvent::Crashed { agent_id, error } => {
format!("Agent {agent_id} crashed: {error}")
}
},
EventPayload::Network(ne) => {
format!("Network event: {:?}", ne)
}
EventPayload::System(se) => match se {
SystemEvent::KernelStarted => "Kernel started".to_string(),
SystemEvent::KernelStopping => "Kernel stopping".to_string(),
SystemEvent::QuotaWarning {
agent_id,
resource,
usage_percent,
} => format!("Quota warning: agent {agent_id}, {resource} at {usage_percent:.1}%"),
SystemEvent::HealthCheck { status } => {
format!("Health check: {status}")
}
SystemEvent::QuotaEnforced {
agent_id,
spent,
limit,
} => {
format!("Quota enforced: agent {agent_id}, spent ${spent:.4} / ${limit:.4}")
}
SystemEvent::ModelRouted {
agent_id,
complexity,
model,
} => {
format!("Model routed: agent {agent_id}, complexity={complexity}, model={model}")
}
SystemEvent::UserAction {
user_id,
action,
result,
} => {
format!("User action: {user_id} {action} -> {result}")
}
SystemEvent::HealthCheckFailed {
agent_id,
unresponsive_secs,
} => {
format!(
"Health check failed: agent {agent_id}, unresponsive for {unresponsive_secs}s"
)
}
},
EventPayload::Custom(data) => {
format!("Custom event ({} bytes)", data.len())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use openfang_types::event::*;
#[test]
fn test_register_trigger() {
let engine = TriggerEngine::new();
let agent_id = AgentId::new();
let id = engine.register(
agent_id,
TriggerPattern::All,
"Event occurred: {{event}}".to_string(),
0,
);
assert!(engine.get(id).is_some());
}
#[test]
fn test_evaluate_lifecycle() {
let engine = TriggerEngine::new();
let watcher = AgentId::new();
engine.register(
watcher,
TriggerPattern::Lifecycle,
"Lifecycle: {{event}}".to_string(),
0,
);
let event = Event::new(
AgentId::new(),
EventTarget::Broadcast,
EventPayload::Lifecycle(LifecycleEvent::Spawned {
agent_id: AgentId::new(),
name: "new-agent".to_string(),
}),
);
let matches = engine.evaluate(&event);
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].0, watcher);
assert!(matches[0].1.contains("new-agent"));
}
#[test]
fn test_evaluate_agent_spawned_pattern() {
let engine = TriggerEngine::new();
let watcher = AgentId::new();
engine.register(
watcher,
TriggerPattern::AgentSpawned {
name_pattern: "coder".to_string(),
},
"Coder spawned: {{event}}".to_string(),
0,
);
// This should match
let event = Event::new(
AgentId::new(),
EventTarget::Broadcast,
EventPayload::Lifecycle(LifecycleEvent::Spawned {
agent_id: AgentId::new(),
name: "coder".to_string(),
}),
);
assert_eq!(engine.evaluate(&event).len(), 1);
// This should NOT match
let event2 = Event::new(
AgentId::new(),
EventTarget::Broadcast,
EventPayload::Lifecycle(LifecycleEvent::Spawned {
agent_id: AgentId::new(),
name: "researcher".to_string(),
}),
);
assert_eq!(engine.evaluate(&event2).len(), 0);
}
#[test]
fn test_max_fires() {
let engine = TriggerEngine::new();
let agent_id = AgentId::new();
engine.register(
agent_id,
TriggerPattern::All,
"Event: {{event}}".to_string(),
2, // max 2 fires
);
let event = Event::new(
AgentId::new(),
EventTarget::Broadcast,
EventPayload::System(SystemEvent::HealthCheck {
status: "ok".to_string(),
}),
);
// First two should match
assert_eq!(engine.evaluate(&event).len(), 1);
assert_eq!(engine.evaluate(&event).len(), 1);
// Third should not
assert_eq!(engine.evaluate(&event).len(), 0);
}
#[test]
fn test_remove_trigger() {
let engine = TriggerEngine::new();
let agent_id = AgentId::new();
let id = engine.register(agent_id, TriggerPattern::All, "msg".to_string(), 0);
assert!(engine.remove(id));
assert!(engine.get(id).is_none());
}
#[test]
fn test_remove_agent_triggers() {
let engine = TriggerEngine::new();
let agent_id = AgentId::new();
engine.register(agent_id, TriggerPattern::All, "a".to_string(), 0);
engine.register(agent_id, TriggerPattern::System, "b".to_string(), 0);
assert_eq!(engine.list_agent_triggers(agent_id).len(), 2);
engine.remove_agent_triggers(agent_id);
assert_eq!(engine.list_agent_triggers(agent_id).len(), 0);
}
#[test]
fn test_content_match() {
let engine = TriggerEngine::new();
let agent_id = AgentId::new();
engine.register(
agent_id,
TriggerPattern::ContentMatch {
substring: "quota".to_string(),
},
"Alert: {{event}}".to_string(),
0,
);
let event = Event::new(
AgentId::new(),
EventTarget::System,
EventPayload::System(SystemEvent::QuotaWarning {
agent_id: AgentId::new(),
resource: "tokens".to_string(),
usage_percent: 85.0,
}),
);
assert_eq!(engine.evaluate(&event).len(), 1);
}
}

View File

@@ -0,0 +1,345 @@
//! WhatsApp Web gateway — embedded Node.js process management.
//!
//! Embeds the gateway JS at compile time, extracts it to `~/.openfang/whatsapp-gateway/`,
//! runs `npm install` if needed, and spawns `node index.js` as a managed child process
//! that auto-restarts on crash.
use crate::config::openfang_home;
use std::path::PathBuf;
use std::sync::Arc;
use tracing::{info, warn};
/// Gateway source files embedded at compile time.
const GATEWAY_INDEX_JS: &str =
include_str!("../../../packages/whatsapp-gateway/index.js");
const GATEWAY_PACKAGE_JSON: &str =
include_str!("../../../packages/whatsapp-gateway/package.json");
/// Default port for the WhatsApp Web gateway.
const DEFAULT_GATEWAY_PORT: u16 = 3009;
/// Maximum restart attempts before giving up.
const MAX_RESTARTS: u32 = 3;
/// Restart backoff delays in seconds: 5s, 10s, 20s.
const RESTART_DELAYS: [u64; 3] = [5, 10, 20];
/// Get the gateway installation directory.
fn gateway_dir() -> PathBuf {
openfang_home().join("whatsapp-gateway")
}
/// Compute a simple hash of content for change detection.
fn content_hash(content: &str) -> String {
// Use a simple FNV-style hash — no crypto needed, just change detection.
let mut hash: u64 = 0xcbf29ce484222325;
for byte in content.as_bytes() {
hash ^= *byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
format!("{hash:016x}")
}
/// Write a file only if its content hash differs from the existing file.
/// Returns `true` if the file was written (content changed).
fn write_if_changed(path: &std::path::Path, content: &str) -> std::io::Result<bool> {
let hash_path = path.with_extension("hash");
let new_hash = content_hash(content);
// Check existing hash
if let Ok(existing_hash) = std::fs::read_to_string(&hash_path) {
if existing_hash.trim() == new_hash {
return Ok(false); // No change
}
}
std::fs::write(path, content)?;
std::fs::write(&hash_path, &new_hash)?;
Ok(true)
}
/// Ensure the gateway files are extracted and npm dependencies installed.
///
/// Returns the gateway directory path on success, or an error message.
async fn ensure_gateway_installed() -> Result<PathBuf, String> {
let dir = gateway_dir();
std::fs::create_dir_all(&dir).map_err(|e| format!("Failed to create gateway dir: {e}"))?;
let index_path = dir.join("index.js");
let package_path = dir.join("package.json");
// Write files only if content changed (avoids unnecessary npm install)
let index_changed =
write_if_changed(&index_path, GATEWAY_INDEX_JS).map_err(|e| format!("Write index.js: {e}"))?;
let package_changed = write_if_changed(&package_path, GATEWAY_PACKAGE_JSON)
.map_err(|e| format!("Write package.json: {e}"))?;
let node_modules = dir.join("node_modules");
let needs_install = !node_modules.exists() || package_changed;
if needs_install {
info!("Installing WhatsApp gateway npm dependencies...");
// Determine npm command (npm.cmd on Windows, npm elsewhere)
let npm_cmd = if cfg!(windows) { "npm.cmd" } else { "npm" };
let output = tokio::process::Command::new(npm_cmd)
.arg("install")
.arg("--production")
.current_dir(&dir)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.output()
.await
.map_err(|e| format!("npm install failed to start: {e}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(format!("npm install failed: {stderr}"));
}
info!("WhatsApp gateway npm dependencies installed");
} else if index_changed {
info!("WhatsApp gateway index.js updated (binary upgrade)");
}
Ok(dir)
}
/// Check if Node.js is available on the system.
async fn node_available() -> bool {
let node_cmd = if cfg!(windows) { "node.exe" } else { "node" };
tokio::process::Command::new(node_cmd)
.arg("--version")
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.await
.map(|s| s.success())
.unwrap_or(false)
}
/// Start the WhatsApp Web gateway as a managed child process.
///
/// This function:
/// 1. Checks if Node.js is available
/// 2. Extracts and installs the gateway files
/// 3. Spawns `node index.js` with appropriate env vars
/// 4. Sets `WHATSAPP_WEB_GATEWAY_URL` so the daemon finds it
/// 5. Monitors the process and restarts on crash (up to 3 times)
///
/// The PID is stored in the kernel's `whatsapp_gateway_pid` for shutdown cleanup.
pub async fn start_whatsapp_gateway(kernel: &Arc<super::kernel::OpenFangKernel>) {
// Only start if WhatsApp is configured
let wa_config = match &kernel.config.channels.whatsapp {
Some(cfg) => cfg.clone(),
None => return,
};
// Check for Node.js
if !node_available().await {
warn!(
"WhatsApp Web gateway requires Node.js >= 18 but `node` was not found. \
Install Node.js to enable WhatsApp Web integration."
);
return;
}
// Extract and install
let gateway_path = match ensure_gateway_installed().await {
Ok(p) => p,
Err(e) => {
warn!("WhatsApp Web gateway setup failed: {e}");
return;
}
};
let port = DEFAULT_GATEWAY_PORT;
let api_listen = &kernel.config.api_listen;
let openfang_url = format!("http://{api_listen}");
let default_agent = wa_config
.default_agent
.as_deref()
.unwrap_or("assistant")
.to_string();
// Auto-set the env var so the rest of the system finds the gateway
std::env::set_var("WHATSAPP_WEB_GATEWAY_URL", format!("http://127.0.0.1:{port}"));
info!("WHATSAPP_WEB_GATEWAY_URL set to http://127.0.0.1:{port}");
// Spawn with crash monitoring
let kernel_weak = Arc::downgrade(kernel);
let gateway_pid = Arc::clone(&kernel.whatsapp_gateway_pid);
tokio::spawn(async move {
let mut restarts = 0u32;
loop {
let node_cmd = if cfg!(windows) { "node.exe" } else { "node" };
info!("Starting WhatsApp Web gateway (attempt {})", restarts + 1);
let child = tokio::process::Command::new(node_cmd)
.arg("index.js")
.current_dir(&gateway_path)
.env("WHATSAPP_GATEWAY_PORT", port.to_string())
.env("OPENFANG_URL", &openfang_url)
.env("OPENFANG_DEFAULT_AGENT", &default_agent)
.stdout(std::process::Stdio::inherit())
.stderr(std::process::Stdio::inherit())
.spawn();
let mut child = match child {
Ok(c) => c,
Err(e) => {
warn!("Failed to spawn WhatsApp gateway: {e}");
return;
}
};
// Store PID for shutdown cleanup
if let Some(pid) = child.id() {
if let Ok(mut guard) = gateway_pid.lock() {
*guard = Some(pid);
}
info!("WhatsApp Web gateway started (PID {pid})");
}
// Wait for process exit
match child.wait().await {
Ok(status) => {
// Clear stored PID
if let Ok(mut guard) = gateway_pid.lock() {
*guard = None;
}
// Check if kernel is still alive (not shutting down)
let kernel = match kernel_weak.upgrade() {
Some(k) => k,
None => {
info!("WhatsApp gateway exited (kernel dropped)");
return;
}
};
if kernel.supervisor.is_shutting_down() {
info!("WhatsApp gateway stopped (daemon shutting down)");
return;
}
if status.success() {
info!("WhatsApp gateway exited cleanly");
return;
}
warn!(
"WhatsApp gateway crashed (exit: {status}), restart {}/{MAX_RESTARTS}",
restarts + 1
);
}
Err(e) => {
if let Ok(mut guard) = gateway_pid.lock() {
*guard = None;
}
warn!("WhatsApp gateway wait error: {e}");
}
}
restarts += 1;
if restarts >= MAX_RESTARTS {
warn!(
"WhatsApp gateway exceeded max restarts ({MAX_RESTARTS}), giving up"
);
return;
}
// Backoff before restart
let delay = RESTART_DELAYS
.get(restarts as usize - 1)
.copied()
.unwrap_or(20);
info!("Restarting WhatsApp gateway in {delay}s...");
tokio::time::sleep(std::time::Duration::from_secs(delay)).await;
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedded_files_not_empty() {
assert!(!GATEWAY_INDEX_JS.is_empty());
assert!(!GATEWAY_PACKAGE_JSON.is_empty());
assert!(GATEWAY_INDEX_JS.contains("WhatsApp"));
assert!(GATEWAY_PACKAGE_JSON.contains("@openfang/whatsapp-gateway"));
}
#[test]
fn test_content_hash_deterministic() {
let h1 = content_hash("hello world");
let h2 = content_hash("hello world");
assert_eq!(h1, h2);
}
#[test]
fn test_content_hash_changes_on_different_input() {
let h1 = content_hash("version 1");
let h2 = content_hash("version 2");
assert_ne!(h1, h2);
}
#[test]
fn test_gateway_dir_under_openfang_home() {
let dir = gateway_dir();
assert!(dir.ends_with("whatsapp-gateway"));
assert!(dir
.parent()
.unwrap()
.to_string_lossy()
.contains(".openfang"));
}
#[test]
fn test_write_if_changed_creates_new_file() {
let tmp = std::env::temp_dir().join("openfang_test_gateway");
let _ = std::fs::create_dir_all(&tmp);
let path = tmp.join("test_write.js");
let hash_path = path.with_extension("hash");
// Clean up any previous runs
let _ = std::fs::remove_file(&path);
let _ = std::fs::remove_file(&hash_path);
// First write should return true (new file)
let changed = write_if_changed(&path, "console.log('v1')").unwrap();
assert!(changed);
assert!(path.exists());
assert!(hash_path.exists());
// Same content should return false
let changed = write_if_changed(&path, "console.log('v1')").unwrap();
assert!(!changed);
// Different content should return true
let changed = write_if_changed(&path, "console.log('v2')").unwrap();
assert!(changed);
// Clean up
let _ = std::fs::remove_file(&path);
let _ = std::fs::remove_file(&hash_path);
let _ = std::fs::remove_dir(&tmp);
}
#[test]
fn test_default_gateway_port() {
assert_eq!(DEFAULT_GATEWAY_PORT, 3009);
}
#[test]
fn test_restart_backoff_delays() {
assert_eq!(RESTART_DELAYS, [5, 10, 20]);
assert_eq!(MAX_RESTARTS, 3);
}
}

View File

@@ -0,0 +1,435 @@
//! NL Auto-Bootstrap Wizard — generates agent configs from natural language.
//!
//! The wizard takes a user's natural language description of what they want
//! an agent to do, extracts structured intent, and generates a complete
//! agent manifest (TOML config) ready to spawn.
use openfang_types::agent::{
AgentManifest, ManifestCapabilities, ModelConfig, Priority, ResourceQuota, ScheduleMode,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// The extracted intent from a user's natural language description.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentIntent {
/// Agent name (slug-style).
pub name: String,
/// Short description.
pub description: String,
/// What the agent should do (summarized task).
pub task: String,
/// What skills/tools it needs.
pub skills: Vec<String>,
/// Suggested model tier (simple, medium, complex).
pub model_tier: String,
/// Whether it runs on a schedule.
pub scheduled: bool,
/// Schedule expression (cron or interval).
pub schedule: Option<String>,
/// Suggested capabilities.
pub capabilities: Vec<String>,
}
/// A generated setup plan from the wizard.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SetupPlan {
/// The extracted intent.
pub intent: AgentIntent,
/// Generated agent manifest (ready to write as TOML).
pub manifest: AgentManifest,
/// Skills to install (if not already installed).
pub skills_to_install: Vec<String>,
/// Human-readable summary of what will be created.
pub summary: String,
}
/// The setup wizard builds agent configurations from natural language.
pub struct SetupWizard;
impl SetupWizard {
/// Build a setup plan from an extracted intent.
///
/// This maps the intent into a concrete agent manifest with appropriate
/// model configuration, capabilities, and schedule.
pub fn build_plan(intent: AgentIntent) -> SetupPlan {
// Map model tier to provider/model
let (provider, model) = match intent.model_tier.as_str() {
"simple" => ("groq", "llama-3.3-70b-versatile"),
"complex" => ("anthropic", "claude-sonnet-4-20250514"),
_ => ("groq", "llama-3.3-70b-versatile"), // medium default
};
// Build capabilities from intent
let mut caps = ManifestCapabilities::default();
for cap in &intent.capabilities {
match cap.as_str() {
"web" | "network" => caps.network.push("*".to_string()),
"file_read" => caps.tools.push("file_read".to_string()),
"file_write" => caps.tools.push("file_write".to_string()),
"file" | "files" => {
for t in &["file_read", "file_write", "file_list"] {
let s = t.to_string();
if !caps.tools.contains(&s) {
caps.tools.push(s);
}
}
}
"shell" => caps.shell.push("*".to_string()),
"memory" => {
caps.memory_read.push("*".to_string());
caps.memory_write.push("*".to_string());
for t in &["memory_store", "memory_recall"] {
let s = t.to_string();
if !caps.tools.contains(&s) {
caps.tools.push(s);
}
}
}
"browser" | "browse" => {
caps.network.push("*".to_string());
for t in &[
"browser_navigate",
"browser_click",
"browser_type",
"browser_read_page",
"browser_screenshot",
"browser_close",
] {
let s = t.to_string();
if !caps.tools.contains(&s) {
caps.tools.push(s);
}
}
}
other => caps.tools.push(other.to_string()),
}
}
// Add web_search + web_fetch if web/network capability is needed
if caps.network.contains(&"*".to_string()) {
for t in &["web_search", "web_fetch"] {
let s = t.to_string();
if !caps.tools.contains(&s) {
caps.tools.push(s);
}
}
}
// Build schedule
let schedule = if intent.scheduled {
if let Some(ref cron) = intent.schedule {
ScheduleMode::Periodic { cron: cron.clone() }
} else {
ScheduleMode::default()
}
} else {
ScheduleMode::default()
};
// Build system prompt — rich enough to guide the agent on its task.
// The prompt_builder will wrap this with tool descriptions, memory protocol,
// safety guidelines, etc. at execution time.
let tool_hints = Self::tool_hints_for(&caps.tools);
let system_prompt = format!(
"You are {name}, an AI agent running inside the OpenFang Agent OS.\n\
\n\
YOUR TASK: {task}\n\
\n\
APPROACH:\n\
- Understand the request fully before acting.\n\
- Use your tools to accomplish the task rather than just describing what to do.\n\
- If you need information, search for it. If you need to read a file, read it.\n\
- Be concise in your responses. Lead with results, not process narration.\n\
{tool_hints}",
name = intent.name,
task = intent.task,
tool_hints = tool_hints,
);
let manifest = AgentManifest {
name: intent.name.clone(),
version: "0.1.0".to_string(),
description: intent.description.clone(),
author: "wizard".to_string(),
module: "builtin:chat".to_string(),
schedule,
model: ModelConfig {
provider: provider.to_string(),
model: model.to_string(),
max_tokens: 4096,
temperature: 0.7,
system_prompt,
api_key_env: None,
base_url: None,
},
resources: ResourceQuota::default(),
priority: Priority::default(),
capabilities: caps,
tools: HashMap::new(),
skills: intent.skills.clone(),
mcp_servers: vec![],
metadata: HashMap::new(),
tags: vec![],
routing: None,
autonomous: None,
pinned_model: None,
workspace: None,
generate_identity_files: true,
profile: None,
fallback_models: vec![],
exec_policy: None,
};
let skills_to_install: Vec<String> = intent
.skills
.iter()
.filter(|s| !s.is_empty())
.cloned()
.collect();
let summary = format!(
"Agent '{}': {}\n Model: {}/{}\n Skills: {}\n Schedule: {}",
intent.name,
intent.description,
provider,
model,
if skills_to_install.is_empty() {
"none".to_string()
} else {
skills_to_install.join(", ")
},
if intent.scheduled {
intent.schedule.as_deref().unwrap_or("on-demand")
} else {
"on-demand"
}
);
SetupPlan {
intent,
manifest,
skills_to_install,
summary,
}
}
/// Build a short tool usage hint block for the system prompt based on granted tools.
fn tool_hints_for(tools: &[String]) -> String {
let mut hints = Vec::new();
let has = |name: &str| tools.iter().any(|t| t == name);
if has("web_search") {
hints.push("- Use web_search to find current information on any topic.");
}
if has("web_fetch") {
hints.push("- Use web_fetch to read the full content of a specific URL as markdown.");
}
if has("browser_navigate") {
hints.push("- Use browser_navigate/click/type/read_page to interact with websites.");
}
if has("file_read") {
hints.push("- Use file_read to examine files before modifying them.");
}
if has("shell_exec") {
hints.push(
"- Use shell_exec to run commands. Explain destructive commands before running.",
);
}
if has("memory_store") {
hints.push(
"- Use memory_store/memory_recall to persist and retrieve important context.",
);
}
if hints.is_empty() {
String::new()
} else {
format!("\nKEY TOOLS:\n{}", hints.join("\n"))
}
}
/// Generate a TOML string from an agent manifest.
pub fn manifest_to_toml(manifest: &AgentManifest) -> Result<String, toml::ser::Error> {
toml::to_string_pretty(manifest)
}
/// Parse an intent from a JSON string (typically LLM output).
pub fn parse_intent(json: &str) -> Result<AgentIntent, serde_json::Error> {
serde_json::from_str(json)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_intent() -> AgentIntent {
AgentIntent {
name: "research-bot".to_string(),
description: "Researches topics and provides summaries".to_string(),
task: "Search the web for information and provide concise summaries".to_string(),
skills: vec!["web-summarizer".to_string()],
model_tier: "medium".to_string(),
scheduled: false,
schedule: None,
capabilities: vec!["web".to_string(), "memory".to_string()],
}
}
#[test]
fn test_build_plan_basic() {
let intent = sample_intent();
let plan = SetupWizard::build_plan(intent);
assert_eq!(plan.manifest.name, "research-bot");
assert_eq!(plan.manifest.model.provider, "groq");
assert!(plan
.manifest
.capabilities
.network
.contains(&"*".to_string()));
assert!(plan.summary.contains("research-bot"));
}
#[test]
fn test_build_plan_complex_tier() {
let mut intent = sample_intent();
intent.model_tier = "complex".to_string();
let plan = SetupWizard::build_plan(intent);
assert_eq!(plan.manifest.model.provider, "anthropic");
assert!(plan.manifest.model.model.contains("sonnet"));
}
#[test]
fn test_build_plan_scheduled() {
let mut intent = sample_intent();
intent.scheduled = true;
intent.schedule = Some("0 */6 * * *".to_string());
let plan = SetupWizard::build_plan(intent);
match &plan.manifest.schedule {
ScheduleMode::Periodic { cron } => {
assert_eq!(cron, "0 */6 * * *");
}
_ => panic!("Expected periodic schedule mode"),
}
}
#[test]
fn test_parse_intent_json() {
let json = r#"{
"name": "code-reviewer",
"description": "Reviews code and suggests improvements",
"task": "Analyze pull requests and provide feedback",
"skills": [],
"model_tier": "complex",
"scheduled": false,
"schedule": null,
"capabilities": ["file_read"]
}"#;
let intent = SetupWizard::parse_intent(json).unwrap();
assert_eq!(intent.name, "code-reviewer");
assert_eq!(intent.model_tier, "complex");
}
#[test]
fn test_manifest_to_toml() {
let intent = sample_intent();
let plan = SetupWizard::build_plan(intent);
let toml = SetupWizard::manifest_to_toml(&plan.manifest);
assert!(toml.is_ok());
let toml_str = toml.unwrap();
assert!(toml_str.contains("research-bot"));
}
#[test]
fn test_web_tools_auto_added() {
let intent = AgentIntent {
name: "test".to_string(),
description: "test".to_string(),
task: "test".to_string(),
skills: vec![],
model_tier: "simple".to_string(),
scheduled: false,
schedule: None,
capabilities: vec!["web".to_string()],
};
let plan = SetupWizard::build_plan(intent);
assert!(plan
.manifest
.capabilities
.tools
.contains(&"web_fetch".to_string()));
assert!(plan
.manifest
.capabilities
.tools
.contains(&"web_search".to_string()));
}
#[test]
fn test_memory_tools_auto_added() {
let intent = AgentIntent {
name: "test".to_string(),
description: "test".to_string(),
task: "test".to_string(),
skills: vec![],
model_tier: "simple".to_string(),
scheduled: false,
schedule: None,
capabilities: vec!["memory".to_string()],
};
let plan = SetupWizard::build_plan(intent);
assert!(plan
.manifest
.capabilities
.tools
.contains(&"memory_store".to_string()));
assert!(plan
.manifest
.capabilities
.tools
.contains(&"memory_recall".to_string()));
}
#[test]
fn test_browser_tools_auto_added() {
let intent = AgentIntent {
name: "test".to_string(),
description: "test".to_string(),
task: "test".to_string(),
skills: vec![],
model_tier: "simple".to_string(),
scheduled: false,
schedule: None,
capabilities: vec!["browser".to_string()],
};
let plan = SetupWizard::build_plan(intent);
assert!(plan
.manifest
.capabilities
.tools
.contains(&"browser_navigate".to_string()));
assert!(plan
.manifest
.capabilities
.tools
.contains(&"browser_click".to_string()));
assert!(plan
.manifest
.capabilities
.tools
.contains(&"browser_read_page".to_string()));
}
#[test]
fn test_wizard_system_prompt_has_task() {
let intent = sample_intent();
let plan = SetupWizard::build_plan(intent);
assert!(plan.manifest.model.system_prompt.contains("YOUR TASK:"));
assert!(plan.manifest.model.system_prompt.contains("Search the web"));
}
}

File diff suppressed because it is too large Load Diff