//! Loop guard middleware — extracts loop detection into a middleware hook. use async_trait::async_trait; use serde_json::Value; use zclaw_types::Result; use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision}; use crate::loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult}; use std::sync::Mutex; /// Middleware that detects and blocks repetitive tool-call loops. pub struct LoopGuardMiddleware { guard: Mutex, } impl LoopGuardMiddleware { pub fn new(config: LoopGuardConfig) -> Self { Self { guard: Mutex::new(LoopGuard::new(config)), } } pub fn with_defaults() -> Self { Self { guard: Mutex::new(LoopGuard::default()), } } } #[async_trait] impl AgentMiddleware for LoopGuardMiddleware { fn name(&self) -> &str { "loop_guard" } fn priority(&self) -> i32 { 500 } async fn before_tool_call( &self, _ctx: &MiddlewareContext, tool_name: &str, tool_input: &Value, ) -> Result { let result = self.guard.lock().unwrap().check(tool_name, tool_input); match result { LoopGuardResult::CircuitBreaker => { tracing::warn!("[LoopGuardMiddleware] Circuit breaker triggered by tool '{}'", tool_name); Ok(ToolCallDecision::Block("检测到工具调用循环,已自动终止".to_string())) } LoopGuardResult::Blocked => { tracing::warn!("[LoopGuardMiddleware] Tool '{}' blocked", tool_name); Ok(ToolCallDecision::Block("工具调用被循环防护拦截".to_string())) } LoopGuardResult::Warn => { tracing::warn!("[LoopGuardMiddleware] Tool '{}' triggered warning", tool_name); Ok(ToolCallDecision::Allow) } LoopGuardResult::Allowed => Ok(ToolCallDecision::Allow), } } }