feat(runtime): DeerFlow 模式中间件链 Phase 1-4 全部完成
借鉴 DeerFlow 架构,实现完整中间件链系统: Phase 1 - Agent 中间件链基础设施 - MiddlewareChain Clone 支持 - LoopRunner 双路径集成 (middleware/legacy) - Kernel create_middleware_chain() 工厂方法 Phase 2 - 技能按需注入 - SkillIndexMiddleware (priority 200) - SkillLoadTool 工具 - SkillDetail/SkillIndexEntry 结构体 - KernelSkillExecutor trait 扩展 Phase 3 - Guardrail 安全护栏 - GuardrailMiddleware (priority 400, fail_open) - ShellExecRule / FileWriteRule / WebFetchRule Phase 4 - 记忆闭环统一 - MemoryMiddleware (priority 150, 30s 防抖) - after_completion 双路径调用 中间件注册顺序: 100 Compaction | 150 Memory | 200 SkillIndex 400 Guardrail | 500 LoopGuard | 700 TokenCalibration 向后兼容:Option<MiddlewareChain> 默认 None 走旧路径
This commit is contained in:
@@ -13,6 +13,7 @@ use crate::tool::builtin::PathValidator;
|
||||
use crate::loop_guard::{LoopGuard, LoopGuardResult};
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::compaction::{self, CompactionConfig};
|
||||
use crate::middleware::{self, MiddlewareChain};
|
||||
use zclaw_memory::MemoryStore;
|
||||
|
||||
/// Agent loop runner
|
||||
@@ -34,6 +35,10 @@ pub struct AgentLoop {
|
||||
compaction_threshold: usize,
|
||||
/// Compaction behavior configuration
|
||||
compaction_config: CompactionConfig,
|
||||
/// Optional middleware chain — when `Some`, cross-cutting logic is
|
||||
/// delegated to the chain instead of the inline code below.
|
||||
/// When `None`, the legacy inline path is used (100% backward compatible).
|
||||
middleware_chain: Option<MiddlewareChain>,
|
||||
}
|
||||
|
||||
impl AgentLoop {
|
||||
@@ -58,6 +63,7 @@ impl AgentLoop {
|
||||
growth: None,
|
||||
compaction_threshold: 0,
|
||||
compaction_config: CompactionConfig::default(),
|
||||
middleware_chain: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,6 +130,14 @@ impl AgentLoop {
|
||||
self
|
||||
}
|
||||
|
||||
/// Inject a middleware chain. When set, cross-cutting concerns (compaction,
|
||||
/// loop guard, token calibration, etc.) are delegated to the chain instead
|
||||
/// of the inline logic.
|
||||
pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
|
||||
self.middleware_chain = Some(chain);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get growth integration reference
|
||||
pub fn growth(&self) -> Option<&GrowthIntegration> {
|
||||
self.growth.as_ref()
|
||||
@@ -175,8 +189,10 @@ impl AgentLoop {
|
||||
// Get all messages for context
|
||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||
|
||||
// Apply compaction if threshold is configured
|
||||
if self.compaction_threshold > 0 {
|
||||
let use_middleware = self.middleware_chain.is_some();
|
||||
|
||||
// Apply compaction — skip inline path when middleware chain handles it
|
||||
if !use_middleware && self.compaction_threshold > 0 {
|
||||
let needs_async =
|
||||
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
||||
if needs_async {
|
||||
@@ -196,14 +212,44 @@ impl AgentLoop {
|
||||
}
|
||||
}
|
||||
|
||||
// Enhance system prompt with growth memories
|
||||
let enhanced_prompt = if let Some(ref growth) = self.growth {
|
||||
// Enhance system prompt — skip when middleware chain handles it
|
||||
let mut enhanced_prompt = if use_middleware {
|
||||
self.system_prompt.clone().unwrap_or_default()
|
||||
} else if let Some(ref growth) = self.growth {
|
||||
let base = self.system_prompt.as_deref().unwrap_or("");
|
||||
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
||||
} else {
|
||||
self.system_prompt.clone().unwrap_or_default()
|
||||
};
|
||||
|
||||
// Run middleware before_completion hooks (compaction, memory inject, etc.)
|
||||
if let Some(ref chain) = self.middleware_chain {
|
||||
let mut mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.clone(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages,
|
||||
response_content: Vec::new(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
};
|
||||
match chain.run_before_completion(&mut mw_ctx).await? {
|
||||
middleware::MiddlewareDecision::Continue => {
|
||||
messages = mw_ctx.messages;
|
||||
enhanced_prompt = mw_ctx.system_prompt;
|
||||
}
|
||||
middleware::MiddlewareDecision::Stop(reason) => {
|
||||
return Ok(AgentLoopResult {
|
||||
response: reason,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
iterations: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let max_iterations = 10;
|
||||
let mut iterations = 0;
|
||||
let mut total_input_tokens = 0u32;
|
||||
@@ -307,24 +353,56 @@ impl AgentLoop {
|
||||
let tool_context = self.create_tool_context(session_id.clone());
|
||||
let mut circuit_breaker_triggered = false;
|
||||
for (id, name, input) in tool_calls {
|
||||
// Check loop guard before executing tool
|
||||
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
|
||||
match guard_result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
|
||||
circuit_breaker_triggered = true;
|
||||
break;
|
||||
// Check tool call safety — via middleware chain or inline loop guard
|
||||
if let Some(ref chain) = self.middleware_chain {
|
||||
let mw_ctx_ref = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.to_string(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: messages.clone(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
match chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? {
|
||||
middleware::ToolCallDecision::Allow => {}
|
||||
middleware::ToolCallDecision::Block(msg) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||
let error_output = serde_json::json!({ "error": msg });
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
middleware::ToolCallDecision::ReplaceInput(new_input) => {
|
||||
// Execute with replaced input
|
||||
let tool_result = match self.execute_tool(&name, new_input, &tool_context).await {
|
||||
Ok(result) => result,
|
||||
Err(e) => serde_json::json!({ "error": e.to_string() }),
|
||||
};
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
} else {
|
||||
// Legacy inline path
|
||||
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
|
||||
match guard_result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
|
||||
circuit_breaker_triggered = true;
|
||||
break;
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||
}
|
||||
LoopGuardResult::Allowed => {}
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||
}
|
||||
LoopGuardResult::Allowed => {}
|
||||
}
|
||||
|
||||
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
|
||||
@@ -356,8 +434,23 @@ impl AgentLoop {
|
||||
}
|
||||
};
|
||||
|
||||
// Process conversation for memory extraction (post-conversation)
|
||||
if let Some(ref growth) = self.growth {
|
||||
// Post-completion processing — middleware chain or inline growth
|
||||
if let Some(ref chain) = self.middleware_chain {
|
||||
let mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.clone(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: self.memory.get_messages(&session_id).await.unwrap_or_default(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
|
||||
tracing::warn!("[AgentLoop] Middleware after_completion failed: {}", e);
|
||||
}
|
||||
} else if let Some(ref growth) = self.growth {
|
||||
// Legacy inline path
|
||||
if let Ok(all_messages) = self.memory.get_messages(&session_id).await {
|
||||
if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await {
|
||||
tracing::warn!("[AgentLoop] Growth processing failed: {}", e);
|
||||
@@ -384,8 +477,10 @@ impl AgentLoop {
|
||||
// Get all messages for context
|
||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||
|
||||
// Apply compaction if threshold is configured
|
||||
if self.compaction_threshold > 0 {
|
||||
let use_middleware = self.middleware_chain.is_some();
|
||||
|
||||
// Apply compaction — skip inline path when middleware chain handles it
|
||||
if !use_middleware && self.compaction_threshold > 0 {
|
||||
let needs_async =
|
||||
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
||||
if needs_async {
|
||||
@@ -405,20 +500,52 @@ impl AgentLoop {
|
||||
}
|
||||
}
|
||||
|
||||
// Enhance system prompt with growth memories
|
||||
let enhanced_prompt = if let Some(ref growth) = self.growth {
|
||||
// Enhance system prompt — skip when middleware chain handles it
|
||||
let mut enhanced_prompt = if use_middleware {
|
||||
self.system_prompt.clone().unwrap_or_default()
|
||||
} else if let Some(ref growth) = self.growth {
|
||||
let base = self.system_prompt.as_deref().unwrap_or("");
|
||||
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
||||
} else {
|
||||
self.system_prompt.clone().unwrap_or_default()
|
||||
};
|
||||
|
||||
// Run middleware before_completion hooks (compaction, memory inject, etc.)
|
||||
if let Some(ref chain) = self.middleware_chain {
|
||||
let mut mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.clone(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages,
|
||||
response_content: Vec::new(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
};
|
||||
match chain.run_before_completion(&mut mw_ctx).await? {
|
||||
middleware::MiddlewareDecision::Continue => {
|
||||
messages = mw_ctx.messages;
|
||||
enhanced_prompt = mw_ctx.system_prompt;
|
||||
}
|
||||
middleware::MiddlewareDecision::Stop(reason) => {
|
||||
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
||||
response: reason,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
iterations: 1,
|
||||
})).await;
|
||||
return Ok(rx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clone necessary data for the async task
|
||||
let session_id_clone = session_id.clone();
|
||||
let memory = self.memory.clone();
|
||||
let driver = self.driver.clone();
|
||||
let tools = self.tools.clone();
|
||||
let loop_guard_clone = self.loop_guard.lock().unwrap().clone();
|
||||
let middleware_chain = self.middleware_chain.clone();
|
||||
let skill_executor = self.skill_executor.clone();
|
||||
let path_validator = self.path_validator.clone();
|
||||
let agent_id = self.agent_id.clone();
|
||||
@@ -558,6 +685,24 @@ impl AgentLoop {
|
||||
output_tokens: total_output_tokens,
|
||||
iterations: iteration,
|
||||
})).await;
|
||||
|
||||
// Post-completion: middleware after_completion (memory extraction, etc.)
|
||||
if let Some(ref chain) = middleware_chain {
|
||||
let mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: agent_id.clone(),
|
||||
session_id: session_id_clone.clone(),
|
||||
user_input: String::new(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: memory.get_messages(&session_id_clone).await.unwrap_or_default(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
|
||||
tracing::warn!("[AgentLoop] Streaming middleware after_completion failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
@@ -579,24 +724,92 @@ impl AgentLoop {
|
||||
for (id, name, input) in pending_tool_calls {
|
||||
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
|
||||
|
||||
// Check loop guard before executing tool
|
||||
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
|
||||
match guard_result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
|
||||
break 'outer;
|
||||
// Check tool call safety — via middleware chain or inline loop guard
|
||||
if let Some(ref chain) = middleware_chain {
|
||||
let mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: agent_id.clone(),
|
||||
session_id: session_id_clone.clone(),
|
||||
user_input: input.to_string(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: messages.clone(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
match chain.run_before_tool_call(&mw_ctx, &name, &input).await {
|
||||
Ok(middleware::ToolCallDecision::Allow) => {}
|
||||
Ok(middleware::ToolCallDecision::Block(msg)) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||
let error_output = serde_json::json!({ "error": msg });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
|
||||
// Execute with replaced input (same path_validator logic below)
|
||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
||||
let home = std::env::var("USERPROFILE")
|
||||
.or_else(|_| std::env::var("HOME"))
|
||||
.unwrap_or_else(|_| ".".to_string());
|
||||
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
||||
});
|
||||
let working_dir = pv.workspace_root()
|
||||
.map(|p| p.to_string_lossy().to_string());
|
||||
let tool_context = ToolContext {
|
||||
agent_id: agent_id.clone(),
|
||||
working_directory: working_dir,
|
||||
session_id: Some(session_id_clone.to_string()),
|
||||
skill_executor: skill_executor.clone(),
|
||||
path_validator: Some(pv),
|
||||
};
|
||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
||||
match tool.execute(new_input, &tool_context).await {
|
||||
Ok(output) => {
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await;
|
||||
(output, false)
|
||||
}
|
||||
Err(e) => {
|
||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
(error_output, true)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
(error_output, true)
|
||||
};
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
|
||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
} else {
|
||||
// Legacy inline loop guard path
|
||||
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
|
||||
match guard_result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
|
||||
break 'outer;
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||
}
|
||||
LoopGuardResult::Allowed => {}
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||
}
|
||||
LoopGuardResult::Allowed => {}
|
||||
}
|
||||
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
|
||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
||||
|
||||
Reference in New Issue
Block a user