feat(runtime): DeerFlow 模式中间件链 Phase 1-4 全部完成

借鉴 DeerFlow 架构,实现完整中间件链系统:

Phase 1 - Agent 中间件链基础设施
- MiddlewareChain Clone 支持
- LoopRunner 双路径集成 (middleware/legacy)
- Kernel create_middleware_chain() 工厂方法

Phase 2 - 技能按需注入
- SkillIndexMiddleware (priority 200)
- SkillLoadTool 工具
- SkillDetail/SkillIndexEntry 结构体
- KernelSkillExecutor trait 扩展

Phase 3 - Guardrail 安全护栏
- GuardrailMiddleware (priority 400, fail_open)
- ShellExecRule / FileWriteRule / WebFetchRule

Phase 4 - 记忆闭环统一
- MemoryMiddleware (priority 150, 30s 防抖)
- after_completion 双路径调用

中间件注册顺序:
100 Compaction | 150 Memory | 200 SkillIndex
400 Guardrail  | 500 LoopGuard | 700 TokenCalibration

向后兼容:Option<MiddlewareChain> 默认 None 走旧路径
This commit is contained in:
iven
2026-03-29 23:19:41 +08:00
parent 7de294375b
commit 04c366fe8b
15 changed files with 1302 additions and 43 deletions

View File

@@ -0,0 +1,61 @@
//! Compaction middleware — wraps the existing compaction module.
use async_trait::async_trait;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
use crate::compaction::{self, CompactionConfig};
use crate::growth::GrowthIntegration;
use crate::driver::LlmDriver;
use std::sync::Arc;
/// Middleware that compresses conversation history when it exceeds a token threshold.
pub struct CompactionMiddleware {
threshold: usize,
config: CompactionConfig,
/// Optional LLM driver for async compaction (LLM summarisation, memory flush).
driver: Option<Arc<dyn LlmDriver>>,
/// Optional growth integration for memory flushing during compaction.
growth: Option<GrowthIntegration>,
}
impl CompactionMiddleware {
pub fn new(
threshold: usize,
config: CompactionConfig,
driver: Option<Arc<dyn LlmDriver>>,
growth: Option<GrowthIntegration>,
) -> Self {
Self { threshold, config, driver, growth }
}
}
#[async_trait]
impl AgentMiddleware for CompactionMiddleware {
fn name(&self) -> &str { "compaction" }
fn priority(&self) -> i32 { 100 }
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
if self.threshold == 0 {
return Ok(MiddlewareDecision::Continue);
}
let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
if needs_async {
let outcome = compaction::maybe_compact_with_config(
ctx.messages.clone(),
self.threshold,
&self.config,
&ctx.agent_id,
&ctx.session_id,
self.driver.as_ref(),
self.growth.as_ref(),
)
.await;
ctx.messages = outcome.messages;
} else {
ctx.messages = compaction::maybe_compact(ctx.messages.clone(), self.threshold);
}
Ok(MiddlewareDecision::Continue)
}
}

View File

@@ -0,0 +1,223 @@
//! Guardrail middleware — configurable safety rules for tool call evaluation.
//!
//! This middleware inspects tool calls before execution and can block or
//! modify them based on configurable rules. Inspired by DeerFlow's safety
//! evaluation hooks.
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
/// A single guardrail rule that can inspect and decide on tool calls.
pub trait GuardrailRule: Send + Sync {
/// Human-readable name for logging.
fn name(&self) -> &str;
/// Evaluate a tool call.
fn evaluate(&self, tool_name: &str, tool_input: &Value) -> GuardrailVerdict;
}
/// Decision returned by a guardrail rule.
#[derive(Debug, Clone)]
pub enum GuardrailVerdict {
/// Allow the tool call to proceed.
Allow,
/// Block the call and return *message* as an error to the LLM.
Block(String),
}
/// Middleware that evaluates tool calls against a set of configurable safety rules.
///
/// Rules are grouped by tool name. When a tool call is made, all rules for
/// that tool are evaluated in order. If any rule returns `Block`, the call
/// is blocked. This is a "deny-by-exception" model — calls are allowed unless
/// a rule explicitly blocks them.
pub struct GuardrailMiddleware {
/// Rules keyed by tool name.
rules: HashMap<String, Vec<Box<dyn GuardrailRule>>>,
/// Default policy for tools with no specific rules: true = allow, false = block.
fail_open: bool,
}
impl GuardrailMiddleware {
pub fn new(fail_open: bool) -> Self {
Self {
rules: HashMap::new(),
fail_open,
}
}
/// Register a guardrail rule for a specific tool.
pub fn add_rule(&mut self, tool_name: impl Into<String>, rule: Box<dyn GuardrailRule>) {
self.rules.entry(tool_name.into()).or_default().push(rule);
}
/// Register built-in safety rules (shell_exec, file_write, web_fetch).
pub fn with_builtin_rules(mut self) -> Self {
self.add_rule("shell_exec", Box::new(ShellExecRule));
self.add_rule("file_write", Box::new(FileWriteRule));
self.add_rule("web_fetch", Box::new(WebFetchRule));
self
}
}
#[async_trait]
impl AgentMiddleware for GuardrailMiddleware {
fn name(&self) -> &str { "guardrail" }
fn priority(&self) -> i32 { 400 }
async fn before_tool_call(
&self,
_ctx: &MiddlewareContext,
tool_name: &str,
tool_input: &Value,
) -> Result<ToolCallDecision> {
if let Some(rules) = self.rules.get(tool_name) {
for rule in rules {
match rule.evaluate(tool_name, tool_input) {
GuardrailVerdict::Allow => {}
GuardrailVerdict::Block(msg) => {
tracing::warn!(
"[GuardrailMiddleware] Rule '{}' blocked tool '{}': {}",
rule.name(),
tool_name,
msg
);
return Ok(ToolCallDecision::Block(msg));
}
}
}
} else if !self.fail_open {
// fail-closed: unknown tools are blocked
tracing::warn!(
"[GuardrailMiddleware] No rules for tool '{}', fail-closed policy blocks it",
tool_name
);
return Ok(ToolCallDecision::Block(format!(
"工具 '{}' 未注册安全规则fail-closed 策略阻止执行",
tool_name
)));
}
Ok(ToolCallDecision::Allow)
}
}
// ---------------------------------------------------------------------------
// Built-in rules
// ---------------------------------------------------------------------------
/// Rule that blocks dangerous shell commands.
pub struct ShellExecRule;
impl GuardrailRule for ShellExecRule {
fn name(&self) -> &str { "shell_exec_dangerous_commands" }
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
let cmd = tool_input["command"].as_str().unwrap_or("");
let dangerous = [
"rm -rf /",
"rm -rf ~",
"del /s /q C:\\",
"format ",
"mkfs.",
"dd if=",
":(){ :|:& };:", // fork bomb
"> /dev/sda",
"shutdown",
"reboot",
];
let cmd_lower = cmd.to_lowercase();
for pattern in &dangerous {
if cmd_lower.contains(pattern) {
return GuardrailVerdict::Block(format!(
"危险命令被安全护栏拦截: 包含 '{}'",
pattern
));
}
}
GuardrailVerdict::Allow
}
}
/// Rule that blocks writes to critical system directories.
pub struct FileWriteRule;
impl GuardrailRule for FileWriteRule {
fn name(&self) -> &str { "file_write_critical_dirs" }
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
let path = tool_input["path"].as_str().unwrap_or("");
let critical_prefixes = [
"/etc/",
"/usr/",
"/bin/",
"/sbin/",
"/boot/",
"/System/",
"/Library/",
"C:\\Windows\\",
"C:\\Program Files\\",
"C:\\ProgramData\\",
];
let path_lower = path.to_lowercase();
for prefix in &critical_prefixes {
if path_lower.starts_with(&prefix.to_lowercase()) {
return GuardrailVerdict::Block(format!(
"写入系统关键目录被拦截: {}",
path
));
}
}
GuardrailVerdict::Allow
}
}
/// Rule that blocks web requests to internal/private network addresses.
pub struct WebFetchRule;
impl GuardrailRule for WebFetchRule {
fn name(&self) -> &str { "web_fetch_private_network" }
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
let url = tool_input["url"].as_str().unwrap_or("");
let blocked = [
"localhost",
"127.0.0.1",
"0.0.0.0",
"10.",
"172.16.",
"172.17.",
"172.18.",
"172.19.",
"172.20.",
"172.21.",
"172.22.",
"172.23.",
"172.24.",
"172.25.",
"172.26.",
"172.27.",
"172.28.",
"172.29.",
"172.30.",
"172.31.",
"192.168.",
"::1",
"169.254.",
"metadata.google",
"metadata.azure",
];
let url_lower = url.to_lowercase();
for prefix in &blocked {
if url_lower.contains(prefix) {
return GuardrailVerdict::Block(format!(
"请求内网/私有地址被拦截: {}",
url
));
}
}
GuardrailVerdict::Allow
}
}

View File

@@ -0,0 +1,57 @@
//! Loop guard middleware — extracts loop detection into a middleware hook.
use async_trait::async_trait;
use serde_json::Value;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
use crate::loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
use std::sync::Mutex;
/// Middleware that detects and blocks repetitive tool-call loops.
pub struct LoopGuardMiddleware {
guard: Mutex<LoopGuard>,
}
impl LoopGuardMiddleware {
pub fn new(config: LoopGuardConfig) -> Self {
Self {
guard: Mutex::new(LoopGuard::new(config)),
}
}
pub fn with_defaults() -> Self {
Self {
guard: Mutex::new(LoopGuard::default()),
}
}
}
#[async_trait]
impl AgentMiddleware for LoopGuardMiddleware {
fn name(&self) -> &str { "loop_guard" }
fn priority(&self) -> i32 { 500 }
async fn before_tool_call(
&self,
_ctx: &MiddlewareContext,
tool_name: &str,
tool_input: &Value,
) -> Result<ToolCallDecision> {
let result = self.guard.lock().unwrap().check(tool_name, tool_input);
match result {
LoopGuardResult::CircuitBreaker => {
tracing::warn!("[LoopGuardMiddleware] Circuit breaker triggered by tool '{}'", tool_name);
Ok(ToolCallDecision::Block("检测到工具调用循环,已自动终止".to_string()))
}
LoopGuardResult::Blocked => {
tracing::warn!("[LoopGuardMiddleware] Tool '{}' blocked", tool_name);
Ok(ToolCallDecision::Block("工具调用被循环防护拦截".to_string()))
}
LoopGuardResult::Warn => {
tracing::warn!("[LoopGuardMiddleware] Tool '{}' triggered warning", tool_name);
Ok(ToolCallDecision::Allow)
}
LoopGuardResult::Allowed => Ok(ToolCallDecision::Allow),
}
}
}

View File

@@ -0,0 +1,115 @@
//! Memory middleware — unified pre/post hooks for memory retrieval and extraction.
//!
//! This middleware unifies the memory lifecycle:
//! - `before_completion`: retrieves relevant memories and injects them into the system prompt
//! - `after_completion`: extracts learnings from the conversation and stores them
//!
//! It replaces both the inline `GrowthIntegration` calls in `AgentLoop` and the
//! `intelligence_hooks` calls in the Tauri desktop layer.
use async_trait::async_trait;
use zclaw_types::Result;
use crate::growth::GrowthIntegration;
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
/// Middleware that handles memory retrieval (pre-completion) and extraction (post-completion).
///
/// Wraps `GrowthIntegration` and delegates:
/// - `before_completion` → `enhance_prompt()` for memory injection
/// - `after_completion` → `process_conversation()` for memory extraction
pub struct MemoryMiddleware {
growth: GrowthIntegration,
/// Minimum seconds between extractions for the same agent (debounce).
debounce_secs: u64,
/// Timestamp of last extraction per agent (for debouncing).
last_extraction: std::sync::Mutex<std::collections::HashMap<String, std::time::Instant>>,
}
impl MemoryMiddleware {
pub fn new(growth: GrowthIntegration) -> Self {
Self {
growth,
debounce_secs: 30,
last_extraction: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
/// Set the debounce interval in seconds.
pub fn with_debounce_secs(mut self, secs: u64) -> Self {
self.debounce_secs = secs;
self
}
/// Check if enough time has passed since the last extraction for this agent.
fn should_extract(&self, agent_id: &str) -> bool {
let now = std::time::Instant::now();
let mut map = self.last_extraction.lock().unwrap();
if let Some(last) = map.get(agent_id) {
if now.duration_since(*last).as_secs() < self.debounce_secs {
return false;
}
}
map.insert(agent_id.to_string(), now);
true
}
}
#[async_trait]
impl AgentMiddleware for MemoryMiddleware {
fn name(&self) -> &str { "memory" }
fn priority(&self) -> i32 { 150 }
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
match self.growth.enhance_prompt(
&ctx.agent_id,
&ctx.system_prompt,
&ctx.user_input,
).await {
Ok(enhanced) => {
ctx.system_prompt = enhanced;
Ok(MiddlewareDecision::Continue)
}
Err(e) => {
// Non-fatal: memory retrieval failure should not block the loop
tracing::warn!("[MemoryMiddleware] Prompt enhancement failed: {}", e);
Ok(MiddlewareDecision::Continue)
}
}
}
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
// Debounce: skip extraction if called too recently for this agent
let agent_key = ctx.agent_id.to_string();
if !self.should_extract(&agent_key) {
tracing::debug!(
"[MemoryMiddleware] Skipping extraction for agent {} (debounced)",
agent_key
);
return Ok(());
}
if ctx.messages.is_empty() {
return Ok(());
}
match self.growth.process_conversation(
&ctx.agent_id,
&ctx.messages,
ctx.session_id.clone(),
).await {
Ok(count) => {
tracing::info!(
"[MemoryMiddleware] Extracted {} memories for agent {}",
count,
agent_key
);
}
Err(e) => {
// Non-fatal: extraction failure should not affect the response
tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e);
}
}
Ok(())
}
}

View File

@@ -0,0 +1,62 @@
//! Skill index middleware — injects a lightweight skill index into the system prompt.
//!
//! Instead of embedding full skill descriptions (which can consume ~2000 tokens for 70+ skills),
//! this middleware injects only skill IDs and one-line triggers (~600 tokens). The LLM can then
//! call the `skill_load` tool on demand to retrieve full skill details when needed.
use async_trait::async_trait;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
use crate::tool::{SkillIndexEntry, SkillExecutor};
use std::sync::Arc;
/// Middleware that injects a lightweight skill index into the system prompt.
///
/// The index format is compact:
/// ```text
/// ## Skills (index — use skill_load for details)
/// - finance-tracker: 财务分析、财报解读 [数据分析]
/// - senior-developer: 代码开发、架构设计 [开发工程]
/// ```
pub struct SkillIndexMiddleware {
/// Pre-built skill index entries, constructed at chain creation time.
entries: Vec<SkillIndexEntry>,
}
impl SkillIndexMiddleware {
pub fn new(entries: Vec<SkillIndexEntry>) -> Self {
Self { entries }
}
/// Build index entries from a skill executor that supports listing.
pub fn from_executor(executor: &Arc<dyn SkillExecutor>) -> Self {
Self {
entries: executor.list_skill_index(),
}
}
}
#[async_trait]
impl AgentMiddleware for SkillIndexMiddleware {
fn name(&self) -> &str { "skill_index" }
fn priority(&self) -> i32 { 200 }
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
if self.entries.is_empty() {
return Ok(MiddlewareDecision::Continue);
}
let mut index = String::from("\n\n## Skills (index — call skill_load for details)\n\n");
for entry in &self.entries {
let triggers = if entry.triggers.is_empty() {
String::new()
} else {
format!("{}", entry.triggers.join(", "))
};
index.push_str(&format!("- **{}**: {}{}\n", entry.id, entry.description, triggers));
}
ctx.system_prompt.push_str(&index);
Ok(MiddlewareDecision::Continue)
}
}

View File

@@ -0,0 +1,52 @@
//! Token calibration middleware — calibrates token estimation after first LLM response.
use async_trait::async_trait;
use zclaw_types::Result;
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
use crate::compaction;
/// Middleware that calibrates the global token estimation factor based on
/// actual API-returned token counts from the first LLM response.
pub struct TokenCalibrationMiddleware {
/// Whether calibration has already been applied in this session.
calibrated: std::sync::atomic::AtomicBool,
}
impl TokenCalibrationMiddleware {
pub fn new() -> Self {
Self {
calibrated: std::sync::atomic::AtomicBool::new(false),
}
}
}
impl Default for TokenCalibrationMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AgentMiddleware for TokenCalibrationMiddleware {
fn name(&self) -> &str { "token_calibration" }
fn priority(&self) -> i32 { 700 }
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
// Calibration happens in after_completion when we have actual token counts.
// Before-completion is a no-op.
Ok(MiddlewareDecision::Continue)
}
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
if ctx.input_tokens > 0 && !self.calibrated.load(std::sync::atomic::Ordering::Relaxed) {
let estimated = compaction::estimate_messages_tokens(&ctx.messages);
compaction::update_calibration(estimated, ctx.input_tokens);
self.calibrated.store(true, std::sync::atomic::Ordering::Relaxed);
tracing::debug!(
"[TokenCalibrationMiddleware] Calibrated: estimated={}, actual={}",
estimated, ctx.input_tokens
);
}
Ok(())
}
}