fix(runtime): 工具调用 P1/P2/P3 全面修复
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

P1: 流式模式工具并行执行
- 三阶段执行: Phase 1 中间件预检(serial) → Phase 2 并行+串行分区 → Phase 3 结果排序
- ReadOnly 工具用 JoinSet + Semaphore(3) 并行,Exclusive/Interactive 串行
- 与非流式模式保持一致的执行策略

P2: OpenAI 驱动工具参数解析
- 解析失败不再静默替换为 {},改为返回 _parse_error + _raw_args
- 让 LLM 和工具能感知参数问题并自我修正

P2: ToolOutputGuard 精确匹配
- 从 to_lowercase() 关键词匹配改为 regex 精确匹配实际密钥值
- 检测 sk-xxx(20+), AKIA(16), PEM 私钥, key=value 模式
- 移除 "system:", "you are now" 等过于宽泛的注入检测
- 消除合法内容包含 "password" 等词汇时的误拦

P2: ToolErrorMiddleware per-session 计数
- 从全局 AtomicU32 改为 Mutex<HashMap<session_id, u32>>
- 每个会话独立跟踪连续失败次数,消除跨会话误触发 AbortLoop

P3: Gateway client onTool 回调语义
- 明确 tool_call 的 output 始终为空串 (start 信号)
- 添加注释说明 start/end 语义约定
This commit is contained in:
iven
2026-04-24 12:56:07 +08:00
parent c12b64150b
commit 3eb098f020
7 changed files with 226 additions and 216 deletions

View File

@@ -222,10 +222,13 @@ impl LlmDriver for OpenAiDriver {
let parsed_args: serde_json::Value = if args.is_empty() {
serde_json::json!({})
} else {
serde_json::from_str(args).unwrap_or_else(|e| {
tracing::warn!("[OpenAI] Failed to parse tool args '{}': {}, using empty object", args, e);
serde_json::json!({})
})
match serde_json::from_str(args) {
Ok(v) => v,
Err(e) => {
tracing::error!("[OpenAI] Failed to parse tool call '{}' args: {}. Raw: {}", name, e, &args[..args.len().min(200)]);
serde_json::json!({ "_parse_error": e.to_string(), "_raw_args": args[..args.len().min(500)].to_string() })
}
}
};
yield Ok(StreamChunk::ToolUseEnd {
id: id.clone(),

View File

@@ -921,177 +921,167 @@ impl AgentLoop {
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
}
// Execute tools
for (id, name, input) in pending_tool_calls {
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
// Execute tools — Phase 1: Pre-process through middleware (serial)
struct StreamToolPlan { idx: usize, id: String, name: String, input: Value }
let mut plans: Vec<StreamToolPlan> = Vec::new();
let mut abort_loop = false;
for (idx, (id, name, input)) in pending_tool_calls.into_iter().enumerate() {
if abort_loop { break; }
let mw_ctx = middleware::MiddlewareContext {
agent_id: agent_id.clone(),
session_id: session_id_clone.clone(),
user_input: input.to_string(),
system_prompt: enhanced_prompt.clone(),
messages: messages.clone(),
response_content: Vec::new(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
};
match middleware_chain.run_before_tool_call(&mw_ctx, &name, &input).await {
Ok(middleware::ToolCallDecision::Allow) => {
plans.push(StreamToolPlan { idx, id, name, input });
}
Ok(middleware::ToolCallDecision::Block(msg)) => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
let error_output = serde_json::json!({ "error": msg });
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
}
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
plans.push(StreamToolPlan { idx, id, name, input: new_input });
}
Ok(middleware::ToolCallDecision::AbortLoop(reason)) => {
tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason);
if let Err(e) = tx.send(LoopEvent::Error(reason)).await {
tracing::warn!("[AgentLoop] Failed to send Error event: {}", e);
}
abort_loop = true;
}
Err(e) => {
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
let error_output = serde_json::json!({ "error": e.to_string() });
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
}
}
}
if abort_loop { break 'outer; }
if plans.is_empty() {
tracing::debug!("[AgentLoop] No tools to execute after middleware filtering");
break 'outer;
}
// Check tool call safety — via middleware chain
{
let mw_ctx = middleware::MiddlewareContext {
agent_id: agent_id.clone(),
session_id: session_id_clone.clone(),
user_input: input.to_string(),
system_prompt: enhanced_prompt.clone(),
messages: messages.clone(),
response_content: Vec::new(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
};
match middleware_chain.run_before_tool_call(&mw_ctx, &name, &input).await {
Ok(middleware::ToolCallDecision::Allow) => {}
Ok(middleware::ToolCallDecision::Block(msg)) => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
let error_output = serde_json::json!({ "error": msg });
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
// Build shared tool context
let pv = path_validator.clone().unwrap_or_else(|| {
let home = std::env::var("USERPROFILE")
.or_else(|_| std::env::var("HOME"))
.unwrap_or_else(|_| ".".to_string());
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
});
let working_dir = pv.workspace_root().map(|p| p.to_string_lossy().to_string());
let tool_context = ToolContext {
agent_id: agent_id.clone(),
working_directory: working_dir,
session_id: Some(session_id_clone.to_string()),
skill_executor: skill_executor.clone(),
hand_executor: hand_executor.clone(),
path_validator: Some(pv),
event_sender: Some(tx.clone()),
};
// Phase 2: Execute tools (parallel for ReadOnly, serial for others)
let (parallel_plans, sequential_plans): (Vec<_>, Vec<_>) = plans.iter()
.partition(|p| {
tools.get(&p.name)
.map(|t| t.concurrency())
.unwrap_or(ToolConcurrency::Exclusive) == ToolConcurrency::ReadOnly
});
let mut results: std::collections::HashMap<usize, (String, String, serde_json::Value, bool)> = std::collections::HashMap::new();
// Execute parallel (ReadOnly) tools with JoinSet (max 3 concurrent)
if !parallel_plans.is_empty() {
let sem = Arc::new(tokio::sync::Semaphore::new(3));
let mut join_set = tokio::task::JoinSet::new();
for plan in &parallel_plans {
let tool_ctx = tool_context.clone();
let input = plan.input.clone();
let idx = plan.idx;
let id = plan.id.clone();
let name = plan.name.clone();
let tools_ref = tools.clone();
let permit = sem.clone().acquire_owned().await.unwrap();
join_set.spawn(async move {
let result = if let Some(tool) = tools_ref.get(&name) {
tokio::time::timeout(std::time::Duration::from_secs(30), tool.execute(input, &tool_ctx)).await
} else {
Ok(Err(zclaw_types::ZclawError::Internal(format!("Unknown tool: {}", name))))
};
drop(permit);
(idx, id, name, result)
});
}
while let Some(res) = join_set.join_next().await {
match res {
Ok((idx, id, name, Ok(Ok(value)))) => {
results.insert(idx, (id, name, value, false));
}
Ok(middleware::ToolCallDecision::AbortLoop(reason)) => {
tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason);
if let Err(e) = tx.send(LoopEvent::Error(reason)).await {
tracing::warn!("[AgentLoop] Failed to send Error event: {}", e);
}
break 'outer;
Ok((idx, id, name, Ok(Err(e)))) => {
results.insert(idx, (id, name, serde_json::json!({ "error": e.to_string() }), true));
}
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
// Execute with replaced input (same path_validator logic below)
let pv = path_validator.clone().unwrap_or_else(|| {
let home = std::env::var("USERPROFILE")
.or_else(|_| std::env::var("HOME"))
.unwrap_or_else(|_| ".".to_string());
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
});
let working_dir = pv.workspace_root()
.map(|p| p.to_string_lossy().to_string());
let tool_context = ToolContext {
agent_id: agent_id.clone(),
working_directory: working_dir,
session_id: Some(session_id_clone.to_string()),
skill_executor: skill_executor.clone(),
hand_executor: hand_executor.clone(),
path_validator: Some(pv),
event_sender: Some(tx.clone()),
};
let (result, is_error) = if let Some(tool) = tools.get(&name) {
match tool.execute(new_input, &tool_context).await {
Ok(output) => {
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
(output, false)
}
Err(e) => {
let error_output = serde_json::json!({ "error": e.to_string() });
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
(error_output, true)
}
}
} else {
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
(error_output, true)
};
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
continue;
Ok((idx, id, name, Err(_))) => {
tracing::warn!("[AgentLoop] Tool '{}' timed out (parallel, 30s)", name);
results.insert(idx, (id, name.clone(), serde_json::json!({ "error": format!("工具 '{}' 执行超时", name) }), true));
}
Err(e) => {
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
let error_output = serde_json::json!({ "error": e.to_string() });
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
tracing::warn!("[AgentLoop] JoinError in parallel tool execution: {}", e);
}
}
}
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
let pv = path_validator.clone().unwrap_or_else(|| {
let home = std::env::var("USERPROFILE")
.or_else(|_| std::env::var("HOME"))
.unwrap_or_else(|_| ".".to_string());
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
});
let working_dir = pv.workspace_root()
.map(|p| p.to_string_lossy().to_string());
let tool_context = ToolContext {
agent_id: agent_id.clone(),
working_directory: working_dir,
session_id: Some(session_id_clone.to_string()),
skill_executor: skill_executor.clone(),
hand_executor: hand_executor.clone(),
path_validator: Some(pv),
event_sender: Some(tx.clone()),
};
}
let (result, is_error) = if let Some(tool) = tools.get(&name) {
tracing::debug!("[AgentLoop] Tool '{}' found, executing...", name);
match tool.execute(input.clone(), &tool_context).await {
Ok(output) => {
tracing::debug!("[AgentLoop] Tool '{}' executed successfully: {:?}", name, output);
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
(output, false)
}
Err(e) => {
tracing::error!("[AgentLoop] Tool '{}' execution failed: {}", name, e);
let error_output = serde_json::json!({ "error": e.to_string() });
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
(error_output, true)
}
// Execute sequential (Exclusive/Interactive) tools
for plan in &sequential_plans {
let (result, is_error) = if let Some(tool) = tools.get(&plan.name) {
match tool.execute(plan.input.clone(), &tool_context).await {
Ok(output) => (output, false),
Err(e) => (serde_json::json!({ "error": e.to_string() }), true),
}
} else {
tracing::error!("[AgentLoop] Tool '{}' not found in registry", name);
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
(error_output, true)
(serde_json::json!({ "error": format!("Unknown tool: {}", plan.name) }), true)
};
// Check if this is a clarification response — break outer loop
if name == "ask_clarification"
// Check clarification (only from sequential tools — ask_clarification is Interactive)
if plan.name == "ask_clarification"
&& result.get("status").and_then(|v| v.as_str()) == Some("clarification_needed")
{
tracing::info!("[AgentLoop] Streaming: Clarification requested, terminating loop");
let question = result.get("question")
.and_then(|v| v.as_str())
.unwrap_or("需要更多信息")
.to_string();
messages.push(Message::tool_result(
id,
zclaw_types::ToolId::new(&name),
result,
is_error,
));
// Send the question as final delta so the user sees it
if let Err(e) = tx.send(LoopEvent::Delta(question.clone())).await {
tracing::warn!("[AgentLoop] Failed to send Delta event: {}", e);
}
if let Err(e) = tx.send(LoopEvent::Complete(AgentLoopResult {
response: question.clone(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
iterations: iteration,
})).await {
tracing::warn!("[AgentLoop] Failed to send Complete event: {}", e);
}
if let Err(e) = memory.append_message(&session_id_clone, &Message::assistant(&question)).await {
tracing::warn!("[AgentLoop] Failed to save clarification message: {}", e);
}
let question = result.get("question").and_then(|v| v.as_str()).unwrap_or("需要更多信息").to_string();
messages.push(Message::tool_result(plan.id.clone(), zclaw_types::ToolId::new(&plan.name), result, is_error));
if let Err(e) = tx.send(LoopEvent::Delta(question.clone())).await { tracing::warn!("{}", e); }
if let Err(e) = tx.send(LoopEvent::Complete(AgentLoopResult { response: question.clone(), input_tokens: total_input_tokens, output_tokens: total_output_tokens, iterations: iteration })).await { tracing::warn!("{}", e); }
if let Err(e) = memory.append_message(&session_id_clone, &Message::assistant(&question)).await { tracing::warn!("{}", e); }
break 'outer;
}
results.insert(plan.idx, (plan.id.clone(), plan.name.clone(), result, is_error));
}
// Run after_tool_call middleware chain (error counting, output guard, etc.)
// Phase 3: after_tool_call middleware + push results in original order
let mut sorted_indices: Vec<usize> = results.keys().copied().collect();
sorted_indices.sort();
for idx in sorted_indices {
let (id, name, result, is_error) = results.remove(&idx).unwrap();
// Emit ToolEnd event
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: result.clone() }).await {
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
}
// Run after_tool_call middleware
{
let mut mw_ctx = middleware::MiddlewareContext {
agent_id: agent_id.clone(),
@@ -1108,14 +1098,7 @@ impl AgentLoop {
}
}
// Add tool result to message history
tracing::debug!("[AgentLoop] Adding tool_result to history: id={}, name={}, is_error={}", id, name, is_error);
messages.push(Message::tool_result(
id,
zclaw_types::ToolId::new(&name),
result,
is_error,
));
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
}
tracing::debug!("[AgentLoop] Continuing to next iteration for LLM to process tool results");

View File

@@ -13,7 +13,8 @@ use serde_json::Value;
use zclaw_types::Result;
use crate::driver::ContentBlock;
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
use std::sync::atomic::{AtomicU32, Ordering};
use std::collections::HashMap;
use std::sync::Mutex;
/// Middleware that intercepts tool call errors and formats recovery messages.
///
@@ -23,8 +24,8 @@ pub struct ToolErrorMiddleware {
max_error_length: usize,
/// Maximum consecutive failures before aborting the loop.
max_consecutive_failures: u32,
/// Tracks consecutive tool failures.
consecutive_failures: AtomicU32,
/// Tracks consecutive tool failures per session.
session_failures: Mutex<HashMap<String, u32>>,
}
impl ToolErrorMiddleware {
@@ -32,7 +33,7 @@ impl ToolErrorMiddleware {
Self {
max_error_length: 500,
max_consecutive_failures: 3,
consecutive_failures: AtomicU32::new(0),
session_failures: Mutex::new(HashMap::new()),
}
}
@@ -66,7 +67,7 @@ impl AgentMiddleware for ToolErrorMiddleware {
async fn before_tool_call(
&self,
_ctx: &MiddlewareContext,
ctx: &MiddlewareContext,
tool_name: &str,
tool_input: &Value,
) -> Result<ToolCallDecision> {
@@ -79,8 +80,10 @@ impl AgentMiddleware for ToolErrorMiddleware {
return Ok(ToolCallDecision::ReplaceInput(serde_json::json!({})));
}
// Check consecutive failure count — abort if too many failures
let failures = self.consecutive_failures.load(Ordering::SeqCst);
// Check consecutive failure count — abort if too many failures (per session)
let failures = self.session_failures.lock()
.map(|m| m.get(&ctx.session_id.to_string()).copied().unwrap_or(0))
.unwrap_or(0);
if failures >= self.max_consecutive_failures {
tracing::warn!(
"[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures",
@@ -102,7 +105,14 @@ impl AgentMiddleware for ToolErrorMiddleware {
) -> Result<()> {
// Check if the tool result indicates an error.
if let Some(error) = result.get("error") {
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
let session_key = ctx.session_id.to_string();
let failures = self.session_failures.lock()
.map(|mut m| {
let count = m.entry(session_key.clone()).or_insert(0);
*count += 1;
*count
})
.unwrap_or(1);
let error_msg = match error {
Value::String(s) => s.clone(),
other => other.to_string(),
@@ -124,8 +134,11 @@ impl AgentMiddleware for ToolErrorMiddleware {
text: guided_message,
});
} else {
// Success — reset consecutive failure counter
self.consecutive_failures.store(0, Ordering::SeqCst);
// Success — reset consecutive failure counter for this session
let session_key = ctx.session_id.to_string();
if let Ok(mut m) = self.session_failures.lock() {
m.insert(session_key, 0);
}
}
Ok(())

View File

@@ -21,35 +21,27 @@ use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
/// Maximum safe output length in characters.
const MAX_OUTPUT_LENGTH: usize = 50_000;
/// Patterns that indicate sensitive information in tool output.
const SENSITIVE_PATTERNS: &[&str] = &[
"api_key",
"apikey",
"api-key",
"secret_key",
"secretkey",
"access_token",
"auth_token",
"password",
"private_key",
"-----BEGIN RSA",
"-----BEGIN PRIVATE",
"sk-", // OpenAI API keys
"sk_live_", // Stripe keys
"AKIA", // AWS access keys
/// Regex patterns that match actual secret values (not just keywords).
/// These detect the *value format* of secrets, avoiding false positives
/// from legitimate content that merely mentions "password" or "api_key".
const SECRET_VALUE_PATTERNS: &[&str] = &[
r#"sk-[a-zA-Z0-9]{20,}"#, // OpenAI API keys (sk-xxx, 20+ chars)
r#"sk_live_[a-zA-Z0-9]{20,}"#, // Stripe live keys
r#"sk_test_[a-zA-Z0-9]{20,}"#, // Stripe test keys
r#"AKIA[A-Z0-9]{16}"#, // AWS access keys (exact 20 chars)
r#"-----BEGIN (RSA |EC )?PRIVATE KEY-----"#, // PEM private keys
r#"(?:api_?key|secret_?key|access_?token|auth_?token|password)\s*[:=]\s*["'][^"']{8,}["']"#, // key=value with actual secret
];
/// Patterns that may indicate prompt injection in tool output.
/// Keyword patterns that indicate prompt injection in tool output.
/// These are specific enough to avoid false positives from normal content.
const INJECTION_PATTERNS: &[&str] = &[
"ignore previous instructions",
"ignore all previous",
"disregard your instructions",
"you are now",
"new instructions:",
"system:",
"[INST]",
"</scratchpad>",
"think step by step about",
];
/// Tool output sanitization middleware.
@@ -105,22 +97,24 @@ impl AgentMiddleware for ToolOutputGuardMiddleware {
);
}
// Rule 2: Sensitive information detection — block output containing secrets (P2-22)
let output_lower = output_str.to_lowercase();
for pattern in SENSITIVE_PATTERNS {
if output_lower.contains(pattern) {
tracing::error!(
"[ToolOutputGuard] BLOCKED tool '{}' output: sensitive pattern '{}'",
tool_name, pattern
);
return Err(zclaw_types::ZclawError::Internal(format!(
"[ToolOutputGuard] Tool '{}' output blocked: sensitive information detected ('{}')",
tool_name, pattern
)));
// Rule 2: Sensitive information detection — match actual secret values, not keywords
for pattern in SECRET_VALUE_PATTERNS {
if let Ok(re) = regex::Regex::new(pattern) {
if re.is_match(&output_str) {
tracing::error!(
"[ToolOutputGuard] BLOCKED tool '{}' output: secret value matched pattern '{}'",
tool_name, pattern
);
return Err(zclaw_types::ZclawError::Internal(format!(
"[ToolOutputGuard] Tool '{}' output blocked: sensitive information detected",
tool_name
)));
}
}
}
// Rule 3: Injection marker detection — BLOCK the output (P2-22 fix)
// Rule 3: Injection marker detection — specific phrase matching
let output_lower = output_str.to_lowercase();
for pattern in INJECTION_PATTERNS {
if output_lower.contains(pattern) {
tracing::error!(