Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
P0 修复: - B-MEM-2: 跨会话记忆丢失 — 添加 IdentityRecall 查询意图检测, 身份类查询绕过 FTS5/LIKE 文本搜索,直接按 scope 检索全部偏好+知识记忆; 缓存 GrowthIntegration 到 Kernel 避免每次请求重建空 scorer - B-HAND-1: Hands 未触发 — 创建 HandTool wrapper 实现 Tool trait, 在 create_tool_registry() 中注册所有已启用 Hands 为 LLM 可调用工具 P1 修复: - B-SCHED-4: 一次性定时未拦截 — 添加 RE_ONE_SHOT_TODAY 正则匹配 "下午3点半提醒我..."类无日期前缀的同日触发模式 - B-CHAT-2: 工具调用循环 — ToolErrorMiddleware 添加连续失败计数器, 3 次连续失败后自动 AbortLoop 防止无限重试 - B-CHAT-5: Stream 竞态 — cancelStream 后添加 500ms cancelCooldown, 防止后端 active-stream 检查竞态
136 lines
4.7 KiB
Rust
136 lines
4.7 KiB
Rust
//! Tool error middleware — catches tool execution errors and converts them
|
|
//! into well-formed tool-result messages for the LLM to recover from.
|
|
//!
|
|
//! Inspired by DeerFlow's ToolErrorMiddleware: instead of propagating raw errors
|
|
//! that crash the agent loop, this middleware wraps tool errors into a structured
|
|
//! format that the LLM can use to self-correct.
|
|
//!
|
|
//! Also tracks consecutive tool failures across different tools — if N consecutive
|
|
//! tool calls all fail, the loop is aborted to prevent infinite retry cycles.
|
|
|
|
use async_trait::async_trait;
|
|
use serde_json::Value;
|
|
use zclaw_types::Result;
|
|
use crate::driver::ContentBlock;
|
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
|
use std::sync::Mutex;
|
|
|
|
/// Middleware that intercepts tool call errors and formats recovery messages.
|
|
///
|
|
/// Priority 350 — runs after dangling tool repair (300) and before guardrail (400).
|
|
pub struct ToolErrorMiddleware {
|
|
/// Maximum error message length before truncation.
|
|
max_error_length: usize,
|
|
/// Maximum consecutive failures before aborting the loop.
|
|
max_consecutive_failures: u32,
|
|
/// Tracks consecutive tool failures.
|
|
consecutive_failures: Mutex<u32>,
|
|
}
|
|
|
|
impl ToolErrorMiddleware {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
max_error_length: 500,
|
|
max_consecutive_failures: 3,
|
|
consecutive_failures: Mutex::new(0),
|
|
}
|
|
}
|
|
|
|
/// Create with a custom max error length.
|
|
pub fn with_max_error_length(mut self, len: usize) -> Self {
|
|
self.max_error_length = len;
|
|
self
|
|
}
|
|
|
|
/// Format a tool error into a guided recovery message for the LLM.
|
|
///
|
|
/// The caller is responsible for truncation before passing `error`.
|
|
fn format_tool_error(&self, tool_name: &str, error: &str) -> String {
|
|
format!(
|
|
"工具 '{}' 执行失败。错误信息: {}\n请分析错误原因,尝试修正参数后重试,或使用其他方法完成任务。",
|
|
tool_name, error
|
|
)
|
|
}
|
|
}
|
|
|
|
impl Default for ToolErrorMiddleware {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl AgentMiddleware for ToolErrorMiddleware {
|
|
fn name(&self) -> &str { "tool_error" }
|
|
fn priority(&self) -> i32 { 350 }
|
|
|
|
async fn before_tool_call(
|
|
&self,
|
|
_ctx: &MiddlewareContext,
|
|
tool_name: &str,
|
|
tool_input: &Value,
|
|
) -> Result<ToolCallDecision> {
|
|
// Pre-validate tool input structure for common issues.
|
|
if tool_input.is_null() {
|
|
tracing::warn!(
|
|
"[ToolErrorMiddleware] Tool '{}' received null input — replacing with empty object",
|
|
tool_name
|
|
);
|
|
return Ok(ToolCallDecision::ReplaceInput(serde_json::json!({})));
|
|
}
|
|
|
|
// Check consecutive failure count — abort if too many failures
|
|
let failures = self.consecutive_failures.lock().unwrap_or_else(|e| e.into_inner());
|
|
if *failures >= self.max_consecutive_failures {
|
|
tracing::warn!(
|
|
"[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures",
|
|
*failures
|
|
);
|
|
return Ok(ToolCallDecision::AbortLoop(
|
|
format!("连续 {} 次工具调用失败,已自动终止以避免无限重试", *failures)
|
|
));
|
|
}
|
|
|
|
Ok(ToolCallDecision::Allow)
|
|
}
|
|
|
|
async fn after_tool_call(
|
|
&self,
|
|
ctx: &mut MiddlewareContext,
|
|
tool_name: &str,
|
|
result: &Value,
|
|
) -> Result<()> {
|
|
let mut failures = self.consecutive_failures.lock().unwrap_or_else(|e| e.into_inner());
|
|
|
|
// Check if the tool result indicates an error.
|
|
if let Some(error) = result.get("error") {
|
|
*failures += 1;
|
|
let error_msg = match error {
|
|
Value::String(s) => s.clone(),
|
|
other => other.to_string(),
|
|
};
|
|
let truncated = if error_msg.len() > self.max_error_length {
|
|
let end = error_msg.floor_char_boundary(self.max_error_length);
|
|
format!("{}...(truncated)", &error_msg[..end])
|
|
} else {
|
|
error_msg.clone()
|
|
};
|
|
|
|
tracing::warn!(
|
|
"[ToolErrorMiddleware] Tool '{}' failed ({}/{} consecutive): {}",
|
|
tool_name, *failures, self.max_consecutive_failures, truncated
|
|
);
|
|
|
|
let guided_message = self.format_tool_error(tool_name, &truncated);
|
|
ctx.response_content.push(ContentBlock::Text {
|
|
text: guided_message,
|
|
});
|
|
} else {
|
|
// Success — reset consecutive failure counter
|
|
*failures = 0;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|