//! Tool error middleware — catches tool execution errors and converts them //! into well-formed tool-result messages for the LLM to recover from. //! //! Inspired by DeerFlow's ToolErrorMiddleware: instead of propagating raw errors //! that crash the agent loop, this middleware wraps tool errors into a structured //! format that the LLM can use to self-correct. //! //! Also tracks consecutive tool failures across different tools — if N consecutive //! tool calls all fail, the loop is aborted to prevent infinite retry cycles. use async_trait::async_trait; use serde_json::Value; use zclaw_types::Result; use crate::driver::ContentBlock; use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision}; use std::sync::Mutex; /// Middleware that intercepts tool call errors and formats recovery messages. /// /// Priority 350 — runs after dangling tool repair (300) and before guardrail (400). pub struct ToolErrorMiddleware { /// Maximum error message length before truncation. max_error_length: usize, /// Maximum consecutive failures before aborting the loop. max_consecutive_failures: u32, /// Tracks consecutive tool failures. consecutive_failures: Mutex, } impl ToolErrorMiddleware { pub fn new() -> Self { Self { max_error_length: 500, max_consecutive_failures: 3, consecutive_failures: Mutex::new(0), } } /// Create with a custom max error length. pub fn with_max_error_length(mut self, len: usize) -> Self { self.max_error_length = len; self } /// Format a tool error into a guided recovery message for the LLM. /// /// The caller is responsible for truncation before passing `error`. fn format_tool_error(&self, tool_name: &str, error: &str) -> String { format!( "工具 '{}' 执行失败。错误信息: {}\n请分析错误原因,尝试修正参数后重试,或使用其他方法完成任务。", tool_name, error ) } } impl Default for ToolErrorMiddleware { fn default() -> Self { Self::new() } } #[async_trait] impl AgentMiddleware for ToolErrorMiddleware { fn name(&self) -> &str { "tool_error" } fn priority(&self) -> i32 { 350 } async fn before_tool_call( &self, _ctx: &MiddlewareContext, tool_name: &str, tool_input: &Value, ) -> Result { // Pre-validate tool input structure for common issues. if tool_input.is_null() { tracing::warn!( "[ToolErrorMiddleware] Tool '{}' received null input — replacing with empty object", tool_name ); return Ok(ToolCallDecision::ReplaceInput(serde_json::json!({}))); } // Check consecutive failure count — abort if too many failures let failures = self.consecutive_failures.lock().unwrap_or_else(|e| e.into_inner()); if *failures >= self.max_consecutive_failures { tracing::warn!( "[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures", *failures ); return Ok(ToolCallDecision::AbortLoop( format!("连续 {} 次工具调用失败,已自动终止以避免无限重试", *failures) )); } Ok(ToolCallDecision::Allow) } async fn after_tool_call( &self, ctx: &mut MiddlewareContext, tool_name: &str, result: &Value, ) -> Result<()> { let mut failures = self.consecutive_failures.lock().unwrap_or_else(|e| e.into_inner()); // Check if the tool result indicates an error. if let Some(error) = result.get("error") { *failures += 1; let error_msg = match error { Value::String(s) => s.clone(), other => other.to_string(), }; let truncated = if error_msg.len() > self.max_error_length { let end = error_msg.floor_char_boundary(self.max_error_length); format!("{}...(truncated)", &error_msg[..end]) } else { error_msg.clone() }; tracing::warn!( "[ToolErrorMiddleware] Tool '{}' failed ({}/{} consecutive): {}", tool_name, *failures, self.max_consecutive_failures, truncated ); let guided_message = self.format_tool_error(tool_name, &truncated); ctx.response_content.push(ContentBlock::Text { text: guided_message, }); } else { // Success — reset consecutive failure counter *failures = 0; } Ok(()) } }