//! Agent middleware system — composable hooks for cross-cutting concerns. //! //! Inspired by [DeerFlow 2.0](https://github.com/bytedance/deer-flow)'s 9-layer middleware chain, //! this module provides a standardised way to inject behaviour before/after LLM completions //! and tool calls without modifying the core `AgentLoop` logic. //! //! # Priority convention //! //! | Range | Category | Example | //! |---------|----------------|-----------------------------| //! | 100-199 | Context shaping| Compaction, MemoryInject | //! | 200-399 | Capability | SkillIndex, Guardrail | //! | 400-599 | Safety | LoopGuard, Guardrail | //! | 600-799 | Telemetry | TokenCalibration, Tracking | use std::sync::Arc; use async_trait::async_trait; use serde_json::Value; use zclaw_types::{AgentId, Result, SessionId}; use crate::driver::ContentBlock; // --------------------------------------------------------------------------- // Decisions returned by middleware hooks // --------------------------------------------------------------------------- /// Decision returned by `before_completion`. #[derive(Debug, Clone)] pub enum MiddlewareDecision { /// Continue to the next middleware / proceed with the LLM call. Continue, /// Abort the agent loop and return *reason* to the caller. Stop(String), } /// Decision returned by `before_tool_call`. #[derive(Debug, Clone)] pub enum ToolCallDecision { /// Allow the tool call to proceed unchanged. Allow, /// Block the call and return *message* as a tool-error to the LLM. Block(String), /// Allow the call but replace the tool input with *new_input*. ReplaceInput(Value), /// Terminate the entire agent loop immediately (e.g. circuit breaker). AbortLoop(String), } // --------------------------------------------------------------------------- // Middleware context — shared mutable state passed through the chain // --------------------------------------------------------------------------- /// Carries the mutable state that middleware may inspect or modify. pub struct MiddlewareContext { /// The agent that owns this loop. pub agent_id: AgentId, /// Current session. pub session_id: SessionId, /// The raw user input that started this turn. pub user_input: String, // -- mutable state ------------------------------------------------------- /// System prompt — middleware may prepend/append context. pub system_prompt: String, /// Conversation messages sent to the LLM. pub messages: Vec, /// Accumulated LLM content blocks from the current response. pub response_content: Vec, /// Token usage reported by the LLM driver (updated after each call). pub input_tokens: u32, pub output_tokens: u32, } impl std::fmt::Debug for MiddlewareContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MiddlewareContext") .field("agent_id", &self.agent_id) .field("session_id", &self.session_id) .field("messages", &self.messages.len()) .field("input_tokens", &self.input_tokens) .field("output_tokens", &self.output_tokens) .finish() } } // --------------------------------------------------------------------------- // Core trait // --------------------------------------------------------------------------- /// A composable middleware hook for the agent loop. /// /// Each middleware focuses on one cross-cutting concern and is executed /// in `priority` order (ascending). All hook methods have default no-op /// implementations so implementors only override what they need. #[async_trait] pub trait AgentMiddleware: Send + Sync { /// Human-readable name for logging / debugging. fn name(&self) -> &str; /// Execution priority — lower values run first. fn priority(&self) -> i32 { 500 } /// Hook executed **before** the LLM completion request is sent. /// /// Use this to inject context (memory, skill index, etc.) or to /// trigger pre-processing (compaction, summarisation). async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result { Ok(MiddlewareDecision::Continue) } /// Hook executed **before** each tool call. /// /// Return `Block` to prevent execution and feed an error back to /// the LLM, or `ReplaceInput` to sanitise / modify the arguments. async fn before_tool_call( &self, _ctx: &MiddlewareContext, _tool_name: &str, _tool_input: &Value, ) -> Result { Ok(ToolCallDecision::Allow) } /// Hook executed **after** each tool call. async fn after_tool_call( &self, _ctx: &mut MiddlewareContext, _tool_name: &str, _result: &Value, ) -> Result<()> { Ok(()) } /// Hook executed **after** the entire agent loop turn completes. /// /// Use this for post-processing (memory extraction, telemetry, etc.). async fn after_completion(&self, _ctx: &MiddlewareContext) -> Result<()> { Ok(()) } } // --------------------------------------------------------------------------- // Middleware chain — ordered collection with run methods // --------------------------------------------------------------------------- /// An ordered chain of `AgentMiddleware` instances. pub struct MiddlewareChain { middlewares: Vec>, } impl MiddlewareChain { /// Create an empty chain. pub fn new() -> Self { Self { middlewares: Vec::new() } } /// Register a middleware. The chain is kept sorted by `priority` /// (ascending) and by registration order within the same priority. pub fn register(&mut self, mw: Arc) { let p = mw.priority(); let pos = self.middlewares.iter().position(|m| m.priority() > p).unwrap_or(self.middlewares.len()); self.middlewares.insert(pos, mw); } /// Run all `before_completion` hooks in order. pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result { for mw in &self.middlewares { match mw.before_completion(ctx).await? { MiddlewareDecision::Continue => {} MiddlewareDecision::Stop(reason) => { tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason); return Ok(MiddlewareDecision::Stop(reason)); } } } Ok(MiddlewareDecision::Continue) } /// Run all `before_tool_call` hooks in order. pub async fn run_before_tool_call( &self, ctx: &MiddlewareContext, tool_name: &str, tool_input: &Value, ) -> Result { for mw in &self.middlewares { match mw.before_tool_call(ctx, tool_name, tool_input).await? { ToolCallDecision::Allow => {} other => { tracing::info!("[MiddlewareChain] '{}' decided {:?} for tool '{}'", mw.name(), other, tool_name); return Ok(other); } } } Ok(ToolCallDecision::Allow) } /// Run all `before_tool_call` hooks with mutable context. pub async fn run_before_tool_call_mut( &self, ctx: &mut MiddlewareContext, tool_name: &str, tool_input: &Value, ) -> Result { for mw in &self.middlewares { match mw.before_tool_call(ctx, tool_name, tool_input).await? { ToolCallDecision::Allow => {} other => { tracing::info!("[MiddlewareChain] '{}' decided {:?} for tool '{}'", mw.name(), other, tool_name); return Ok(other); } } } Ok(ToolCallDecision::Allow) } /// Run all `after_tool_call` hooks in order. pub async fn run_after_tool_call( &self, ctx: &mut MiddlewareContext, tool_name: &str, result: &Value, ) -> Result<()> { for mw in &self.middlewares { mw.after_tool_call(ctx, tool_name, result).await?; } Ok(()) } /// Run all `after_completion` hooks in order. pub async fn run_after_completion(&self, ctx: &MiddlewareContext) -> Result<()> { for mw in &self.middlewares { mw.after_completion(ctx).await?; } Ok(()) } /// Number of registered middlewares. pub fn len(&self) -> usize { self.middlewares.len() } /// Whether the chain is empty. pub fn is_empty(&self) -> bool { self.middlewares.is_empty() } } impl Clone for MiddlewareChain { fn clone(&self) -> Self { Self { middlewares: self.middlewares.clone(), // Arc clone — cheap ref-count bump } } } impl Default for MiddlewareChain { fn default() -> Self { Self::new() } } // --------------------------------------------------------------------------- // Sub-modules — concrete middleware implementations // --------------------------------------------------------------------------- pub mod butler_router; pub mod compaction; pub mod dangling_tool; pub mod data_masking; pub mod guardrail; pub mod loop_guard; pub mod memory; pub mod skill_index; pub mod subagent_limit; pub mod title; pub mod token_calibration; pub mod tool_error; pub mod tool_output_guard; pub mod trajectory_recorder;