feat(runtime): DeerFlow 模式中间件链 Phase 1-4 全部完成

借鉴 DeerFlow 架构,实现完整中间件链系统:

Phase 1 - Agent 中间件链基础设施
- MiddlewareChain Clone 支持
- LoopRunner 双路径集成 (middleware/legacy)
- Kernel create_middleware_chain() 工厂方法

Phase 2 - 技能按需注入
- SkillIndexMiddleware (priority 200)
- SkillLoadTool 工具
- SkillDetail/SkillIndexEntry 结构体
- KernelSkillExecutor trait 扩展

Phase 3 - Guardrail 安全护栏
- GuardrailMiddleware (priority 400, fail_open)
- ShellExecRule / FileWriteRule / WebFetchRule

Phase 4 - 记忆闭环统一
- MemoryMiddleware (priority 150, 30s 防抖)
- after_completion 双路径调用

中间件注册顺序:
100 Compaction | 150 Memory | 200 SkillIndex
400 Guardrail  | 500 LoopGuard | 700 TokenCalibration

向后兼容:Option<MiddlewareChain> 默认 None 走旧路径
This commit is contained in:
iven
2026-03-29 23:19:41 +08:00
parent 7de294375b
commit 04c366fe8b
15 changed files with 1302 additions and 43 deletions

View File

@@ -27,7 +27,7 @@ pub struct SqliteStorage {
}
/// Database row structure for memory entry
struct MemoryRow {
pub(crate) struct MemoryRow {
uri: String,
memory_type: String,
content: String,

View File

@@ -86,6 +86,32 @@ impl SkillExecutor for KernelSkillExecutor {
let result = self.skills.execute(&zclaw_types::SkillId::new(skill_id), &context, input).await?;
Ok(result.output)
}
fn get_skill_detail(&self, skill_id: &str) -> Option<zclaw_runtime::tool::SkillDetail> {
let manifests = self.skills.manifests_snapshot();
let manifest = manifests.get(&zclaw_types::SkillId::new(skill_id))?;
Some(zclaw_runtime::tool::SkillDetail {
id: manifest.id.as_str().to_string(),
name: manifest.name.clone(),
description: manifest.description.clone(),
category: manifest.category.clone(),
input_schema: manifest.input_schema.clone(),
triggers: manifest.triggers.clone(),
capabilities: manifest.capabilities.clone(),
})
}
fn list_skill_index(&self) -> Vec<zclaw_runtime::tool::SkillIndexEntry> {
let manifests = self.skills.manifests_snapshot();
manifests.values()
.filter(|m| m.enabled)
.map(|m| zclaw_runtime::tool::SkillIndexEntry {
id: m.id.as_str().to_string(),
description: m.description.clone(),
triggers: m.triggers.clone(),
})
.collect()
}
}
/// The ZCLAW Kernel
@@ -205,6 +231,68 @@ impl Kernel {
tools
}
/// Create the middleware chain for the agent loop.
///
/// When middleware is configured, cross-cutting concerns (compaction, loop guard,
/// token calibration, etc.) are delegated to the chain. When no middleware is
/// registered, the legacy inline path in `AgentLoop` is used instead.
fn create_middleware_chain(&self) -> Option<zclaw_runtime::middleware::MiddlewareChain> {
let mut chain = zclaw_runtime::middleware::MiddlewareChain::new();
// Compaction middleware — only register when threshold > 0
let threshold = self.config.compaction_threshold();
if threshold > 0 {
use std::sync::Arc;
let mw = zclaw_runtime::middleware::compaction::CompactionMiddleware::new(
threshold,
zclaw_runtime::CompactionConfig::default(),
Some(self.driver.clone()),
None, // growth not wired in kernel yet
);
chain.register(Arc::new(mw));
}
// Loop guard middleware
{
use std::sync::Arc;
let mw = zclaw_runtime::middleware::loop_guard::LoopGuardMiddleware::with_defaults();
chain.register(Arc::new(mw));
}
// Token calibration middleware
{
use std::sync::Arc;
let mw = zclaw_runtime::middleware::token_calibration::TokenCalibrationMiddleware::new();
chain.register(Arc::new(mw));
}
// Skill index middleware — inject lightweight index instead of full descriptions
{
use std::sync::Arc;
let entries = self.skill_executor.list_skill_index();
if !entries.is_empty() {
let mw = zclaw_runtime::middleware::skill_index::SkillIndexMiddleware::new(entries);
chain.register(Arc::new(mw));
}
}
// Guardrail middleware — safety rules for tool calls
{
use std::sync::Arc;
let mw = zclaw_runtime::middleware::guardrail::GuardrailMiddleware::new(true)
.with_builtin_rules();
chain.register(Arc::new(mw));
}
// Only return Some if we actually registered middleware
if chain.is_empty() {
None
} else {
tracing::info!("[Kernel] Middleware chain created with {} middlewares", chain.len());
Some(chain)
}
}
/// Build a system prompt with skill information injected
async fn build_system_prompt_with_skills(&self, base_prompt: Option<&String>) -> String {
// Get skill list asynchronously
@@ -417,6 +505,11 @@ impl Kernel {
loop_runner = loop_runner.with_path_validator(path_validator);
}
// Inject middleware chain if available
if let Some(chain) = self.create_middleware_chain() {
loop_runner = loop_runner.with_middleware_chain(chain);
}
// Build system prompt with skill information injected
let system_prompt = self.build_system_prompt_with_skills(agent_config.system_prompt.as_ref()).await;
let loop_runner = loop_runner.with_system_prompt(&system_prompt);
@@ -501,6 +594,11 @@ impl Kernel {
loop_runner = loop_runner.with_path_validator(path_validator);
}
// Inject middleware chain if available
if let Some(chain) = self.create_middleware_chain() {
loop_runner = loop_runner.with_middleware_chain(chain);
}
// Use external prompt if provided, otherwise build default
let system_prompt = match system_prompt_override {
Some(prompt) => prompt,

View File

@@ -15,6 +15,7 @@ pub mod loop_guard;
pub mod stream;
pub mod growth;
pub mod compaction;
pub mod middleware;
// Re-export main types
pub use driver::{

View File

@@ -13,6 +13,7 @@ use crate::tool::builtin::PathValidator;
use crate::loop_guard::{LoopGuard, LoopGuardResult};
use crate::growth::GrowthIntegration;
use crate::compaction::{self, CompactionConfig};
use crate::middleware::{self, MiddlewareChain};
use zclaw_memory::MemoryStore;
/// Agent loop runner
@@ -34,6 +35,10 @@ pub struct AgentLoop {
compaction_threshold: usize,
/// Compaction behavior configuration
compaction_config: CompactionConfig,
/// Optional middleware chain — when `Some`, cross-cutting logic is
/// delegated to the chain instead of the inline code below.
/// When `None`, the legacy inline path is used (100% backward compatible).
middleware_chain: Option<MiddlewareChain>,
}
impl AgentLoop {
@@ -58,6 +63,7 @@ impl AgentLoop {
growth: None,
compaction_threshold: 0,
compaction_config: CompactionConfig::default(),
middleware_chain: None,
}
}
@@ -124,6 +130,14 @@ impl AgentLoop {
self
}
/// Inject a middleware chain. When set, cross-cutting concerns (compaction,
/// loop guard, token calibration, etc.) are delegated to the chain instead
/// of the inline logic.
pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
self.middleware_chain = Some(chain);
self
}
/// Get growth integration reference
pub fn growth(&self) -> Option<&GrowthIntegration> {
self.growth.as_ref()
@@ -175,8 +189,10 @@ impl AgentLoop {
// Get all messages for context
let mut messages = self.memory.get_messages(&session_id).await?;
// Apply compaction if threshold is configured
if self.compaction_threshold > 0 {
let use_middleware = self.middleware_chain.is_some();
// Apply compaction — skip inline path when middleware chain handles it
if !use_middleware && self.compaction_threshold > 0 {
let needs_async =
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
if needs_async {
@@ -196,14 +212,44 @@ impl AgentLoop {
}
}
// Enhance system prompt with growth memories
let enhanced_prompt = if let Some(ref growth) = self.growth {
// Enhance system prompt — skip when middleware chain handles it
let mut enhanced_prompt = if use_middleware {
self.system_prompt.clone().unwrap_or_default()
} else if let Some(ref growth) = self.growth {
let base = self.system_prompt.as_deref().unwrap_or("");
growth.enhance_prompt(&self.agent_id, base, &input).await?
} else {
self.system_prompt.clone().unwrap_or_default()
};
// Run middleware before_completion hooks (compaction, memory inject, etc.)
if let Some(ref chain) = self.middleware_chain {
let mut mw_ctx = middleware::MiddlewareContext {
agent_id: self.agent_id.clone(),
session_id: session_id.clone(),
user_input: input.clone(),
system_prompt: enhanced_prompt.clone(),
messages,
response_content: Vec::new(),
input_tokens: 0,
output_tokens: 0,
};
match chain.run_before_completion(&mut mw_ctx).await? {
middleware::MiddlewareDecision::Continue => {
messages = mw_ctx.messages;
enhanced_prompt = mw_ctx.system_prompt;
}
middleware::MiddlewareDecision::Stop(reason) => {
return Ok(AgentLoopResult {
response: reason,
input_tokens: 0,
output_tokens: 0,
iterations: 1,
});
}
}
}
let max_iterations = 10;
let mut iterations = 0;
let mut total_input_tokens = 0u32;
@@ -307,24 +353,56 @@ impl AgentLoop {
let tool_context = self.create_tool_context(session_id.clone());
let mut circuit_breaker_triggered = false;
for (id, name, input) in tool_calls {
// Check loop guard before executing tool
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
match guard_result {
LoopGuardResult::CircuitBreaker => {
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
circuit_breaker_triggered = true;
break;
// Check tool call safety — via middleware chain or inline loop guard
if let Some(ref chain) = self.middleware_chain {
let mw_ctx_ref = middleware::MiddlewareContext {
agent_id: self.agent_id.clone(),
session_id: session_id.clone(),
user_input: input.to_string(),
system_prompt: enhanced_prompt.clone(),
messages: messages.clone(),
response_content: Vec::new(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
};
match chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? {
middleware::ToolCallDecision::Allow => {}
middleware::ToolCallDecision::Block(msg) => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
let error_output = serde_json::json!({ "error": msg });
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
}
middleware::ToolCallDecision::ReplaceInput(new_input) => {
// Execute with replaced input
let tool_result = match self.execute_tool(&name, new_input, &tool_context).await {
Ok(result) => result,
Err(e) => serde_json::json!({ "error": e.to_string() }),
};
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false));
continue;
}
}
LoopGuardResult::Blocked => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
} else {
// Legacy inline path
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
match guard_result {
LoopGuardResult::CircuitBreaker => {
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
circuit_breaker_triggered = true;
break;
}
LoopGuardResult::Blocked => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
}
LoopGuardResult::Warn => {
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
}
LoopGuardResult::Allowed => {}
}
LoopGuardResult::Warn => {
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
}
LoopGuardResult::Allowed => {}
}
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
@@ -356,8 +434,23 @@ impl AgentLoop {
}
};
// Process conversation for memory extraction (post-conversation)
if let Some(ref growth) = self.growth {
// Post-completion processing — middleware chain or inline growth
if let Some(ref chain) = self.middleware_chain {
let mw_ctx = middleware::MiddlewareContext {
agent_id: self.agent_id.clone(),
session_id: session_id.clone(),
user_input: input.clone(),
system_prompt: enhanced_prompt.clone(),
messages: self.memory.get_messages(&session_id).await.unwrap_or_default(),
response_content: Vec::new(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
};
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
tracing::warn!("[AgentLoop] Middleware after_completion failed: {}", e);
}
} else if let Some(ref growth) = self.growth {
// Legacy inline path
if let Ok(all_messages) = self.memory.get_messages(&session_id).await {
if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await {
tracing::warn!("[AgentLoop] Growth processing failed: {}", e);
@@ -384,8 +477,10 @@ impl AgentLoop {
// Get all messages for context
let mut messages = self.memory.get_messages(&session_id).await?;
// Apply compaction if threshold is configured
if self.compaction_threshold > 0 {
let use_middleware = self.middleware_chain.is_some();
// Apply compaction — skip inline path when middleware chain handles it
if !use_middleware && self.compaction_threshold > 0 {
let needs_async =
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
if needs_async {
@@ -405,20 +500,52 @@ impl AgentLoop {
}
}
// Enhance system prompt with growth memories
let enhanced_prompt = if let Some(ref growth) = self.growth {
// Enhance system prompt — skip when middleware chain handles it
let mut enhanced_prompt = if use_middleware {
self.system_prompt.clone().unwrap_or_default()
} else if let Some(ref growth) = self.growth {
let base = self.system_prompt.as_deref().unwrap_or("");
growth.enhance_prompt(&self.agent_id, base, &input).await?
} else {
self.system_prompt.clone().unwrap_or_default()
};
// Run middleware before_completion hooks (compaction, memory inject, etc.)
if let Some(ref chain) = self.middleware_chain {
let mut mw_ctx = middleware::MiddlewareContext {
agent_id: self.agent_id.clone(),
session_id: session_id.clone(),
user_input: input.clone(),
system_prompt: enhanced_prompt.clone(),
messages,
response_content: Vec::new(),
input_tokens: 0,
output_tokens: 0,
};
match chain.run_before_completion(&mut mw_ctx).await? {
middleware::MiddlewareDecision::Continue => {
messages = mw_ctx.messages;
enhanced_prompt = mw_ctx.system_prompt;
}
middleware::MiddlewareDecision::Stop(reason) => {
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
response: reason,
input_tokens: 0,
output_tokens: 0,
iterations: 1,
})).await;
return Ok(rx);
}
}
}
// Clone necessary data for the async task
let session_id_clone = session_id.clone();
let memory = self.memory.clone();
let driver = self.driver.clone();
let tools = self.tools.clone();
let loop_guard_clone = self.loop_guard.lock().unwrap().clone();
let middleware_chain = self.middleware_chain.clone();
let skill_executor = self.skill_executor.clone();
let path_validator = self.path_validator.clone();
let agent_id = self.agent_id.clone();
@@ -558,6 +685,24 @@ impl AgentLoop {
output_tokens: total_output_tokens,
iterations: iteration,
})).await;
// Post-completion: middleware after_completion (memory extraction, etc.)
if let Some(ref chain) = middleware_chain {
let mw_ctx = middleware::MiddlewareContext {
agent_id: agent_id.clone(),
session_id: session_id_clone.clone(),
user_input: String::new(),
system_prompt: enhanced_prompt.clone(),
messages: memory.get_messages(&session_id_clone).await.unwrap_or_default(),
response_content: Vec::new(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
};
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
tracing::warn!("[AgentLoop] Streaming middleware after_completion failed: {}", e);
}
}
break 'outer;
}
@@ -579,24 +724,92 @@ impl AgentLoop {
for (id, name, input) in pending_tool_calls {
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
// Check loop guard before executing tool
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
match guard_result {
LoopGuardResult::CircuitBreaker => {
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
break 'outer;
// Check tool call safety — via middleware chain or inline loop guard
if let Some(ref chain) = middleware_chain {
let mw_ctx = middleware::MiddlewareContext {
agent_id: agent_id.clone(),
session_id: session_id_clone.clone(),
user_input: input.to_string(),
system_prompt: enhanced_prompt.clone(),
messages: messages.clone(),
response_content: Vec::new(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
};
match chain.run_before_tool_call(&mw_ctx, &name, &input).await {
Ok(middleware::ToolCallDecision::Allow) => {}
Ok(middleware::ToolCallDecision::Block(msg)) => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
let error_output = serde_json::json!({ "error": msg });
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
}
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
// Execute with replaced input (same path_validator logic below)
let pv = path_validator.clone().unwrap_or_else(|| {
let home = std::env::var("USERPROFILE")
.or_else(|_| std::env::var("HOME"))
.unwrap_or_else(|_| ".".to_string());
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
});
let working_dir = pv.workspace_root()
.map(|p| p.to_string_lossy().to_string());
let tool_context = ToolContext {
agent_id: agent_id.clone(),
working_directory: working_dir,
session_id: Some(session_id_clone.to_string()),
skill_executor: skill_executor.clone(),
path_validator: Some(pv),
};
let (result, is_error) = if let Some(tool) = tools.get(&name) {
match tool.execute(new_input, &tool_context).await {
Ok(output) => {
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await;
(output, false)
}
Err(e) => {
let error_output = serde_json::json!({ "error": e.to_string() });
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
(error_output, true)
}
}
} else {
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
(error_output, true)
};
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
continue;
}
Err(e) => {
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
let error_output = serde_json::json!({ "error": e.to_string() });
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
}
}
LoopGuardResult::Blocked => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
} else {
// Legacy inline loop guard path
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
match guard_result {
LoopGuardResult::CircuitBreaker => {
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
break 'outer;
}
LoopGuardResult::Blocked => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
}
LoopGuardResult::Warn => {
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
}
LoopGuardResult::Allowed => {}
}
LoopGuardResult::Warn => {
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
}
LoopGuardResult::Allowed => {}
}
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
let pv = path_validator.clone().unwrap_or_else(|| {

View File

@@ -0,0 +1,252 @@
//! Agent middleware system — composable hooks for cross-cutting concerns.
//!
//! Inspired by [DeerFlow 2.0](https://github.com/bytedance/deer-flow)'s 9-layer middleware chain,
//! this module provides a standardised way to inject behaviour before/after LLM completions
//! and tool calls without modifying the core `AgentLoop` logic.
//!
//! # Priority convention
//!
//! | Range | Category | Example |
//! |---------|----------------|-----------------------------|
//! | 100-199 | Context shaping| Compaction, MemoryInject |
//! | 200-399 | Capability | SkillIndex, Guardrail |
//! | 400-599 | Safety | LoopGuard, Guardrail |
//! | 600-799 | Telemetry | TokenCalibration, Tracking |
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use zclaw_types::{AgentId, Result, SessionId};
use crate::driver::ContentBlock;
// ---------------------------------------------------------------------------
// Decisions returned by middleware hooks
// ---------------------------------------------------------------------------
/// Decision returned by `before_completion`.
#[derive(Debug, Clone)]
pub enum MiddlewareDecision {
/// Continue to the next middleware / proceed with the LLM call.
Continue,
/// Abort the agent loop and return *reason* to the caller.
Stop(String),
}
/// Decision returned by `before_tool_call`.
#[derive(Debug, Clone)]
pub enum ToolCallDecision {
/// Allow the tool call to proceed unchanged.
Allow,
/// Block the call and return *message* as a tool-error to the LLM.
Block(String),
/// Allow the call but replace the tool input with *new_input*.
ReplaceInput(Value),
}
// ---------------------------------------------------------------------------
// Middleware context — shared mutable state passed through the chain
// ---------------------------------------------------------------------------
/// Carries the mutable state that middleware may inspect or modify.
pub struct MiddlewareContext {
/// The agent that owns this loop.
pub agent_id: AgentId,
/// Current session.
pub session_id: SessionId,
/// The raw user input that started this turn.
pub user_input: String,
// -- mutable state -------------------------------------------------------
/// System prompt — middleware may prepend/append context.
pub system_prompt: String,
/// Conversation messages sent to the LLM.
pub messages: Vec<zclaw_types::Message>,
/// Accumulated LLM content blocks from the current response.
pub response_content: Vec<ContentBlock>,
/// Token usage reported by the LLM driver (updated after each call).
pub input_tokens: u32,
pub output_tokens: u32,
}
impl std::fmt::Debug for MiddlewareContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MiddlewareContext")
.field("agent_id", &self.agent_id)
.field("session_id", &self.session_id)
.field("messages", &self.messages.len())
.field("input_tokens", &self.input_tokens)
.field("output_tokens", &self.output_tokens)
.finish()
}
}
// ---------------------------------------------------------------------------
// Core trait
// ---------------------------------------------------------------------------
/// A composable middleware hook for the agent loop.
///
/// Each middleware focuses on one cross-cutting concern and is executed
/// in `priority` order (ascending). All hook methods have default no-op
/// implementations so implementors only override what they need.
#[async_trait]
pub trait AgentMiddleware: Send + Sync {
/// Human-readable name for logging / debugging.
fn name(&self) -> &str;
/// Execution priority — lower values run first.
fn priority(&self) -> i32 {
500
}
/// Hook executed **before** the LLM completion request is sent.
///
/// Use this to inject context (memory, skill index, etc.) or to
/// trigger pre-processing (compaction, summarisation).
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
Ok(MiddlewareDecision::Continue)
}
/// Hook executed **before** each tool call.
///
/// Return `Block` to prevent execution and feed an error back to
/// the LLM, or `ReplaceInput` to sanitise / modify the arguments.
async fn before_tool_call(
&self,
_ctx: &MiddlewareContext,
_tool_name: &str,
_tool_input: &Value,
) -> Result<ToolCallDecision> {
Ok(ToolCallDecision::Allow)
}
/// Hook executed **after** each tool call.
async fn after_tool_call(
&self,
_ctx: &mut MiddlewareContext,
_tool_name: &str,
_result: &Value,
) -> Result<()> {
Ok(())
}
/// Hook executed **after** the entire agent loop turn completes.
///
/// Use this for post-processing (memory extraction, telemetry, etc.).
async fn after_completion(&self, _ctx: &MiddlewareContext) -> Result<()> {
Ok(())
}
}
// ---------------------------------------------------------------------------
// Middleware chain — ordered collection with run methods
// ---------------------------------------------------------------------------
/// An ordered chain of `AgentMiddleware` instances.
pub struct MiddlewareChain {
middlewares: Vec<Arc<dyn AgentMiddleware>>,
}
impl MiddlewareChain {
/// Create an empty chain.
pub fn new() -> Self {
Self { middlewares: Vec::new() }
}
/// Register a middleware. The chain is kept sorted by `priority`
/// (ascending) and by registration order within the same priority.
pub fn register(&mut self, mw: Arc<dyn AgentMiddleware>) {
let p = mw.priority();
let pos = self.middlewares.iter().position(|m| m.priority() > p).unwrap_or(self.middlewares.len());
self.middlewares.insert(pos, mw);
}
/// Run all `before_completion` hooks in order.
pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
for mw in &self.middlewares {
match mw.before_completion(ctx).await? {
MiddlewareDecision::Continue => {}
MiddlewareDecision::Stop(reason) => {
tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason);
return Ok(MiddlewareDecision::Stop(reason));
}
}
}
Ok(MiddlewareDecision::Continue)
}
/// Run all `before_tool_call` hooks in order.
pub async fn run_before_tool_call(
&self,
ctx: &MiddlewareContext,
tool_name: &str,
tool_input: &Value,
) -> Result<ToolCallDecision> {
for mw in &self.middlewares {
match mw.before_tool_call(ctx, tool_name, tool_input).await? {
ToolCallDecision::Allow => {}
other => {
tracing::info!("[MiddlewareChain] '{}' decided {:?} for tool '{}'", mw.name(), other, tool_name);
return Ok(other);
}
}
}
Ok(ToolCallDecision::Allow)
}
/// Run all `after_tool_call` hooks in order.
pub async fn run_after_tool_call(
&self,
ctx: &mut MiddlewareContext,
tool_name: &str,
result: &Value,
) -> Result<()> {
for mw in &self.middlewares {
mw.after_tool_call(ctx, tool_name, result).await?;
}
Ok(())
}
/// Run all `after_completion` hooks in order.
pub async fn run_after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
for mw in &self.middlewares {
mw.after_completion(ctx).await?;
}
Ok(())
}
/// Number of registered middlewares.
pub fn len(&self) -> usize {
self.middlewares.len()
}
/// Whether the chain is empty.
pub fn is_empty(&self) -> bool {
self.middlewares.is_empty()
}
}
impl Clone for MiddlewareChain {
fn clone(&self) -> Self {
Self {
middlewares: self.middlewares.clone(), // Arc clone — cheap ref-count bump
}
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
// ---------------------------------------------------------------------------
// Sub-modules — concrete middleware implementations
// ---------------------------------------------------------------------------
pub mod compaction;
pub mod guardrail;
pub mod loop_guard;
pub mod memory;
pub mod skill_index;
pub mod token_calibration;

View File

@@ -0,0 +1,61 @@
//! Compaction middleware — wraps the existing compaction module.
use async_trait::async_trait;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
use crate::compaction::{self, CompactionConfig};
use crate::growth::GrowthIntegration;
use crate::driver::LlmDriver;
use std::sync::Arc;
/// Middleware that compresses conversation history when it exceeds a token threshold.
pub struct CompactionMiddleware {
threshold: usize,
config: CompactionConfig,
/// Optional LLM driver for async compaction (LLM summarisation, memory flush).
driver: Option<Arc<dyn LlmDriver>>,
/// Optional growth integration for memory flushing during compaction.
growth: Option<GrowthIntegration>,
}
impl CompactionMiddleware {
pub fn new(
threshold: usize,
config: CompactionConfig,
driver: Option<Arc<dyn LlmDriver>>,
growth: Option<GrowthIntegration>,
) -> Self {
Self { threshold, config, driver, growth }
}
}
#[async_trait]
impl AgentMiddleware for CompactionMiddleware {
fn name(&self) -> &str { "compaction" }
fn priority(&self) -> i32 { 100 }
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
if self.threshold == 0 {
return Ok(MiddlewareDecision::Continue);
}
let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
if needs_async {
let outcome = compaction::maybe_compact_with_config(
ctx.messages.clone(),
self.threshold,
&self.config,
&ctx.agent_id,
&ctx.session_id,
self.driver.as_ref(),
self.growth.as_ref(),
)
.await;
ctx.messages = outcome.messages;
} else {
ctx.messages = compaction::maybe_compact(ctx.messages.clone(), self.threshold);
}
Ok(MiddlewareDecision::Continue)
}
}

View File

@@ -0,0 +1,223 @@
//! Guardrail middleware — configurable safety rules for tool call evaluation.
//!
//! This middleware inspects tool calls before execution and can block or
//! modify them based on configurable rules. Inspired by DeerFlow's safety
//! evaluation hooks.
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
/// A single guardrail rule that can inspect and decide on tool calls.
pub trait GuardrailRule: Send + Sync {
/// Human-readable name for logging.
fn name(&self) -> &str;
/// Evaluate a tool call.
fn evaluate(&self, tool_name: &str, tool_input: &Value) -> GuardrailVerdict;
}
/// Decision returned by a guardrail rule.
#[derive(Debug, Clone)]
pub enum GuardrailVerdict {
/// Allow the tool call to proceed.
Allow,
/// Block the call and return *message* as an error to the LLM.
Block(String),
}
/// Middleware that evaluates tool calls against a set of configurable safety rules.
///
/// Rules are grouped by tool name. When a tool call is made, all rules for
/// that tool are evaluated in order. If any rule returns `Block`, the call
/// is blocked. This is a "deny-by-exception" model — calls are allowed unless
/// a rule explicitly blocks them.
pub struct GuardrailMiddleware {
/// Rules keyed by tool name.
rules: HashMap<String, Vec<Box<dyn GuardrailRule>>>,
/// Default policy for tools with no specific rules: true = allow, false = block.
fail_open: bool,
}
impl GuardrailMiddleware {
pub fn new(fail_open: bool) -> Self {
Self {
rules: HashMap::new(),
fail_open,
}
}
/// Register a guardrail rule for a specific tool.
pub fn add_rule(&mut self, tool_name: impl Into<String>, rule: Box<dyn GuardrailRule>) {
self.rules.entry(tool_name.into()).or_default().push(rule);
}
/// Register built-in safety rules (shell_exec, file_write, web_fetch).
pub fn with_builtin_rules(mut self) -> Self {
self.add_rule("shell_exec", Box::new(ShellExecRule));
self.add_rule("file_write", Box::new(FileWriteRule));
self.add_rule("web_fetch", Box::new(WebFetchRule));
self
}
}
#[async_trait]
impl AgentMiddleware for GuardrailMiddleware {
fn name(&self) -> &str { "guardrail" }
fn priority(&self) -> i32 { 400 }
async fn before_tool_call(
&self,
_ctx: &MiddlewareContext,
tool_name: &str,
tool_input: &Value,
) -> Result<ToolCallDecision> {
if let Some(rules) = self.rules.get(tool_name) {
for rule in rules {
match rule.evaluate(tool_name, tool_input) {
GuardrailVerdict::Allow => {}
GuardrailVerdict::Block(msg) => {
tracing::warn!(
"[GuardrailMiddleware] Rule '{}' blocked tool '{}': {}",
rule.name(),
tool_name,
msg
);
return Ok(ToolCallDecision::Block(msg));
}
}
}
} else if !self.fail_open {
// fail-closed: unknown tools are blocked
tracing::warn!(
"[GuardrailMiddleware] No rules for tool '{}', fail-closed policy blocks it",
tool_name
);
return Ok(ToolCallDecision::Block(format!(
"工具 '{}' 未注册安全规则fail-closed 策略阻止执行",
tool_name
)));
}
Ok(ToolCallDecision::Allow)
}
}
// ---------------------------------------------------------------------------
// Built-in rules
// ---------------------------------------------------------------------------
/// Rule that blocks dangerous shell commands.
pub struct ShellExecRule;
impl GuardrailRule for ShellExecRule {
fn name(&self) -> &str { "shell_exec_dangerous_commands" }
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
let cmd = tool_input["command"].as_str().unwrap_or("");
let dangerous = [
"rm -rf /",
"rm -rf ~",
"del /s /q C:\\",
"format ",
"mkfs.",
"dd if=",
":(){ :|:& };:", // fork bomb
"> /dev/sda",
"shutdown",
"reboot",
];
let cmd_lower = cmd.to_lowercase();
for pattern in &dangerous {
if cmd_lower.contains(pattern) {
return GuardrailVerdict::Block(format!(
"危险命令被安全护栏拦截: 包含 '{}'",
pattern
));
}
}
GuardrailVerdict::Allow
}
}
/// Rule that blocks writes to critical system directories.
pub struct FileWriteRule;
impl GuardrailRule for FileWriteRule {
fn name(&self) -> &str { "file_write_critical_dirs" }
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
let path = tool_input["path"].as_str().unwrap_or("");
let critical_prefixes = [
"/etc/",
"/usr/",
"/bin/",
"/sbin/",
"/boot/",
"/System/",
"/Library/",
"C:\\Windows\\",
"C:\\Program Files\\",
"C:\\ProgramData\\",
];
let path_lower = path.to_lowercase();
for prefix in &critical_prefixes {
if path_lower.starts_with(&prefix.to_lowercase()) {
return GuardrailVerdict::Block(format!(
"写入系统关键目录被拦截: {}",
path
));
}
}
GuardrailVerdict::Allow
}
}
/// Rule that blocks web requests to internal/private network addresses.
pub struct WebFetchRule;
impl GuardrailRule for WebFetchRule {
fn name(&self) -> &str { "web_fetch_private_network" }
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
let url = tool_input["url"].as_str().unwrap_or("");
let blocked = [
"localhost",
"127.0.0.1",
"0.0.0.0",
"10.",
"172.16.",
"172.17.",
"172.18.",
"172.19.",
"172.20.",
"172.21.",
"172.22.",
"172.23.",
"172.24.",
"172.25.",
"172.26.",
"172.27.",
"172.28.",
"172.29.",
"172.30.",
"172.31.",
"192.168.",
"::1",
"169.254.",
"metadata.google",
"metadata.azure",
];
let url_lower = url.to_lowercase();
for prefix in &blocked {
if url_lower.contains(prefix) {
return GuardrailVerdict::Block(format!(
"请求内网/私有地址被拦截: {}",
url
));
}
}
GuardrailVerdict::Allow
}
}

View File

@@ -0,0 +1,57 @@
//! Loop guard middleware — extracts loop detection into a middleware hook.
use async_trait::async_trait;
use serde_json::Value;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
use crate::loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
use std::sync::Mutex;
/// Middleware that detects and blocks repetitive tool-call loops.
pub struct LoopGuardMiddleware {
guard: Mutex<LoopGuard>,
}
impl LoopGuardMiddleware {
pub fn new(config: LoopGuardConfig) -> Self {
Self {
guard: Mutex::new(LoopGuard::new(config)),
}
}
pub fn with_defaults() -> Self {
Self {
guard: Mutex::new(LoopGuard::default()),
}
}
}
#[async_trait]
impl AgentMiddleware for LoopGuardMiddleware {
fn name(&self) -> &str { "loop_guard" }
fn priority(&self) -> i32 { 500 }
async fn before_tool_call(
&self,
_ctx: &MiddlewareContext,
tool_name: &str,
tool_input: &Value,
) -> Result<ToolCallDecision> {
let result = self.guard.lock().unwrap().check(tool_name, tool_input);
match result {
LoopGuardResult::CircuitBreaker => {
tracing::warn!("[LoopGuardMiddleware] Circuit breaker triggered by tool '{}'", tool_name);
Ok(ToolCallDecision::Block("检测到工具调用循环,已自动终止".to_string()))
}
LoopGuardResult::Blocked => {
tracing::warn!("[LoopGuardMiddleware] Tool '{}' blocked", tool_name);
Ok(ToolCallDecision::Block("工具调用被循环防护拦截".to_string()))
}
LoopGuardResult::Warn => {
tracing::warn!("[LoopGuardMiddleware] Tool '{}' triggered warning", tool_name);
Ok(ToolCallDecision::Allow)
}
LoopGuardResult::Allowed => Ok(ToolCallDecision::Allow),
}
}
}

View File

@@ -0,0 +1,115 @@
//! Memory middleware — unified pre/post hooks for memory retrieval and extraction.
//!
//! This middleware unifies the memory lifecycle:
//! - `before_completion`: retrieves relevant memories and injects them into the system prompt
//! - `after_completion`: extracts learnings from the conversation and stores them
//!
//! It replaces both the inline `GrowthIntegration` calls in `AgentLoop` and the
//! `intelligence_hooks` calls in the Tauri desktop layer.
use async_trait::async_trait;
use zclaw_types::Result;
use crate::growth::GrowthIntegration;
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
/// Middleware that handles memory retrieval (pre-completion) and extraction (post-completion).
///
/// Wraps `GrowthIntegration` and delegates:
/// - `before_completion` → `enhance_prompt()` for memory injection
/// - `after_completion` → `process_conversation()` for memory extraction
pub struct MemoryMiddleware {
growth: GrowthIntegration,
/// Minimum seconds between extractions for the same agent (debounce).
debounce_secs: u64,
/// Timestamp of last extraction per agent (for debouncing).
last_extraction: std::sync::Mutex<std::collections::HashMap<String, std::time::Instant>>,
}
impl MemoryMiddleware {
pub fn new(growth: GrowthIntegration) -> Self {
Self {
growth,
debounce_secs: 30,
last_extraction: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
/// Set the debounce interval in seconds.
pub fn with_debounce_secs(mut self, secs: u64) -> Self {
self.debounce_secs = secs;
self
}
/// Check if enough time has passed since the last extraction for this agent.
fn should_extract(&self, agent_id: &str) -> bool {
let now = std::time::Instant::now();
let mut map = self.last_extraction.lock().unwrap();
if let Some(last) = map.get(agent_id) {
if now.duration_since(*last).as_secs() < self.debounce_secs {
return false;
}
}
map.insert(agent_id.to_string(), now);
true
}
}
#[async_trait]
impl AgentMiddleware for MemoryMiddleware {
fn name(&self) -> &str { "memory" }
fn priority(&self) -> i32 { 150 }
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
match self.growth.enhance_prompt(
&ctx.agent_id,
&ctx.system_prompt,
&ctx.user_input,
).await {
Ok(enhanced) => {
ctx.system_prompt = enhanced;
Ok(MiddlewareDecision::Continue)
}
Err(e) => {
// Non-fatal: memory retrieval failure should not block the loop
tracing::warn!("[MemoryMiddleware] Prompt enhancement failed: {}", e);
Ok(MiddlewareDecision::Continue)
}
}
}
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
// Debounce: skip extraction if called too recently for this agent
let agent_key = ctx.agent_id.to_string();
if !self.should_extract(&agent_key) {
tracing::debug!(
"[MemoryMiddleware] Skipping extraction for agent {} (debounced)",
agent_key
);
return Ok(());
}
if ctx.messages.is_empty() {
return Ok(());
}
match self.growth.process_conversation(
&ctx.agent_id,
&ctx.messages,
ctx.session_id.clone(),
).await {
Ok(count) => {
tracing::info!(
"[MemoryMiddleware] Extracted {} memories for agent {}",
count,
agent_key
);
}
Err(e) => {
// Non-fatal: extraction failure should not affect the response
tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e);
}
}
Ok(())
}
}

View File

@@ -0,0 +1,62 @@
//! Skill index middleware — injects a lightweight skill index into the system prompt.
//!
//! Instead of embedding full skill descriptions (which can consume ~2000 tokens for 70+ skills),
//! this middleware injects only skill IDs and one-line triggers (~600 tokens). The LLM can then
//! call the `skill_load` tool on demand to retrieve full skill details when needed.
use async_trait::async_trait;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
use crate::tool::{SkillIndexEntry, SkillExecutor};
use std::sync::Arc;
/// Middleware that injects a lightweight skill index into the system prompt.
///
/// The index format is compact:
/// ```text
/// ## Skills (index — use skill_load for details)
/// - finance-tracker: 财务分析、财报解读 [数据分析]
/// - senior-developer: 代码开发、架构设计 [开发工程]
/// ```
pub struct SkillIndexMiddleware {
/// Pre-built skill index entries, constructed at chain creation time.
entries: Vec<SkillIndexEntry>,
}
impl SkillIndexMiddleware {
pub fn new(entries: Vec<SkillIndexEntry>) -> Self {
Self { entries }
}
/// Build index entries from a skill executor that supports listing.
pub fn from_executor(executor: &Arc<dyn SkillExecutor>) -> Self {
Self {
entries: executor.list_skill_index(),
}
}
}
#[async_trait]
impl AgentMiddleware for SkillIndexMiddleware {
fn name(&self) -> &str { "skill_index" }
fn priority(&self) -> i32 { 200 }
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
if self.entries.is_empty() {
return Ok(MiddlewareDecision::Continue);
}
let mut index = String::from("\n\n## Skills (index — call skill_load for details)\n\n");
for entry in &self.entries {
let triggers = if entry.triggers.is_empty() {
String::new()
} else {
format!("{}", entry.triggers.join(", "))
};
index.push_str(&format!("- **{}**: {}{}\n", entry.id, entry.description, triggers));
}
ctx.system_prompt.push_str(&index);
Ok(MiddlewareDecision::Continue)
}
}

View File

@@ -0,0 +1,52 @@
//! Token calibration middleware — calibrates token estimation after first LLM response.
use async_trait::async_trait;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
use crate::compaction;
/// Middleware that calibrates the global token estimation factor based on
/// actual API-returned token counts from the first LLM response.
pub struct TokenCalibrationMiddleware {
/// Whether calibration has already been applied in this session.
calibrated: std::sync::atomic::AtomicBool,
}
impl TokenCalibrationMiddleware {
pub fn new() -> Self {
Self {
calibrated: std::sync::atomic::AtomicBool::new(false),
}
}
}
impl Default for TokenCalibrationMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AgentMiddleware for TokenCalibrationMiddleware {
fn name(&self) -> &str { "token_calibration" }
fn priority(&self) -> i32 { 700 }
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
// Calibration happens in after_completion when we have actual token counts.
// Before-completion is a no-op.
Ok(MiddlewareDecision::Continue)
}
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
if ctx.input_tokens > 0 && !self.calibrated.load(std::sync::atomic::Ordering::Relaxed) {
let estimated = compaction::estimate_messages_tokens(&ctx.messages);
compaction::update_calibration(estimated, ctx.input_tokens);
self.calibrated.store(true, std::sync::atomic::Ordering::Relaxed);
tracing::debug!(
"[TokenCalibrationMiddleware] Calibrated: estimated={}, actual={}",
estimated, ctx.input_tokens
);
}
Ok(())
}
}

View File

@@ -37,6 +37,39 @@ pub trait SkillExecutor: Send + Sync {
session_id: &str,
input: Value,
) -> Result<Value>;
/// Return metadata for on-demand skill loading.
/// Default returns `None` (skill detail not available).
fn get_skill_detail(&self, skill_id: &str) -> Option<SkillDetail> {
let _ = skill_id;
None
}
/// Return lightweight index of all available skills.
/// Default returns empty (no index available).
fn list_skill_index(&self) -> Vec<SkillIndexEntry> {
Vec::new()
}
}
/// Lightweight skill index entry for system prompt injection.
#[derive(Debug, Clone, serde::Serialize)]
pub struct SkillIndexEntry {
pub id: String,
pub description: String,
pub triggers: Vec<String>,
}
/// Full skill detail returned by `skill_load` tool.
#[derive(Debug, Clone, serde::Serialize)]
pub struct SkillDetail {
pub id: String,
pub name: String,
pub description: String,
pub category: Option<String>,
pub input_schema: Option<Value>,
pub triggers: Vec<String>,
pub capabilities: Vec<String>,
}
/// Context provided to tool execution

View File

@@ -5,6 +5,7 @@ mod file_write;
mod shell_exec;
mod web_fetch;
mod execute_skill;
mod skill_load;
mod path_validator;
pub use file_read::FileReadTool;
@@ -12,6 +13,7 @@ pub use file_write::FileWriteTool;
pub use shell_exec::ShellExecTool;
pub use web_fetch::WebFetchTool;
pub use execute_skill::ExecuteSkillTool;
pub use skill_load::SkillLoadTool;
pub use path_validator::{PathValidator, PathValidatorConfig};
use crate::tool::ToolRegistry;
@@ -23,4 +25,5 @@ pub fn register_builtin_tools(registry: &mut ToolRegistry) {
registry.register(Box::new(ShellExecTool::new()));
registry.register(Box::new(WebFetchTool::new()));
registry.register(Box::new(ExecuteSkillTool::new()));
registry.register(Box::new(SkillLoadTool::new()));
}

View File

@@ -0,0 +1,81 @@
//! Skill load tool — on-demand retrieval of full skill details.
//!
//! When the `SkillIndexMiddleware` is active, the system prompt contains only a lightweight
//! skill index. This tool allows the LLM to load full skill details (description, input schema,
//! capabilities) on demand, exactly when the LLM decides a particular skill is relevant.
use async_trait::async_trait;
use serde_json::{json, Value};
use zclaw_types::{Result, ZclawError};
use crate::tool::{Tool, ToolContext};
pub struct SkillLoadTool;
impl SkillLoadTool {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl Tool for SkillLoadTool {
fn name(&self) -> &str {
"skill_load"
}
fn description(&self) -> &str {
"Load full details for a skill by its ID. Use this when you need to understand a skill's \
input parameters, capabilities, or usage instructions before calling execute_skill. \
Returns the skill description, input schema, and trigger conditions."
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"skill_id": {
"type": "string",
"description": "The ID of the skill to load details for"
}
},
"required": ["skill_id"]
})
}
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
let skill_id = input["skill_id"].as_str()
.ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?;
let executor = context.skill_executor.as_ref()
.ok_or_else(|| ZclawError::ToolError("Skill executor not available".into()))?;
match executor.get_skill_detail(skill_id) {
Some(detail) => {
let mut result = json!({
"id": detail.id,
"name": detail.name,
"description": detail.description,
"triggers": detail.triggers,
});
if let Some(schema) = &detail.input_schema {
result["input_schema"] = schema.clone();
}
if let Some(cat) = &detail.category {
result["category"] = json!(cat);
}
if !detail.capabilities.is_empty() {
result["capabilities"] = json!(detail.capabilities);
}
Ok(result)
}
None => Err(ZclawError::ToolError(format!("Skill not found: {}", skill_id))),
}
}
}
impl Default for SkillLoadTool {
fn default() -> Self {
Self::new()
}
}

View File

@@ -133,6 +133,14 @@ impl SkillRegistry {
manifests.values().cloned().collect()
}
/// Synchronous snapshot of all manifests.
/// Uses `try_read` — returns empty map if write lock is held (should be rare at steady state).
pub fn manifests_snapshot(&self) -> HashMap<SkillId, SkillManifest> {
self.manifests.try_read()
.map(|guard| guard.clone())
.unwrap_or_default()
}
/// Execute a skill
pub async fn execute(
&self,