feat(runtime): DeerFlow 模式中间件链 Phase 1-4 全部完成
借鉴 DeerFlow 架构,实现完整中间件链系统: Phase 1 - Agent 中间件链基础设施 - MiddlewareChain Clone 支持 - LoopRunner 双路径集成 (middleware/legacy) - Kernel create_middleware_chain() 工厂方法 Phase 2 - 技能按需注入 - SkillIndexMiddleware (priority 200) - SkillLoadTool 工具 - SkillDetail/SkillIndexEntry 结构体 - KernelSkillExecutor trait 扩展 Phase 3 - Guardrail 安全护栏 - GuardrailMiddleware (priority 400, fail_open) - ShellExecRule / FileWriteRule / WebFetchRule Phase 4 - 记忆闭环统一 - MemoryMiddleware (priority 150, 30s 防抖) - after_completion 双路径调用 中间件注册顺序: 100 Compaction | 150 Memory | 200 SkillIndex 400 Guardrail | 500 LoopGuard | 700 TokenCalibration 向后兼容:Option<MiddlewareChain> 默认 None 走旧路径
This commit is contained in:
@@ -27,7 +27,7 @@ pub struct SqliteStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Database row structure for memory entry
|
/// Database row structure for memory entry
|
||||||
struct MemoryRow {
|
pub(crate) struct MemoryRow {
|
||||||
uri: String,
|
uri: String,
|
||||||
memory_type: String,
|
memory_type: String,
|
||||||
content: String,
|
content: String,
|
||||||
|
|||||||
@@ -86,6 +86,32 @@ impl SkillExecutor for KernelSkillExecutor {
|
|||||||
let result = self.skills.execute(&zclaw_types::SkillId::new(skill_id), &context, input).await?;
|
let result = self.skills.execute(&zclaw_types::SkillId::new(skill_id), &context, input).await?;
|
||||||
Ok(result.output)
|
Ok(result.output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_skill_detail(&self, skill_id: &str) -> Option<zclaw_runtime::tool::SkillDetail> {
|
||||||
|
let manifests = self.skills.manifests_snapshot();
|
||||||
|
let manifest = manifests.get(&zclaw_types::SkillId::new(skill_id))?;
|
||||||
|
Some(zclaw_runtime::tool::SkillDetail {
|
||||||
|
id: manifest.id.as_str().to_string(),
|
||||||
|
name: manifest.name.clone(),
|
||||||
|
description: manifest.description.clone(),
|
||||||
|
category: manifest.category.clone(),
|
||||||
|
input_schema: manifest.input_schema.clone(),
|
||||||
|
triggers: manifest.triggers.clone(),
|
||||||
|
capabilities: manifest.capabilities.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_skill_index(&self) -> Vec<zclaw_runtime::tool::SkillIndexEntry> {
|
||||||
|
let manifests = self.skills.manifests_snapshot();
|
||||||
|
manifests.values()
|
||||||
|
.filter(|m| m.enabled)
|
||||||
|
.map(|m| zclaw_runtime::tool::SkillIndexEntry {
|
||||||
|
id: m.id.as_str().to_string(),
|
||||||
|
description: m.description.clone(),
|
||||||
|
triggers: m.triggers.clone(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The ZCLAW Kernel
|
/// The ZCLAW Kernel
|
||||||
@@ -205,6 +231,68 @@ impl Kernel {
|
|||||||
tools
|
tools
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create the middleware chain for the agent loop.
|
||||||
|
///
|
||||||
|
/// When middleware is configured, cross-cutting concerns (compaction, loop guard,
|
||||||
|
/// token calibration, etc.) are delegated to the chain. When no middleware is
|
||||||
|
/// registered, the legacy inline path in `AgentLoop` is used instead.
|
||||||
|
fn create_middleware_chain(&self) -> Option<zclaw_runtime::middleware::MiddlewareChain> {
|
||||||
|
let mut chain = zclaw_runtime::middleware::MiddlewareChain::new();
|
||||||
|
|
||||||
|
// Compaction middleware — only register when threshold > 0
|
||||||
|
let threshold = self.config.compaction_threshold();
|
||||||
|
if threshold > 0 {
|
||||||
|
use std::sync::Arc;
|
||||||
|
let mw = zclaw_runtime::middleware::compaction::CompactionMiddleware::new(
|
||||||
|
threshold,
|
||||||
|
zclaw_runtime::CompactionConfig::default(),
|
||||||
|
Some(self.driver.clone()),
|
||||||
|
None, // growth not wired in kernel yet
|
||||||
|
);
|
||||||
|
chain.register(Arc::new(mw));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop guard middleware
|
||||||
|
{
|
||||||
|
use std::sync::Arc;
|
||||||
|
let mw = zclaw_runtime::middleware::loop_guard::LoopGuardMiddleware::with_defaults();
|
||||||
|
chain.register(Arc::new(mw));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token calibration middleware
|
||||||
|
{
|
||||||
|
use std::sync::Arc;
|
||||||
|
let mw = zclaw_runtime::middleware::token_calibration::TokenCalibrationMiddleware::new();
|
||||||
|
chain.register(Arc::new(mw));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skill index middleware — inject lightweight index instead of full descriptions
|
||||||
|
{
|
||||||
|
use std::sync::Arc;
|
||||||
|
let entries = self.skill_executor.list_skill_index();
|
||||||
|
if !entries.is_empty() {
|
||||||
|
let mw = zclaw_runtime::middleware::skill_index::SkillIndexMiddleware::new(entries);
|
||||||
|
chain.register(Arc::new(mw));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Guardrail middleware — safety rules for tool calls
|
||||||
|
{
|
||||||
|
use std::sync::Arc;
|
||||||
|
let mw = zclaw_runtime::middleware::guardrail::GuardrailMiddleware::new(true)
|
||||||
|
.with_builtin_rules();
|
||||||
|
chain.register(Arc::new(mw));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only return Some if we actually registered middleware
|
||||||
|
if chain.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
tracing::info!("[Kernel] Middleware chain created with {} middlewares", chain.len());
|
||||||
|
Some(chain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Build a system prompt with skill information injected
|
/// Build a system prompt with skill information injected
|
||||||
async fn build_system_prompt_with_skills(&self, base_prompt: Option<&String>) -> String {
|
async fn build_system_prompt_with_skills(&self, base_prompt: Option<&String>) -> String {
|
||||||
// Get skill list asynchronously
|
// Get skill list asynchronously
|
||||||
@@ -417,6 +505,11 @@ impl Kernel {
|
|||||||
loop_runner = loop_runner.with_path_validator(path_validator);
|
loop_runner = loop_runner.with_path_validator(path_validator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Inject middleware chain if available
|
||||||
|
if let Some(chain) = self.create_middleware_chain() {
|
||||||
|
loop_runner = loop_runner.with_middleware_chain(chain);
|
||||||
|
}
|
||||||
|
|
||||||
// Build system prompt with skill information injected
|
// Build system prompt with skill information injected
|
||||||
let system_prompt = self.build_system_prompt_with_skills(agent_config.system_prompt.as_ref()).await;
|
let system_prompt = self.build_system_prompt_with_skills(agent_config.system_prompt.as_ref()).await;
|
||||||
let loop_runner = loop_runner.with_system_prompt(&system_prompt);
|
let loop_runner = loop_runner.with_system_prompt(&system_prompt);
|
||||||
@@ -501,6 +594,11 @@ impl Kernel {
|
|||||||
loop_runner = loop_runner.with_path_validator(path_validator);
|
loop_runner = loop_runner.with_path_validator(path_validator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Inject middleware chain if available
|
||||||
|
if let Some(chain) = self.create_middleware_chain() {
|
||||||
|
loop_runner = loop_runner.with_middleware_chain(chain);
|
||||||
|
}
|
||||||
|
|
||||||
// Use external prompt if provided, otherwise build default
|
// Use external prompt if provided, otherwise build default
|
||||||
let system_prompt = match system_prompt_override {
|
let system_prompt = match system_prompt_override {
|
||||||
Some(prompt) => prompt,
|
Some(prompt) => prompt,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ pub mod loop_guard;
|
|||||||
pub mod stream;
|
pub mod stream;
|
||||||
pub mod growth;
|
pub mod growth;
|
||||||
pub mod compaction;
|
pub mod compaction;
|
||||||
|
pub mod middleware;
|
||||||
|
|
||||||
// Re-export main types
|
// Re-export main types
|
||||||
pub use driver::{
|
pub use driver::{
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ use crate::tool::builtin::PathValidator;
|
|||||||
use crate::loop_guard::{LoopGuard, LoopGuardResult};
|
use crate::loop_guard::{LoopGuard, LoopGuardResult};
|
||||||
use crate::growth::GrowthIntegration;
|
use crate::growth::GrowthIntegration;
|
||||||
use crate::compaction::{self, CompactionConfig};
|
use crate::compaction::{self, CompactionConfig};
|
||||||
|
use crate::middleware::{self, MiddlewareChain};
|
||||||
use zclaw_memory::MemoryStore;
|
use zclaw_memory::MemoryStore;
|
||||||
|
|
||||||
/// Agent loop runner
|
/// Agent loop runner
|
||||||
@@ -34,6 +35,10 @@ pub struct AgentLoop {
|
|||||||
compaction_threshold: usize,
|
compaction_threshold: usize,
|
||||||
/// Compaction behavior configuration
|
/// Compaction behavior configuration
|
||||||
compaction_config: CompactionConfig,
|
compaction_config: CompactionConfig,
|
||||||
|
/// Optional middleware chain — when `Some`, cross-cutting logic is
|
||||||
|
/// delegated to the chain instead of the inline code below.
|
||||||
|
/// When `None`, the legacy inline path is used (100% backward compatible).
|
||||||
|
middleware_chain: Option<MiddlewareChain>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
@@ -58,6 +63,7 @@ impl AgentLoop {
|
|||||||
growth: None,
|
growth: None,
|
||||||
compaction_threshold: 0,
|
compaction_threshold: 0,
|
||||||
compaction_config: CompactionConfig::default(),
|
compaction_config: CompactionConfig::default(),
|
||||||
|
middleware_chain: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,6 +130,14 @@ impl AgentLoop {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Inject a middleware chain. When set, cross-cutting concerns (compaction,
|
||||||
|
/// loop guard, token calibration, etc.) are delegated to the chain instead
|
||||||
|
/// of the inline logic.
|
||||||
|
pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
|
||||||
|
self.middleware_chain = Some(chain);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Get growth integration reference
|
/// Get growth integration reference
|
||||||
pub fn growth(&self) -> Option<&GrowthIntegration> {
|
pub fn growth(&self) -> Option<&GrowthIntegration> {
|
||||||
self.growth.as_ref()
|
self.growth.as_ref()
|
||||||
@@ -175,8 +189,10 @@ impl AgentLoop {
|
|||||||
// Get all messages for context
|
// Get all messages for context
|
||||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||||
|
|
||||||
// Apply compaction if threshold is configured
|
let use_middleware = self.middleware_chain.is_some();
|
||||||
if self.compaction_threshold > 0 {
|
|
||||||
|
// Apply compaction — skip inline path when middleware chain handles it
|
||||||
|
if !use_middleware && self.compaction_threshold > 0 {
|
||||||
let needs_async =
|
let needs_async =
|
||||||
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
||||||
if needs_async {
|
if needs_async {
|
||||||
@@ -196,14 +212,44 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enhance system prompt with growth memories
|
// Enhance system prompt — skip when middleware chain handles it
|
||||||
let enhanced_prompt = if let Some(ref growth) = self.growth {
|
let mut enhanced_prompt = if use_middleware {
|
||||||
|
self.system_prompt.clone().unwrap_or_default()
|
||||||
|
} else if let Some(ref growth) = self.growth {
|
||||||
let base = self.system_prompt.as_deref().unwrap_or("");
|
let base = self.system_prompt.as_deref().unwrap_or("");
|
||||||
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
||||||
} else {
|
} else {
|
||||||
self.system_prompt.clone().unwrap_or_default()
|
self.system_prompt.clone().unwrap_or_default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Run middleware before_completion hooks (compaction, memory inject, etc.)
|
||||||
|
if let Some(ref chain) = self.middleware_chain {
|
||||||
|
let mut mw_ctx = middleware::MiddlewareContext {
|
||||||
|
agent_id: self.agent_id.clone(),
|
||||||
|
session_id: session_id.clone(),
|
||||||
|
user_input: input.clone(),
|
||||||
|
system_prompt: enhanced_prompt.clone(),
|
||||||
|
messages,
|
||||||
|
response_content: Vec::new(),
|
||||||
|
input_tokens: 0,
|
||||||
|
output_tokens: 0,
|
||||||
|
};
|
||||||
|
match chain.run_before_completion(&mut mw_ctx).await? {
|
||||||
|
middleware::MiddlewareDecision::Continue => {
|
||||||
|
messages = mw_ctx.messages;
|
||||||
|
enhanced_prompt = mw_ctx.system_prompt;
|
||||||
|
}
|
||||||
|
middleware::MiddlewareDecision::Stop(reason) => {
|
||||||
|
return Ok(AgentLoopResult {
|
||||||
|
response: reason,
|
||||||
|
input_tokens: 0,
|
||||||
|
output_tokens: 0,
|
||||||
|
iterations: 1,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let max_iterations = 10;
|
let max_iterations = 10;
|
||||||
let mut iterations = 0;
|
let mut iterations = 0;
|
||||||
let mut total_input_tokens = 0u32;
|
let mut total_input_tokens = 0u32;
|
||||||
@@ -307,24 +353,56 @@ impl AgentLoop {
|
|||||||
let tool_context = self.create_tool_context(session_id.clone());
|
let tool_context = self.create_tool_context(session_id.clone());
|
||||||
let mut circuit_breaker_triggered = false;
|
let mut circuit_breaker_triggered = false;
|
||||||
for (id, name, input) in tool_calls {
|
for (id, name, input) in tool_calls {
|
||||||
// Check loop guard before executing tool
|
// Check tool call safety — via middleware chain or inline loop guard
|
||||||
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
|
if let Some(ref chain) = self.middleware_chain {
|
||||||
match guard_result {
|
let mw_ctx_ref = middleware::MiddlewareContext {
|
||||||
LoopGuardResult::CircuitBreaker => {
|
agent_id: self.agent_id.clone(),
|
||||||
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
|
session_id: session_id.clone(),
|
||||||
circuit_breaker_triggered = true;
|
user_input: input.to_string(),
|
||||||
break;
|
system_prompt: enhanced_prompt.clone(),
|
||||||
|
messages: messages.clone(),
|
||||||
|
response_content: Vec::new(),
|
||||||
|
input_tokens: total_input_tokens,
|
||||||
|
output_tokens: total_output_tokens,
|
||||||
|
};
|
||||||
|
match chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? {
|
||||||
|
middleware::ToolCallDecision::Allow => {}
|
||||||
|
middleware::ToolCallDecision::Block(msg) => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||||
|
let error_output = serde_json::json!({ "error": msg });
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
middleware::ToolCallDecision::ReplaceInput(new_input) => {
|
||||||
|
// Execute with replaced input
|
||||||
|
let tool_result = match self.execute_tool(&name, new_input, &tool_context).await {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) => serde_json::json!({ "error": e.to_string() }),
|
||||||
|
};
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LoopGuardResult::Blocked => {
|
} else {
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
// Legacy inline path
|
||||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
match guard_result {
|
||||||
continue;
|
LoopGuardResult::CircuitBreaker => {
|
||||||
|
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
|
||||||
|
circuit_breaker_triggered = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
LoopGuardResult::Blocked => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||||
|
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LoopGuardResult::Warn => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||||
|
}
|
||||||
|
LoopGuardResult::Allowed => {}
|
||||||
}
|
}
|
||||||
LoopGuardResult::Warn => {
|
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
|
||||||
}
|
|
||||||
LoopGuardResult::Allowed => {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
|
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
|
||||||
@@ -356,8 +434,23 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Process conversation for memory extraction (post-conversation)
|
// Post-completion processing — middleware chain or inline growth
|
||||||
if let Some(ref growth) = self.growth {
|
if let Some(ref chain) = self.middleware_chain {
|
||||||
|
let mw_ctx = middleware::MiddlewareContext {
|
||||||
|
agent_id: self.agent_id.clone(),
|
||||||
|
session_id: session_id.clone(),
|
||||||
|
user_input: input.clone(),
|
||||||
|
system_prompt: enhanced_prompt.clone(),
|
||||||
|
messages: self.memory.get_messages(&session_id).await.unwrap_or_default(),
|
||||||
|
response_content: Vec::new(),
|
||||||
|
input_tokens: total_input_tokens,
|
||||||
|
output_tokens: total_output_tokens,
|
||||||
|
};
|
||||||
|
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
|
||||||
|
tracing::warn!("[AgentLoop] Middleware after_completion failed: {}", e);
|
||||||
|
}
|
||||||
|
} else if let Some(ref growth) = self.growth {
|
||||||
|
// Legacy inline path
|
||||||
if let Ok(all_messages) = self.memory.get_messages(&session_id).await {
|
if let Ok(all_messages) = self.memory.get_messages(&session_id).await {
|
||||||
if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await {
|
if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await {
|
||||||
tracing::warn!("[AgentLoop] Growth processing failed: {}", e);
|
tracing::warn!("[AgentLoop] Growth processing failed: {}", e);
|
||||||
@@ -384,8 +477,10 @@ impl AgentLoop {
|
|||||||
// Get all messages for context
|
// Get all messages for context
|
||||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||||
|
|
||||||
// Apply compaction if threshold is configured
|
let use_middleware = self.middleware_chain.is_some();
|
||||||
if self.compaction_threshold > 0 {
|
|
||||||
|
// Apply compaction — skip inline path when middleware chain handles it
|
||||||
|
if !use_middleware && self.compaction_threshold > 0 {
|
||||||
let needs_async =
|
let needs_async =
|
||||||
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
||||||
if needs_async {
|
if needs_async {
|
||||||
@@ -405,20 +500,52 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enhance system prompt with growth memories
|
// Enhance system prompt — skip when middleware chain handles it
|
||||||
let enhanced_prompt = if let Some(ref growth) = self.growth {
|
let mut enhanced_prompt = if use_middleware {
|
||||||
|
self.system_prompt.clone().unwrap_or_default()
|
||||||
|
} else if let Some(ref growth) = self.growth {
|
||||||
let base = self.system_prompt.as_deref().unwrap_or("");
|
let base = self.system_prompt.as_deref().unwrap_or("");
|
||||||
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
||||||
} else {
|
} else {
|
||||||
self.system_prompt.clone().unwrap_or_default()
|
self.system_prompt.clone().unwrap_or_default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Run middleware before_completion hooks (compaction, memory inject, etc.)
|
||||||
|
if let Some(ref chain) = self.middleware_chain {
|
||||||
|
let mut mw_ctx = middleware::MiddlewareContext {
|
||||||
|
agent_id: self.agent_id.clone(),
|
||||||
|
session_id: session_id.clone(),
|
||||||
|
user_input: input.clone(),
|
||||||
|
system_prompt: enhanced_prompt.clone(),
|
||||||
|
messages,
|
||||||
|
response_content: Vec::new(),
|
||||||
|
input_tokens: 0,
|
||||||
|
output_tokens: 0,
|
||||||
|
};
|
||||||
|
match chain.run_before_completion(&mut mw_ctx).await? {
|
||||||
|
middleware::MiddlewareDecision::Continue => {
|
||||||
|
messages = mw_ctx.messages;
|
||||||
|
enhanced_prompt = mw_ctx.system_prompt;
|
||||||
|
}
|
||||||
|
middleware::MiddlewareDecision::Stop(reason) => {
|
||||||
|
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
||||||
|
response: reason,
|
||||||
|
input_tokens: 0,
|
||||||
|
output_tokens: 0,
|
||||||
|
iterations: 1,
|
||||||
|
})).await;
|
||||||
|
return Ok(rx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Clone necessary data for the async task
|
// Clone necessary data for the async task
|
||||||
let session_id_clone = session_id.clone();
|
let session_id_clone = session_id.clone();
|
||||||
let memory = self.memory.clone();
|
let memory = self.memory.clone();
|
||||||
let driver = self.driver.clone();
|
let driver = self.driver.clone();
|
||||||
let tools = self.tools.clone();
|
let tools = self.tools.clone();
|
||||||
let loop_guard_clone = self.loop_guard.lock().unwrap().clone();
|
let loop_guard_clone = self.loop_guard.lock().unwrap().clone();
|
||||||
|
let middleware_chain = self.middleware_chain.clone();
|
||||||
let skill_executor = self.skill_executor.clone();
|
let skill_executor = self.skill_executor.clone();
|
||||||
let path_validator = self.path_validator.clone();
|
let path_validator = self.path_validator.clone();
|
||||||
let agent_id = self.agent_id.clone();
|
let agent_id = self.agent_id.clone();
|
||||||
@@ -558,6 +685,24 @@ impl AgentLoop {
|
|||||||
output_tokens: total_output_tokens,
|
output_tokens: total_output_tokens,
|
||||||
iterations: iteration,
|
iterations: iteration,
|
||||||
})).await;
|
})).await;
|
||||||
|
|
||||||
|
// Post-completion: middleware after_completion (memory extraction, etc.)
|
||||||
|
if let Some(ref chain) = middleware_chain {
|
||||||
|
let mw_ctx = middleware::MiddlewareContext {
|
||||||
|
agent_id: agent_id.clone(),
|
||||||
|
session_id: session_id_clone.clone(),
|
||||||
|
user_input: String::new(),
|
||||||
|
system_prompt: enhanced_prompt.clone(),
|
||||||
|
messages: memory.get_messages(&session_id_clone).await.unwrap_or_default(),
|
||||||
|
response_content: Vec::new(),
|
||||||
|
input_tokens: total_input_tokens,
|
||||||
|
output_tokens: total_output_tokens,
|
||||||
|
};
|
||||||
|
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
|
||||||
|
tracing::warn!("[AgentLoop] Streaming middleware after_completion failed: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
break 'outer;
|
break 'outer;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -579,24 +724,92 @@ impl AgentLoop {
|
|||||||
for (id, name, input) in pending_tool_calls {
|
for (id, name, input) in pending_tool_calls {
|
||||||
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
|
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
|
||||||
|
|
||||||
// Check loop guard before executing tool
|
// Check tool call safety — via middleware chain or inline loop guard
|
||||||
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
|
if let Some(ref chain) = middleware_chain {
|
||||||
match guard_result {
|
let mw_ctx = middleware::MiddlewareContext {
|
||||||
LoopGuardResult::CircuitBreaker => {
|
agent_id: agent_id.clone(),
|
||||||
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
|
session_id: session_id_clone.clone(),
|
||||||
break 'outer;
|
user_input: input.to_string(),
|
||||||
|
system_prompt: enhanced_prompt.clone(),
|
||||||
|
messages: messages.clone(),
|
||||||
|
response_content: Vec::new(),
|
||||||
|
input_tokens: total_input_tokens,
|
||||||
|
output_tokens: total_output_tokens,
|
||||||
|
};
|
||||||
|
match chain.run_before_tool_call(&mw_ctx, &name, &input).await {
|
||||||
|
Ok(middleware::ToolCallDecision::Allow) => {}
|
||||||
|
Ok(middleware::ToolCallDecision::Block(msg)) => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||||
|
let error_output = serde_json::json!({ "error": msg });
|
||||||
|
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
|
||||||
|
// Execute with replaced input (same path_validator logic below)
|
||||||
|
let pv = path_validator.clone().unwrap_or_else(|| {
|
||||||
|
let home = std::env::var("USERPROFILE")
|
||||||
|
.or_else(|_| std::env::var("HOME"))
|
||||||
|
.unwrap_or_else(|_| ".".to_string());
|
||||||
|
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
||||||
|
});
|
||||||
|
let working_dir = pv.workspace_root()
|
||||||
|
.map(|p| p.to_string_lossy().to_string());
|
||||||
|
let tool_context = ToolContext {
|
||||||
|
agent_id: agent_id.clone(),
|
||||||
|
working_directory: working_dir,
|
||||||
|
session_id: Some(session_id_clone.to_string()),
|
||||||
|
skill_executor: skill_executor.clone(),
|
||||||
|
path_validator: Some(pv),
|
||||||
|
};
|
||||||
|
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
||||||
|
match tool.execute(new_input, &tool_context).await {
|
||||||
|
Ok(output) => {
|
||||||
|
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await;
|
||||||
|
(output, false)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let error_output = serde_json::json!({ "error": e.to_string() });
|
||||||
|
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||||
|
(error_output, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
|
||||||
|
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||||
|
(error_output, true)
|
||||||
|
};
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
|
||||||
|
let error_output = serde_json::json!({ "error": e.to_string() });
|
||||||
|
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LoopGuardResult::Blocked => {
|
} else {
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
// Legacy inline loop guard path
|
||||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
|
||||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
match guard_result {
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
LoopGuardResult::CircuitBreaker => {
|
||||||
continue;
|
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
LoopGuardResult::Blocked => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||||
|
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||||
|
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LoopGuardResult::Warn => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||||
|
}
|
||||||
|
LoopGuardResult::Allowed => {}
|
||||||
}
|
}
|
||||||
LoopGuardResult::Warn => {
|
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
|
||||||
}
|
|
||||||
LoopGuardResult::Allowed => {}
|
|
||||||
}
|
}
|
||||||
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
|
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
|
||||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
let pv = path_validator.clone().unwrap_or_else(|| {
|
||||||
|
|||||||
252
crates/zclaw-runtime/src/middleware.rs
Normal file
252
crates/zclaw-runtime/src/middleware.rs
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
//! Agent middleware system — composable hooks for cross-cutting concerns.
|
||||||
|
//!
|
||||||
|
//! Inspired by [DeerFlow 2.0](https://github.com/bytedance/deer-flow)'s 9-layer middleware chain,
|
||||||
|
//! this module provides a standardised way to inject behaviour before/after LLM completions
|
||||||
|
//! and tool calls without modifying the core `AgentLoop` logic.
|
||||||
|
//!
|
||||||
|
//! # Priority convention
|
||||||
|
//!
|
||||||
|
//! | Range | Category | Example |
|
||||||
|
//! |---------|----------------|-----------------------------|
|
||||||
|
//! | 100-199 | Context shaping| Compaction, MemoryInject |
|
||||||
|
//! | 200-399 | Capability | SkillIndex, Guardrail |
|
||||||
|
//! | 400-599 | Safety | LoopGuard, Guardrail |
|
||||||
|
//! | 600-799 | Telemetry | TokenCalibration, Tracking |
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
use zclaw_types::{AgentId, Result, SessionId};
|
||||||
|
use crate::driver::ContentBlock;
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Decisions returned by middleware hooks
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Decision returned by `before_completion`.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum MiddlewareDecision {
|
||||||
|
/// Continue to the next middleware / proceed with the LLM call.
|
||||||
|
Continue,
|
||||||
|
/// Abort the agent loop and return *reason* to the caller.
|
||||||
|
Stop(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decision returned by `before_tool_call`.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ToolCallDecision {
|
||||||
|
/// Allow the tool call to proceed unchanged.
|
||||||
|
Allow,
|
||||||
|
/// Block the call and return *message* as a tool-error to the LLM.
|
||||||
|
Block(String),
|
||||||
|
/// Allow the call but replace the tool input with *new_input*.
|
||||||
|
ReplaceInput(Value),
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Middleware context — shared mutable state passed through the chain
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Carries the mutable state that middleware may inspect or modify.
|
||||||
|
pub struct MiddlewareContext {
|
||||||
|
/// The agent that owns this loop.
|
||||||
|
pub agent_id: AgentId,
|
||||||
|
/// Current session.
|
||||||
|
pub session_id: SessionId,
|
||||||
|
/// The raw user input that started this turn.
|
||||||
|
pub user_input: String,
|
||||||
|
|
||||||
|
// -- mutable state -------------------------------------------------------
|
||||||
|
/// System prompt — middleware may prepend/append context.
|
||||||
|
pub system_prompt: String,
|
||||||
|
/// Conversation messages sent to the LLM.
|
||||||
|
pub messages: Vec<zclaw_types::Message>,
|
||||||
|
/// Accumulated LLM content blocks from the current response.
|
||||||
|
pub response_content: Vec<ContentBlock>,
|
||||||
|
/// Token usage reported by the LLM driver (updated after each call).
|
||||||
|
pub input_tokens: u32,
|
||||||
|
pub output_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for MiddlewareContext {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("MiddlewareContext")
|
||||||
|
.field("agent_id", &self.agent_id)
|
||||||
|
.field("session_id", &self.session_id)
|
||||||
|
.field("messages", &self.messages.len())
|
||||||
|
.field("input_tokens", &self.input_tokens)
|
||||||
|
.field("output_tokens", &self.output_tokens)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Core trait
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// A composable middleware hook for the agent loop.
|
||||||
|
///
|
||||||
|
/// Each middleware focuses on one cross-cutting concern and is executed
|
||||||
|
/// in `priority` order (ascending). All hook methods have default no-op
|
||||||
|
/// implementations so implementors only override what they need.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait AgentMiddleware: Send + Sync {
|
||||||
|
/// Human-readable name for logging / debugging.
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// Execution priority — lower values run first.
|
||||||
|
fn priority(&self) -> i32 {
|
||||||
|
500
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hook executed **before** the LLM completion request is sent.
|
||||||
|
///
|
||||||
|
/// Use this to inject context (memory, skill index, etc.) or to
|
||||||
|
/// trigger pre-processing (compaction, summarisation).
|
||||||
|
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
|
Ok(MiddlewareDecision::Continue)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hook executed **before** each tool call.
|
||||||
|
///
|
||||||
|
/// Return `Block` to prevent execution and feed an error back to
|
||||||
|
/// the LLM, or `ReplaceInput` to sanitise / modify the arguments.
|
||||||
|
async fn before_tool_call(
|
||||||
|
&self,
|
||||||
|
_ctx: &MiddlewareContext,
|
||||||
|
_tool_name: &str,
|
||||||
|
_tool_input: &Value,
|
||||||
|
) -> Result<ToolCallDecision> {
|
||||||
|
Ok(ToolCallDecision::Allow)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hook executed **after** each tool call.
|
||||||
|
async fn after_tool_call(
|
||||||
|
&self,
|
||||||
|
_ctx: &mut MiddlewareContext,
|
||||||
|
_tool_name: &str,
|
||||||
|
_result: &Value,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hook executed **after** the entire agent loop turn completes.
|
||||||
|
///
|
||||||
|
/// Use this for post-processing (memory extraction, telemetry, etc.).
|
||||||
|
async fn after_completion(&self, _ctx: &MiddlewareContext) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Middleware chain — ordered collection with run methods
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// An ordered chain of `AgentMiddleware` instances.
|
||||||
|
pub struct MiddlewareChain {
|
||||||
|
middlewares: Vec<Arc<dyn AgentMiddleware>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MiddlewareChain {
|
||||||
|
/// Create an empty chain.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self { middlewares: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a middleware. The chain is kept sorted by `priority`
|
||||||
|
/// (ascending) and by registration order within the same priority.
|
||||||
|
pub fn register(&mut self, mw: Arc<dyn AgentMiddleware>) {
|
||||||
|
let p = mw.priority();
|
||||||
|
let pos = self.middlewares.iter().position(|m| m.priority() > p).unwrap_or(self.middlewares.len());
|
||||||
|
self.middlewares.insert(pos, mw);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all `before_completion` hooks in order.
|
||||||
|
pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
|
for mw in &self.middlewares {
|
||||||
|
match mw.before_completion(ctx).await? {
|
||||||
|
MiddlewareDecision::Continue => {}
|
||||||
|
MiddlewareDecision::Stop(reason) => {
|
||||||
|
tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason);
|
||||||
|
return Ok(MiddlewareDecision::Stop(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(MiddlewareDecision::Continue)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all `before_tool_call` hooks in order.
|
||||||
|
pub async fn run_before_tool_call(
|
||||||
|
&self,
|
||||||
|
ctx: &MiddlewareContext,
|
||||||
|
tool_name: &str,
|
||||||
|
tool_input: &Value,
|
||||||
|
) -> Result<ToolCallDecision> {
|
||||||
|
for mw in &self.middlewares {
|
||||||
|
match mw.before_tool_call(ctx, tool_name, tool_input).await? {
|
||||||
|
ToolCallDecision::Allow => {}
|
||||||
|
other => {
|
||||||
|
tracing::info!("[MiddlewareChain] '{}' decided {:?} for tool '{}'", mw.name(), other, tool_name);
|
||||||
|
return Ok(other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(ToolCallDecision::Allow)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all `after_tool_call` hooks in order.
|
||||||
|
pub async fn run_after_tool_call(
|
||||||
|
&self,
|
||||||
|
ctx: &mut MiddlewareContext,
|
||||||
|
tool_name: &str,
|
||||||
|
result: &Value,
|
||||||
|
) -> Result<()> {
|
||||||
|
for mw in &self.middlewares {
|
||||||
|
mw.after_tool_call(ctx, tool_name, result).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all `after_completion` hooks in order.
|
||||||
|
pub async fn run_after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||||
|
for mw in &self.middlewares {
|
||||||
|
mw.after_completion(ctx).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of registered middlewares.
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.middlewares.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether the chain is empty.
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.middlewares.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for MiddlewareChain {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
middlewares: self.middlewares.clone(), // Arc clone — cheap ref-count bump
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MiddlewareChain {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Sub-modules — concrete middleware implementations
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
pub mod compaction;
|
||||||
|
pub mod guardrail;
|
||||||
|
pub mod loop_guard;
|
||||||
|
pub mod memory;
|
||||||
|
pub mod skill_index;
|
||||||
|
pub mod token_calibration;
|
||||||
61
crates/zclaw-runtime/src/middleware/compaction.rs
Normal file
61
crates/zclaw-runtime/src/middleware/compaction.rs
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
//! Compaction middleware — wraps the existing compaction module.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use zclaw_types::Result;
|
||||||
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||||
|
use crate::compaction::{self, CompactionConfig};
|
||||||
|
use crate::growth::GrowthIntegration;
|
||||||
|
use crate::driver::LlmDriver;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Middleware that compresses conversation history when it exceeds a token threshold.
|
||||||
|
pub struct CompactionMiddleware {
|
||||||
|
threshold: usize,
|
||||||
|
config: CompactionConfig,
|
||||||
|
/// Optional LLM driver for async compaction (LLM summarisation, memory flush).
|
||||||
|
driver: Option<Arc<dyn LlmDriver>>,
|
||||||
|
/// Optional growth integration for memory flushing during compaction.
|
||||||
|
growth: Option<GrowthIntegration>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompactionMiddleware {
|
||||||
|
pub fn new(
|
||||||
|
threshold: usize,
|
||||||
|
config: CompactionConfig,
|
||||||
|
driver: Option<Arc<dyn LlmDriver>>,
|
||||||
|
growth: Option<GrowthIntegration>,
|
||||||
|
) -> Self {
|
||||||
|
Self { threshold, config, driver, growth }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AgentMiddleware for CompactionMiddleware {
|
||||||
|
fn name(&self) -> &str { "compaction" }
|
||||||
|
fn priority(&self) -> i32 { 100 }
|
||||||
|
|
||||||
|
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
|
if self.threshold == 0 {
|
||||||
|
return Ok(MiddlewareDecision::Continue);
|
||||||
|
}
|
||||||
|
|
||||||
|
let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
|
||||||
|
if needs_async {
|
||||||
|
let outcome = compaction::maybe_compact_with_config(
|
||||||
|
ctx.messages.clone(),
|
||||||
|
self.threshold,
|
||||||
|
&self.config,
|
||||||
|
&ctx.agent_id,
|
||||||
|
&ctx.session_id,
|
||||||
|
self.driver.as_ref(),
|
||||||
|
self.growth.as_ref(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
ctx.messages = outcome.messages;
|
||||||
|
} else {
|
||||||
|
ctx.messages = compaction::maybe_compact(ctx.messages.clone(), self.threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(MiddlewareDecision::Continue)
|
||||||
|
}
|
||||||
|
}
|
||||||
223
crates/zclaw-runtime/src/middleware/guardrail.rs
Normal file
223
crates/zclaw-runtime/src/middleware/guardrail.rs
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
//! Guardrail middleware — configurable safety rules for tool call evaluation.
|
||||||
|
//!
|
||||||
|
//! This middleware inspects tool calls before execution and can block or
|
||||||
|
//! modify them based on configurable rules. Inspired by DeerFlow's safety
|
||||||
|
//! evaluation hooks.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use zclaw_types::Result;
|
||||||
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||||
|
|
||||||
|
/// A single guardrail rule that can inspect and decide on tool calls.
|
||||||
|
pub trait GuardrailRule: Send + Sync {
|
||||||
|
/// Human-readable name for logging.
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// Evaluate a tool call.
|
||||||
|
fn evaluate(&self, tool_name: &str, tool_input: &Value) -> GuardrailVerdict;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decision returned by a guardrail rule.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum GuardrailVerdict {
|
||||||
|
/// Allow the tool call to proceed.
|
||||||
|
Allow,
|
||||||
|
/// Block the call and return *message* as an error to the LLM.
|
||||||
|
Block(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middleware that evaluates tool calls against a set of configurable safety rules.
|
||||||
|
///
|
||||||
|
/// Rules are grouped by tool name. When a tool call is made, all rules for
|
||||||
|
/// that tool are evaluated in order. If any rule returns `Block`, the call
|
||||||
|
/// is blocked. This is a "deny-by-exception" model — calls are allowed unless
|
||||||
|
/// a rule explicitly blocks them.
|
||||||
|
pub struct GuardrailMiddleware {
|
||||||
|
/// Rules keyed by tool name.
|
||||||
|
rules: HashMap<String, Vec<Box<dyn GuardrailRule>>>,
|
||||||
|
/// Default policy for tools with no specific rules: true = allow, false = block.
|
||||||
|
fail_open: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GuardrailMiddleware {
|
||||||
|
pub fn new(fail_open: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
rules: HashMap::new(),
|
||||||
|
fail_open,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a guardrail rule for a specific tool.
|
||||||
|
pub fn add_rule(&mut self, tool_name: impl Into<String>, rule: Box<dyn GuardrailRule>) {
|
||||||
|
self.rules.entry(tool_name.into()).or_default().push(rule);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register built-in safety rules (shell_exec, file_write, web_fetch).
|
||||||
|
pub fn with_builtin_rules(mut self) -> Self {
|
||||||
|
self.add_rule("shell_exec", Box::new(ShellExecRule));
|
||||||
|
self.add_rule("file_write", Box::new(FileWriteRule));
|
||||||
|
self.add_rule("web_fetch", Box::new(WebFetchRule));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AgentMiddleware for GuardrailMiddleware {
|
||||||
|
fn name(&self) -> &str { "guardrail" }
|
||||||
|
fn priority(&self) -> i32 { 400 }
|
||||||
|
|
||||||
|
async fn before_tool_call(
|
||||||
|
&self,
|
||||||
|
_ctx: &MiddlewareContext,
|
||||||
|
tool_name: &str,
|
||||||
|
tool_input: &Value,
|
||||||
|
) -> Result<ToolCallDecision> {
|
||||||
|
if let Some(rules) = self.rules.get(tool_name) {
|
||||||
|
for rule in rules {
|
||||||
|
match rule.evaluate(tool_name, tool_input) {
|
||||||
|
GuardrailVerdict::Allow => {}
|
||||||
|
GuardrailVerdict::Block(msg) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"[GuardrailMiddleware] Rule '{}' blocked tool '{}': {}",
|
||||||
|
rule.name(),
|
||||||
|
tool_name,
|
||||||
|
msg
|
||||||
|
);
|
||||||
|
return Ok(ToolCallDecision::Block(msg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !self.fail_open {
|
||||||
|
// fail-closed: unknown tools are blocked
|
||||||
|
tracing::warn!(
|
||||||
|
"[GuardrailMiddleware] No rules for tool '{}', fail-closed policy blocks it",
|
||||||
|
tool_name
|
||||||
|
);
|
||||||
|
return Ok(ToolCallDecision::Block(format!(
|
||||||
|
"工具 '{}' 未注册安全规则,fail-closed 策略阻止执行",
|
||||||
|
tool_name
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(ToolCallDecision::Allow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Built-in rules
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Rule that blocks dangerous shell commands.
|
||||||
|
pub struct ShellExecRule;
|
||||||
|
|
||||||
|
impl GuardrailRule for ShellExecRule {
|
||||||
|
fn name(&self) -> &str { "shell_exec_dangerous_commands" }
|
||||||
|
|
||||||
|
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||||
|
let cmd = tool_input["command"].as_str().unwrap_or("");
|
||||||
|
let dangerous = [
|
||||||
|
"rm -rf /",
|
||||||
|
"rm -rf ~",
|
||||||
|
"del /s /q C:\\",
|
||||||
|
"format ",
|
||||||
|
"mkfs.",
|
||||||
|
"dd if=",
|
||||||
|
":(){ :|:& };:", // fork bomb
|
||||||
|
"> /dev/sda",
|
||||||
|
"shutdown",
|
||||||
|
"reboot",
|
||||||
|
];
|
||||||
|
let cmd_lower = cmd.to_lowercase();
|
||||||
|
for pattern in &dangerous {
|
||||||
|
if cmd_lower.contains(pattern) {
|
||||||
|
return GuardrailVerdict::Block(format!(
|
||||||
|
"危险命令被安全护栏拦截: 包含 '{}'",
|
||||||
|
pattern
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GuardrailVerdict::Allow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rule that blocks writes to critical system directories.
|
||||||
|
pub struct FileWriteRule;
|
||||||
|
|
||||||
|
impl GuardrailRule for FileWriteRule {
|
||||||
|
fn name(&self) -> &str { "file_write_critical_dirs" }
|
||||||
|
|
||||||
|
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||||
|
let path = tool_input["path"].as_str().unwrap_or("");
|
||||||
|
let critical_prefixes = [
|
||||||
|
"/etc/",
|
||||||
|
"/usr/",
|
||||||
|
"/bin/",
|
||||||
|
"/sbin/",
|
||||||
|
"/boot/",
|
||||||
|
"/System/",
|
||||||
|
"/Library/",
|
||||||
|
"C:\\Windows\\",
|
||||||
|
"C:\\Program Files\\",
|
||||||
|
"C:\\ProgramData\\",
|
||||||
|
];
|
||||||
|
let path_lower = path.to_lowercase();
|
||||||
|
for prefix in &critical_prefixes {
|
||||||
|
if path_lower.starts_with(&prefix.to_lowercase()) {
|
||||||
|
return GuardrailVerdict::Block(format!(
|
||||||
|
"写入系统关键目录被拦截: {}",
|
||||||
|
path
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GuardrailVerdict::Allow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rule that blocks web requests to internal/private network addresses.
|
||||||
|
pub struct WebFetchRule;
|
||||||
|
|
||||||
|
impl GuardrailRule for WebFetchRule {
|
||||||
|
fn name(&self) -> &str { "web_fetch_private_network" }
|
||||||
|
|
||||||
|
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||||
|
let url = tool_input["url"].as_str().unwrap_or("");
|
||||||
|
let blocked = [
|
||||||
|
"localhost",
|
||||||
|
"127.0.0.1",
|
||||||
|
"0.0.0.0",
|
||||||
|
"10.",
|
||||||
|
"172.16.",
|
||||||
|
"172.17.",
|
||||||
|
"172.18.",
|
||||||
|
"172.19.",
|
||||||
|
"172.20.",
|
||||||
|
"172.21.",
|
||||||
|
"172.22.",
|
||||||
|
"172.23.",
|
||||||
|
"172.24.",
|
||||||
|
"172.25.",
|
||||||
|
"172.26.",
|
||||||
|
"172.27.",
|
||||||
|
"172.28.",
|
||||||
|
"172.29.",
|
||||||
|
"172.30.",
|
||||||
|
"172.31.",
|
||||||
|
"192.168.",
|
||||||
|
"::1",
|
||||||
|
"169.254.",
|
||||||
|
"metadata.google",
|
||||||
|
"metadata.azure",
|
||||||
|
];
|
||||||
|
let url_lower = url.to_lowercase();
|
||||||
|
for prefix in &blocked {
|
||||||
|
if url_lower.contains(prefix) {
|
||||||
|
return GuardrailVerdict::Block(format!(
|
||||||
|
"请求内网/私有地址被拦截: {}",
|
||||||
|
url
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GuardrailVerdict::Allow
|
||||||
|
}
|
||||||
|
}
|
||||||
57
crates/zclaw-runtime/src/middleware/loop_guard.rs
Normal file
57
crates/zclaw-runtime/src/middleware/loop_guard.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
//! Loop guard middleware — extracts loop detection into a middleware hook.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
use zclaw_types::Result;
|
||||||
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||||
|
use crate::loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
/// Middleware that detects and blocks repetitive tool-call loops.
|
||||||
|
pub struct LoopGuardMiddleware {
|
||||||
|
guard: Mutex<LoopGuard>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LoopGuardMiddleware {
|
||||||
|
pub fn new(config: LoopGuardConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
guard: Mutex::new(LoopGuard::new(config)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_defaults() -> Self {
|
||||||
|
Self {
|
||||||
|
guard: Mutex::new(LoopGuard::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AgentMiddleware for LoopGuardMiddleware {
|
||||||
|
fn name(&self) -> &str { "loop_guard" }
|
||||||
|
fn priority(&self) -> i32 { 500 }
|
||||||
|
|
||||||
|
async fn before_tool_call(
|
||||||
|
&self,
|
||||||
|
_ctx: &MiddlewareContext,
|
||||||
|
tool_name: &str,
|
||||||
|
tool_input: &Value,
|
||||||
|
) -> Result<ToolCallDecision> {
|
||||||
|
let result = self.guard.lock().unwrap().check(tool_name, tool_input);
|
||||||
|
match result {
|
||||||
|
LoopGuardResult::CircuitBreaker => {
|
||||||
|
tracing::warn!("[LoopGuardMiddleware] Circuit breaker triggered by tool '{}'", tool_name);
|
||||||
|
Ok(ToolCallDecision::Block("检测到工具调用循环,已自动终止".to_string()))
|
||||||
|
}
|
||||||
|
LoopGuardResult::Blocked => {
|
||||||
|
tracing::warn!("[LoopGuardMiddleware] Tool '{}' blocked", tool_name);
|
||||||
|
Ok(ToolCallDecision::Block("工具调用被循环防护拦截".to_string()))
|
||||||
|
}
|
||||||
|
LoopGuardResult::Warn => {
|
||||||
|
tracing::warn!("[LoopGuardMiddleware] Tool '{}' triggered warning", tool_name);
|
||||||
|
Ok(ToolCallDecision::Allow)
|
||||||
|
}
|
||||||
|
LoopGuardResult::Allowed => Ok(ToolCallDecision::Allow),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
115
crates/zclaw-runtime/src/middleware/memory.rs
Normal file
115
crates/zclaw-runtime/src/middleware/memory.rs
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
//! Memory middleware — unified pre/post hooks for memory retrieval and extraction.
|
||||||
|
//!
|
||||||
|
//! This middleware unifies the memory lifecycle:
|
||||||
|
//! - `before_completion`: retrieves relevant memories and injects them into the system prompt
|
||||||
|
//! - `after_completion`: extracts learnings from the conversation and stores them
|
||||||
|
//!
|
||||||
|
//! It replaces both the inline `GrowthIntegration` calls in `AgentLoop` and the
|
||||||
|
//! `intelligence_hooks` calls in the Tauri desktop layer.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use zclaw_types::Result;
|
||||||
|
use crate::growth::GrowthIntegration;
|
||||||
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||||
|
|
||||||
|
/// Middleware that handles memory retrieval (pre-completion) and extraction (post-completion).
|
||||||
|
///
|
||||||
|
/// Wraps `GrowthIntegration` and delegates:
|
||||||
|
/// - `before_completion` → `enhance_prompt()` for memory injection
|
||||||
|
/// - `after_completion` → `process_conversation()` for memory extraction
|
||||||
|
pub struct MemoryMiddleware {
|
||||||
|
growth: GrowthIntegration,
|
||||||
|
/// Minimum seconds between extractions for the same agent (debounce).
|
||||||
|
debounce_secs: u64,
|
||||||
|
/// Timestamp of last extraction per agent (for debouncing).
|
||||||
|
last_extraction: std::sync::Mutex<std::collections::HashMap<String, std::time::Instant>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryMiddleware {
|
||||||
|
pub fn new(growth: GrowthIntegration) -> Self {
|
||||||
|
Self {
|
||||||
|
growth,
|
||||||
|
debounce_secs: 30,
|
||||||
|
last_extraction: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the debounce interval in seconds.
|
||||||
|
pub fn with_debounce_secs(mut self, secs: u64) -> Self {
|
||||||
|
self.debounce_secs = secs;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if enough time has passed since the last extraction for this agent.
|
||||||
|
fn should_extract(&self, agent_id: &str) -> bool {
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
let mut map = self.last_extraction.lock().unwrap();
|
||||||
|
if let Some(last) = map.get(agent_id) {
|
||||||
|
if now.duration_since(*last).as_secs() < self.debounce_secs {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
map.insert(agent_id.to_string(), now);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AgentMiddleware for MemoryMiddleware {
|
||||||
|
fn name(&self) -> &str { "memory" }
|
||||||
|
fn priority(&self) -> i32 { 150 }
|
||||||
|
|
||||||
|
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
|
match self.growth.enhance_prompt(
|
||||||
|
&ctx.agent_id,
|
||||||
|
&ctx.system_prompt,
|
||||||
|
&ctx.user_input,
|
||||||
|
).await {
|
||||||
|
Ok(enhanced) => {
|
||||||
|
ctx.system_prompt = enhanced;
|
||||||
|
Ok(MiddlewareDecision::Continue)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Non-fatal: memory retrieval failure should not block the loop
|
||||||
|
tracing::warn!("[MemoryMiddleware] Prompt enhancement failed: {}", e);
|
||||||
|
Ok(MiddlewareDecision::Continue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||||
|
// Debounce: skip extraction if called too recently for this agent
|
||||||
|
let agent_key = ctx.agent_id.to_string();
|
||||||
|
if !self.should_extract(&agent_key) {
|
||||||
|
tracing::debug!(
|
||||||
|
"[MemoryMiddleware] Skipping extraction for agent {} (debounced)",
|
||||||
|
agent_key
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.messages.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.growth.process_conversation(
|
||||||
|
&ctx.agent_id,
|
||||||
|
&ctx.messages,
|
||||||
|
ctx.session_id.clone(),
|
||||||
|
).await {
|
||||||
|
Ok(count) => {
|
||||||
|
tracing::info!(
|
||||||
|
"[MemoryMiddleware] Extracted {} memories for agent {}",
|
||||||
|
count,
|
||||||
|
agent_key
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Non-fatal: extraction failure should not affect the response
|
||||||
|
tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
62
crates/zclaw-runtime/src/middleware/skill_index.rs
Normal file
62
crates/zclaw-runtime/src/middleware/skill_index.rs
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
//! Skill index middleware — injects a lightweight skill index into the system prompt.
|
||||||
|
//!
|
||||||
|
//! Instead of embedding full skill descriptions (which can consume ~2000 tokens for 70+ skills),
|
||||||
|
//! this middleware injects only skill IDs and one-line triggers (~600 tokens). The LLM can then
|
||||||
|
//! call the `skill_load` tool on demand to retrieve full skill details when needed.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use zclaw_types::Result;
|
||||||
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||||
|
use crate::tool::{SkillIndexEntry, SkillExecutor};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Middleware that injects a lightweight skill index into the system prompt.
|
||||||
|
///
|
||||||
|
/// The index format is compact:
|
||||||
|
/// ```text
|
||||||
|
/// ## Skills (index — use skill_load for details)
|
||||||
|
/// - finance-tracker: 财务分析、财报解读 [数据分析]
|
||||||
|
/// - senior-developer: 代码开发、架构设计 [开发工程]
|
||||||
|
/// ```
|
||||||
|
pub struct SkillIndexMiddleware {
|
||||||
|
/// Pre-built skill index entries, constructed at chain creation time.
|
||||||
|
entries: Vec<SkillIndexEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SkillIndexMiddleware {
|
||||||
|
pub fn new(entries: Vec<SkillIndexEntry>) -> Self {
|
||||||
|
Self { entries }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build index entries from a skill executor that supports listing.
|
||||||
|
pub fn from_executor(executor: &Arc<dyn SkillExecutor>) -> Self {
|
||||||
|
Self {
|
||||||
|
entries: executor.list_skill_index(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AgentMiddleware for SkillIndexMiddleware {
|
||||||
|
fn name(&self) -> &str { "skill_index" }
|
||||||
|
fn priority(&self) -> i32 { 200 }
|
||||||
|
|
||||||
|
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
|
if self.entries.is_empty() {
|
||||||
|
return Ok(MiddlewareDecision::Continue);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut index = String::from("\n\n## Skills (index — call skill_load for details)\n\n");
|
||||||
|
for entry in &self.entries {
|
||||||
|
let triggers = if entry.triggers.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
format!(" — {}", entry.triggers.join(", "))
|
||||||
|
};
|
||||||
|
index.push_str(&format!("- **{}**: {}{}\n", entry.id, entry.description, triggers));
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.system_prompt.push_str(&index);
|
||||||
|
Ok(MiddlewareDecision::Continue)
|
||||||
|
}
|
||||||
|
}
|
||||||
52
crates/zclaw-runtime/src/middleware/token_calibration.rs
Normal file
52
crates/zclaw-runtime/src/middleware/token_calibration.rs
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
//! Token calibration middleware — calibrates token estimation after first LLM response.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use zclaw_types::Result;
|
||||||
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||||
|
use crate::compaction;
|
||||||
|
|
||||||
|
/// Middleware that calibrates the global token estimation factor based on
|
||||||
|
/// actual API-returned token counts from the first LLM response.
|
||||||
|
pub struct TokenCalibrationMiddleware {
|
||||||
|
/// Whether calibration has already been applied in this session.
|
||||||
|
calibrated: std::sync::atomic::AtomicBool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenCalibrationMiddleware {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
calibrated: std::sync::atomic::AtomicBool::new(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TokenCalibrationMiddleware {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AgentMiddleware for TokenCalibrationMiddleware {
|
||||||
|
fn name(&self) -> &str { "token_calibration" }
|
||||||
|
fn priority(&self) -> i32 { 700 }
|
||||||
|
|
||||||
|
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
|
// Calibration happens in after_completion when we have actual token counts.
|
||||||
|
// Before-completion is a no-op.
|
||||||
|
Ok(MiddlewareDecision::Continue)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||||
|
if ctx.input_tokens > 0 && !self.calibrated.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
let estimated = compaction::estimate_messages_tokens(&ctx.messages);
|
||||||
|
compaction::update_calibration(estimated, ctx.input_tokens);
|
||||||
|
self.calibrated.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
tracing::debug!(
|
||||||
|
"[TokenCalibrationMiddleware] Calibrated: estimated={}, actual={}",
|
||||||
|
estimated, ctx.input_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,6 +37,39 @@ pub trait SkillExecutor: Send + Sync {
|
|||||||
session_id: &str,
|
session_id: &str,
|
||||||
input: Value,
|
input: Value,
|
||||||
) -> Result<Value>;
|
) -> Result<Value>;
|
||||||
|
|
||||||
|
/// Return metadata for on-demand skill loading.
|
||||||
|
/// Default returns `None` (skill detail not available).
|
||||||
|
fn get_skill_detail(&self, skill_id: &str) -> Option<SkillDetail> {
|
||||||
|
let _ = skill_id;
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return lightweight index of all available skills.
|
||||||
|
/// Default returns empty (no index available).
|
||||||
|
fn list_skill_index(&self) -> Vec<SkillIndexEntry> {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lightweight skill index entry for system prompt injection.
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct SkillIndexEntry {
|
||||||
|
pub id: String,
|
||||||
|
pub description: String,
|
||||||
|
pub triggers: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Full skill detail returned by `skill_load` tool.
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct SkillDetail {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub category: Option<String>,
|
||||||
|
pub input_schema: Option<Value>,
|
||||||
|
pub triggers: Vec<String>,
|
||||||
|
pub capabilities: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Context provided to tool execution
|
/// Context provided to tool execution
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ mod file_write;
|
|||||||
mod shell_exec;
|
mod shell_exec;
|
||||||
mod web_fetch;
|
mod web_fetch;
|
||||||
mod execute_skill;
|
mod execute_skill;
|
||||||
|
mod skill_load;
|
||||||
mod path_validator;
|
mod path_validator;
|
||||||
|
|
||||||
pub use file_read::FileReadTool;
|
pub use file_read::FileReadTool;
|
||||||
@@ -12,6 +13,7 @@ pub use file_write::FileWriteTool;
|
|||||||
pub use shell_exec::ShellExecTool;
|
pub use shell_exec::ShellExecTool;
|
||||||
pub use web_fetch::WebFetchTool;
|
pub use web_fetch::WebFetchTool;
|
||||||
pub use execute_skill::ExecuteSkillTool;
|
pub use execute_skill::ExecuteSkillTool;
|
||||||
|
pub use skill_load::SkillLoadTool;
|
||||||
pub use path_validator::{PathValidator, PathValidatorConfig};
|
pub use path_validator::{PathValidator, PathValidatorConfig};
|
||||||
|
|
||||||
use crate::tool::ToolRegistry;
|
use crate::tool::ToolRegistry;
|
||||||
@@ -23,4 +25,5 @@ pub fn register_builtin_tools(registry: &mut ToolRegistry) {
|
|||||||
registry.register(Box::new(ShellExecTool::new()));
|
registry.register(Box::new(ShellExecTool::new()));
|
||||||
registry.register(Box::new(WebFetchTool::new()));
|
registry.register(Box::new(WebFetchTool::new()));
|
||||||
registry.register(Box::new(ExecuteSkillTool::new()));
|
registry.register(Box::new(ExecuteSkillTool::new()));
|
||||||
|
registry.register(Box::new(SkillLoadTool::new()));
|
||||||
}
|
}
|
||||||
|
|||||||
81
crates/zclaw-runtime/src/tool/builtin/skill_load.rs
Normal file
81
crates/zclaw-runtime/src/tool/builtin/skill_load.rs
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
//! Skill load tool — on-demand retrieval of full skill details.
|
||||||
|
//!
|
||||||
|
//! When the `SkillIndexMiddleware` is active, the system prompt contains only a lightweight
|
||||||
|
//! skill index. This tool allows the LLM to load full skill details (description, input schema,
|
||||||
|
//! capabilities) on demand, exactly when the LLM decides a particular skill is relevant.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use zclaw_types::{Result, ZclawError};
|
||||||
|
|
||||||
|
use crate::tool::{Tool, ToolContext};
|
||||||
|
|
||||||
|
pub struct SkillLoadTool;
|
||||||
|
|
||||||
|
impl SkillLoadTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for SkillLoadTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"skill_load"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Load full details for a skill by its ID. Use this when you need to understand a skill's \
|
||||||
|
input parameters, capabilities, or usage instructions before calling execute_skill. \
|
||||||
|
Returns the skill description, input schema, and trigger conditions."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"skill_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the skill to load details for"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["skill_id"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||||
|
let skill_id = input["skill_id"].as_str()
|
||||||
|
.ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?;
|
||||||
|
|
||||||
|
let executor = context.skill_executor.as_ref()
|
||||||
|
.ok_or_else(|| ZclawError::ToolError("Skill executor not available".into()))?;
|
||||||
|
|
||||||
|
match executor.get_skill_detail(skill_id) {
|
||||||
|
Some(detail) => {
|
||||||
|
let mut result = json!({
|
||||||
|
"id": detail.id,
|
||||||
|
"name": detail.name,
|
||||||
|
"description": detail.description,
|
||||||
|
"triggers": detail.triggers,
|
||||||
|
});
|
||||||
|
if let Some(schema) = &detail.input_schema {
|
||||||
|
result["input_schema"] = schema.clone();
|
||||||
|
}
|
||||||
|
if let Some(cat) = &detail.category {
|
||||||
|
result["category"] = json!(cat);
|
||||||
|
}
|
||||||
|
if !detail.capabilities.is_empty() {
|
||||||
|
result["capabilities"] = json!(detail.capabilities);
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
None => Err(ZclawError::ToolError(format!("Skill not found: {}", skill_id))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SkillLoadTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -133,6 +133,14 @@ impl SkillRegistry {
|
|||||||
manifests.values().cloned().collect()
|
manifests.values().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Synchronous snapshot of all manifests.
|
||||||
|
/// Uses `try_read` — returns empty map if write lock is held (should be rare at steady state).
|
||||||
|
pub fn manifests_snapshot(&self) -> HashMap<SkillId, SkillManifest> {
|
||||||
|
self.manifests.try_read()
|
||||||
|
.map(|guard| guard.clone())
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
/// Execute a skill
|
/// Execute a skill
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
Reference in New Issue
Block a user