fix(runtime): 工具调用 P1/P2/P3 全面修复
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
P1: 流式模式工具并行执行
- 三阶段执行: Phase 1 中间件预检(serial) → Phase 2 并行+串行分区 → Phase 3 结果排序
- ReadOnly 工具用 JoinSet + Semaphore(3) 并行,Exclusive/Interactive 串行
- 与非流式模式保持一致的执行策略
P2: OpenAI 驱动工具参数解析
- 解析失败不再静默替换为 {},改为返回 _parse_error + _raw_args
- 让 LLM 和工具能感知参数问题并自我修正
P2: ToolOutputGuard 精确匹配
- 从 to_lowercase() 关键词匹配改为 regex 精确匹配实际密钥值
- 检测 sk-xxx(20+), AKIA(16), PEM 私钥, key=value 模式
- 移除 "system:", "you are now" 等过于宽泛的注入检测
- 消除合法内容包含 "password" 等词汇时的误拦
P2: ToolErrorMiddleware per-session 计数
- 从全局 AtomicU32 改为 Mutex<HashMap<session_id, u32>>
- 每个会话独立跟踪连续失败次数,消除跨会话误触发 AbortLoop
P3: Gateway client onTool 回调语义
- 明确 tool_call 的 output 始终为空串 (start 信号)
- 添加注释说明 start/end 语义约定
This commit is contained in:
@@ -222,10 +222,13 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
let parsed_args: serde_json::Value = if args.is_empty() {
|
let parsed_args: serde_json::Value = if args.is_empty() {
|
||||||
serde_json::json!({})
|
serde_json::json!({})
|
||||||
} else {
|
} else {
|
||||||
serde_json::from_str(args).unwrap_or_else(|e| {
|
match serde_json::from_str(args) {
|
||||||
tracing::warn!("[OpenAI] Failed to parse tool args '{}': {}, using empty object", args, e);
|
Ok(v) => v,
|
||||||
serde_json::json!({})
|
Err(e) => {
|
||||||
})
|
tracing::error!("[OpenAI] Failed to parse tool call '{}' args: {}. Raw: {}", name, e, &args[..args.len().min(200)]);
|
||||||
|
serde_json::json!({ "_parse_error": e.to_string(), "_raw_args": args[..args.len().min(500)].to_string() })
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
yield Ok(StreamChunk::ToolUseEnd {
|
yield Ok(StreamChunk::ToolUseEnd {
|
||||||
id: id.clone(),
|
id: id.clone(),
|
||||||
|
|||||||
@@ -921,177 +921,167 @@ impl AgentLoop {
|
|||||||
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
|
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute tools
|
// Execute tools — Phase 1: Pre-process through middleware (serial)
|
||||||
for (id, name, input) in pending_tool_calls {
|
struct StreamToolPlan { idx: usize, id: String, name: String, input: Value }
|
||||||
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
|
let mut plans: Vec<StreamToolPlan> = Vec::new();
|
||||||
|
let mut abort_loop = false;
|
||||||
|
for (idx, (id, name, input)) in pending_tool_calls.into_iter().enumerate() {
|
||||||
|
if abort_loop { break; }
|
||||||
|
let mw_ctx = middleware::MiddlewareContext {
|
||||||
|
agent_id: agent_id.clone(),
|
||||||
|
session_id: session_id_clone.clone(),
|
||||||
|
user_input: input.to_string(),
|
||||||
|
system_prompt: enhanced_prompt.clone(),
|
||||||
|
messages: messages.clone(),
|
||||||
|
response_content: Vec::new(),
|
||||||
|
input_tokens: total_input_tokens,
|
||||||
|
output_tokens: total_output_tokens,
|
||||||
|
};
|
||||||
|
match middleware_chain.run_before_tool_call(&mw_ctx, &name, &input).await {
|
||||||
|
Ok(middleware::ToolCallDecision::Allow) => {
|
||||||
|
plans.push(StreamToolPlan { idx, id, name, input });
|
||||||
|
}
|
||||||
|
Ok(middleware::ToolCallDecision::Block(msg)) => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||||
|
let error_output = serde_json::json!({ "error": msg });
|
||||||
|
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
|
||||||
|
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
||||||
|
}
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
|
}
|
||||||
|
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
|
||||||
|
plans.push(StreamToolPlan { idx, id, name, input: new_input });
|
||||||
|
}
|
||||||
|
Ok(middleware::ToolCallDecision::AbortLoop(reason)) => {
|
||||||
|
tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason);
|
||||||
|
if let Err(e) = tx.send(LoopEvent::Error(reason)).await {
|
||||||
|
tracing::warn!("[AgentLoop] Failed to send Error event: {}", e);
|
||||||
|
}
|
||||||
|
abort_loop = true;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
|
||||||
|
let error_output = serde_json::json!({ "error": e.to_string() });
|
||||||
|
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
|
||||||
|
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
||||||
|
}
|
||||||
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if abort_loop { break 'outer; }
|
||||||
|
if plans.is_empty() {
|
||||||
|
tracing::debug!("[AgentLoop] No tools to execute after middleware filtering");
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
|
||||||
// Check tool call safety — via middleware chain
|
// Build shared tool context
|
||||||
{
|
let pv = path_validator.clone().unwrap_or_else(|| {
|
||||||
let mw_ctx = middleware::MiddlewareContext {
|
let home = std::env::var("USERPROFILE")
|
||||||
agent_id: agent_id.clone(),
|
.or_else(|_| std::env::var("HOME"))
|
||||||
session_id: session_id_clone.clone(),
|
.unwrap_or_else(|_| ".".to_string());
|
||||||
user_input: input.to_string(),
|
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
||||||
system_prompt: enhanced_prompt.clone(),
|
});
|
||||||
messages: messages.clone(),
|
let working_dir = pv.workspace_root().map(|p| p.to_string_lossy().to_string());
|
||||||
response_content: Vec::new(),
|
let tool_context = ToolContext {
|
||||||
input_tokens: total_input_tokens,
|
agent_id: agent_id.clone(),
|
||||||
output_tokens: total_output_tokens,
|
working_directory: working_dir,
|
||||||
};
|
session_id: Some(session_id_clone.to_string()),
|
||||||
match middleware_chain.run_before_tool_call(&mw_ctx, &name, &input).await {
|
skill_executor: skill_executor.clone(),
|
||||||
Ok(middleware::ToolCallDecision::Allow) => {}
|
hand_executor: hand_executor.clone(),
|
||||||
Ok(middleware::ToolCallDecision::Block(msg)) => {
|
path_validator: Some(pv),
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
event_sender: Some(tx.clone()),
|
||||||
let error_output = serde_json::json!({ "error": msg });
|
};
|
||||||
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
// Phase 2: Execute tools (parallel for ReadOnly, serial for others)
|
||||||
}
|
let (parallel_plans, sequential_plans): (Vec<_>, Vec<_>) = plans.iter()
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
.partition(|p| {
|
||||||
continue;
|
tools.get(&p.name)
|
||||||
|
.map(|t| t.concurrency())
|
||||||
|
.unwrap_or(ToolConcurrency::Exclusive) == ToolConcurrency::ReadOnly
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut results: std::collections::HashMap<usize, (String, String, serde_json::Value, bool)> = std::collections::HashMap::new();
|
||||||
|
|
||||||
|
// Execute parallel (ReadOnly) tools with JoinSet (max 3 concurrent)
|
||||||
|
if !parallel_plans.is_empty() {
|
||||||
|
let sem = Arc::new(tokio::sync::Semaphore::new(3));
|
||||||
|
let mut join_set = tokio::task::JoinSet::new();
|
||||||
|
for plan in ¶llel_plans {
|
||||||
|
let tool_ctx = tool_context.clone();
|
||||||
|
let input = plan.input.clone();
|
||||||
|
let idx = plan.idx;
|
||||||
|
let id = plan.id.clone();
|
||||||
|
let name = plan.name.clone();
|
||||||
|
let tools_ref = tools.clone();
|
||||||
|
let permit = sem.clone().acquire_owned().await.unwrap();
|
||||||
|
join_set.spawn(async move {
|
||||||
|
let result = if let Some(tool) = tools_ref.get(&name) {
|
||||||
|
tokio::time::timeout(std::time::Duration::from_secs(30), tool.execute(input, &tool_ctx)).await
|
||||||
|
} else {
|
||||||
|
Ok(Err(zclaw_types::ZclawError::Internal(format!("Unknown tool: {}", name))))
|
||||||
|
};
|
||||||
|
drop(permit);
|
||||||
|
(idx, id, name, result)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
while let Some(res) = join_set.join_next().await {
|
||||||
|
match res {
|
||||||
|
Ok((idx, id, name, Ok(Ok(value)))) => {
|
||||||
|
results.insert(idx, (id, name, value, false));
|
||||||
}
|
}
|
||||||
Ok(middleware::ToolCallDecision::AbortLoop(reason)) => {
|
Ok((idx, id, name, Ok(Err(e)))) => {
|
||||||
tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason);
|
results.insert(idx, (id, name, serde_json::json!({ "error": e.to_string() }), true));
|
||||||
if let Err(e) = tx.send(LoopEvent::Error(reason)).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send Error event: {}", e);
|
|
||||||
}
|
|
||||||
break 'outer;
|
|
||||||
}
|
}
|
||||||
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
|
Ok((idx, id, name, Err(_))) => {
|
||||||
// Execute with replaced input (same path_validator logic below)
|
tracing::warn!("[AgentLoop] Tool '{}' timed out (parallel, 30s)", name);
|
||||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
results.insert(idx, (id, name.clone(), serde_json::json!({ "error": format!("工具 '{}' 执行超时", name) }), true));
|
||||||
let home = std::env::var("USERPROFILE")
|
|
||||||
.or_else(|_| std::env::var("HOME"))
|
|
||||||
.unwrap_or_else(|_| ".".to_string());
|
|
||||||
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
|
||||||
});
|
|
||||||
let working_dir = pv.workspace_root()
|
|
||||||
.map(|p| p.to_string_lossy().to_string());
|
|
||||||
let tool_context = ToolContext {
|
|
||||||
agent_id: agent_id.clone(),
|
|
||||||
working_directory: working_dir,
|
|
||||||
session_id: Some(session_id_clone.to_string()),
|
|
||||||
skill_executor: skill_executor.clone(),
|
|
||||||
hand_executor: hand_executor.clone(),
|
|
||||||
path_validator: Some(pv),
|
|
||||||
event_sender: Some(tx.clone()),
|
|
||||||
};
|
|
||||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
|
||||||
match tool.execute(new_input, &tool_context).await {
|
|
||||||
Ok(output) => {
|
|
||||||
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
|
||||||
}
|
|
||||||
(output, false)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
|
||||||
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
|
||||||
}
|
|
||||||
(error_output, true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
|
|
||||||
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
|
||||||
}
|
|
||||||
(error_output, true)
|
|
||||||
};
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
|
tracing::warn!("[AgentLoop] JoinError in parallel tool execution: {}", e);
|
||||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
|
||||||
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
|
||||||
}
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
|
}
|
||||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
|
||||||
let home = std::env::var("USERPROFILE")
|
|
||||||
.or_else(|_| std::env::var("HOME"))
|
|
||||||
.unwrap_or_else(|_| ".".to_string());
|
|
||||||
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
|
||||||
});
|
|
||||||
let working_dir = pv.workspace_root()
|
|
||||||
.map(|p| p.to_string_lossy().to_string());
|
|
||||||
let tool_context = ToolContext {
|
|
||||||
agent_id: agent_id.clone(),
|
|
||||||
working_directory: working_dir,
|
|
||||||
session_id: Some(session_id_clone.to_string()),
|
|
||||||
skill_executor: skill_executor.clone(),
|
|
||||||
hand_executor: hand_executor.clone(),
|
|
||||||
path_validator: Some(pv),
|
|
||||||
event_sender: Some(tx.clone()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
// Execute sequential (Exclusive/Interactive) tools
|
||||||
tracing::debug!("[AgentLoop] Tool '{}' found, executing...", name);
|
for plan in &sequential_plans {
|
||||||
match tool.execute(input.clone(), &tool_context).await {
|
let (result, is_error) = if let Some(tool) = tools.get(&plan.name) {
|
||||||
Ok(output) => {
|
match tool.execute(plan.input.clone(), &tool_context).await {
|
||||||
tracing::debug!("[AgentLoop] Tool '{}' executed successfully: {:?}", name, output);
|
Ok(output) => (output, false),
|
||||||
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await {
|
Err(e) => (serde_json::json!({ "error": e.to_string() }), true),
|
||||||
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
|
||||||
}
|
|
||||||
(output, false)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("[AgentLoop] Tool '{}' execution failed: {}", name, e);
|
|
||||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
|
||||||
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
|
||||||
}
|
|
||||||
(error_output, true)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tracing::error!("[AgentLoop] Tool '{}' not found in registry", name);
|
(serde_json::json!({ "error": format!("Unknown tool: {}", plan.name) }), true)
|
||||||
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
|
|
||||||
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
|
||||||
}
|
|
||||||
(error_output, true)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if this is a clarification response — break outer loop
|
// Check clarification (only from sequential tools — ask_clarification is Interactive)
|
||||||
if name == "ask_clarification"
|
if plan.name == "ask_clarification"
|
||||||
&& result.get("status").and_then(|v| v.as_str()) == Some("clarification_needed")
|
&& result.get("status").and_then(|v| v.as_str()) == Some("clarification_needed")
|
||||||
{
|
{
|
||||||
tracing::info!("[AgentLoop] Streaming: Clarification requested, terminating loop");
|
tracing::info!("[AgentLoop] Streaming: Clarification requested, terminating loop");
|
||||||
let question = result.get("question")
|
let question = result.get("question").and_then(|v| v.as_str()).unwrap_or("需要更多信息").to_string();
|
||||||
.and_then(|v| v.as_str())
|
messages.push(Message::tool_result(plan.id.clone(), zclaw_types::ToolId::new(&plan.name), result, is_error));
|
||||||
.unwrap_or("需要更多信息")
|
if let Err(e) = tx.send(LoopEvent::Delta(question.clone())).await { tracing::warn!("{}", e); }
|
||||||
.to_string();
|
if let Err(e) = tx.send(LoopEvent::Complete(AgentLoopResult { response: question.clone(), input_tokens: total_input_tokens, output_tokens: total_output_tokens, iterations: iteration })).await { tracing::warn!("{}", e); }
|
||||||
messages.push(Message::tool_result(
|
if let Err(e) = memory.append_message(&session_id_clone, &Message::assistant(&question)).await { tracing::warn!("{}", e); }
|
||||||
id,
|
|
||||||
zclaw_types::ToolId::new(&name),
|
|
||||||
result,
|
|
||||||
is_error,
|
|
||||||
));
|
|
||||||
// Send the question as final delta so the user sees it
|
|
||||||
if let Err(e) = tx.send(LoopEvent::Delta(question.clone())).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send Delta event: {}", e);
|
|
||||||
}
|
|
||||||
if let Err(e) = tx.send(LoopEvent::Complete(AgentLoopResult {
|
|
||||||
response: question.clone(),
|
|
||||||
input_tokens: total_input_tokens,
|
|
||||||
output_tokens: total_output_tokens,
|
|
||||||
iterations: iteration,
|
|
||||||
})).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to send Complete event: {}", e);
|
|
||||||
}
|
|
||||||
if let Err(e) = memory.append_message(&session_id_clone, &Message::assistant(&question)).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to save clarification message: {}", e);
|
|
||||||
}
|
|
||||||
break 'outer;
|
break 'outer;
|
||||||
}
|
}
|
||||||
|
results.insert(plan.idx, (plan.id.clone(), plan.name.clone(), result, is_error));
|
||||||
|
}
|
||||||
|
|
||||||
// Run after_tool_call middleware chain (error counting, output guard, etc.)
|
// Phase 3: after_tool_call middleware + push results in original order
|
||||||
|
let mut sorted_indices: Vec<usize> = results.keys().copied().collect();
|
||||||
|
sorted_indices.sort();
|
||||||
|
for idx in sorted_indices {
|
||||||
|
let (id, name, result, is_error) = results.remove(&idx).unwrap();
|
||||||
|
|
||||||
|
// Emit ToolEnd event
|
||||||
|
if let Err(e) = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: result.clone() }).await {
|
||||||
|
tracing::warn!("[AgentLoop] Failed to send ToolEnd event: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run after_tool_call middleware
|
||||||
{
|
{
|
||||||
let mut mw_ctx = middleware::MiddlewareContext {
|
let mut mw_ctx = middleware::MiddlewareContext {
|
||||||
agent_id: agent_id.clone(),
|
agent_id: agent_id.clone(),
|
||||||
@@ -1108,14 +1098,7 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add tool result to message history
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
|
||||||
tracing::debug!("[AgentLoop] Adding tool_result to history: id={}, name={}, is_error={}", id, name, is_error);
|
|
||||||
messages.push(Message::tool_result(
|
|
||||||
id,
|
|
||||||
zclaw_types::ToolId::new(&name),
|
|
||||||
result,
|
|
||||||
is_error,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::debug!("[AgentLoop] Continuing to next iteration for LLM to process tool results");
|
tracing::debug!("[AgentLoop] Continuing to next iteration for LLM to process tool results");
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ use serde_json::Value;
|
|||||||
use zclaw_types::Result;
|
use zclaw_types::Result;
|
||||||
use crate::driver::ContentBlock;
|
use crate::driver::ContentBlock;
|
||||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||||
use std::sync::atomic::{AtomicU32, Ordering};
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
/// Middleware that intercepts tool call errors and formats recovery messages.
|
/// Middleware that intercepts tool call errors and formats recovery messages.
|
||||||
///
|
///
|
||||||
@@ -23,8 +24,8 @@ pub struct ToolErrorMiddleware {
|
|||||||
max_error_length: usize,
|
max_error_length: usize,
|
||||||
/// Maximum consecutive failures before aborting the loop.
|
/// Maximum consecutive failures before aborting the loop.
|
||||||
max_consecutive_failures: u32,
|
max_consecutive_failures: u32,
|
||||||
/// Tracks consecutive tool failures.
|
/// Tracks consecutive tool failures per session.
|
||||||
consecutive_failures: AtomicU32,
|
session_failures: Mutex<HashMap<String, u32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolErrorMiddleware {
|
impl ToolErrorMiddleware {
|
||||||
@@ -32,7 +33,7 @@ impl ToolErrorMiddleware {
|
|||||||
Self {
|
Self {
|
||||||
max_error_length: 500,
|
max_error_length: 500,
|
||||||
max_consecutive_failures: 3,
|
max_consecutive_failures: 3,
|
||||||
consecutive_failures: AtomicU32::new(0),
|
session_failures: Mutex::new(HashMap::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,7 +67,7 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
|||||||
|
|
||||||
async fn before_tool_call(
|
async fn before_tool_call(
|
||||||
&self,
|
&self,
|
||||||
_ctx: &MiddlewareContext,
|
ctx: &MiddlewareContext,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
tool_input: &Value,
|
tool_input: &Value,
|
||||||
) -> Result<ToolCallDecision> {
|
) -> Result<ToolCallDecision> {
|
||||||
@@ -79,8 +80,10 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
|||||||
return Ok(ToolCallDecision::ReplaceInput(serde_json::json!({})));
|
return Ok(ToolCallDecision::ReplaceInput(serde_json::json!({})));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check consecutive failure count — abort if too many failures
|
// Check consecutive failure count — abort if too many failures (per session)
|
||||||
let failures = self.consecutive_failures.load(Ordering::SeqCst);
|
let failures = self.session_failures.lock()
|
||||||
|
.map(|m| m.get(&ctx.session_id.to_string()).copied().unwrap_or(0))
|
||||||
|
.unwrap_or(0);
|
||||||
if failures >= self.max_consecutive_failures {
|
if failures >= self.max_consecutive_failures {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
"[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures",
|
"[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures",
|
||||||
@@ -102,7 +105,14 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
|||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// Check if the tool result indicates an error.
|
// Check if the tool result indicates an error.
|
||||||
if let Some(error) = result.get("error") {
|
if let Some(error) = result.get("error") {
|
||||||
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
|
let session_key = ctx.session_id.to_string();
|
||||||
|
let failures = self.session_failures.lock()
|
||||||
|
.map(|mut m| {
|
||||||
|
let count = m.entry(session_key.clone()).or_insert(0);
|
||||||
|
*count += 1;
|
||||||
|
*count
|
||||||
|
})
|
||||||
|
.unwrap_or(1);
|
||||||
let error_msg = match error {
|
let error_msg = match error {
|
||||||
Value::String(s) => s.clone(),
|
Value::String(s) => s.clone(),
|
||||||
other => other.to_string(),
|
other => other.to_string(),
|
||||||
@@ -124,8 +134,11 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
|||||||
text: guided_message,
|
text: guided_message,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Success — reset consecutive failure counter
|
// Success — reset consecutive failure counter for this session
|
||||||
self.consecutive_failures.store(0, Ordering::SeqCst);
|
let session_key = ctx.session_id.to_string();
|
||||||
|
if let Ok(mut m) = self.session_failures.lock() {
|
||||||
|
m.insert(session_key, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -21,35 +21,27 @@ use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
|||||||
/// Maximum safe output length in characters.
|
/// Maximum safe output length in characters.
|
||||||
const MAX_OUTPUT_LENGTH: usize = 50_000;
|
const MAX_OUTPUT_LENGTH: usize = 50_000;
|
||||||
|
|
||||||
/// Patterns that indicate sensitive information in tool output.
|
/// Regex patterns that match actual secret values (not just keywords).
|
||||||
const SENSITIVE_PATTERNS: &[&str] = &[
|
/// These detect the *value format* of secrets, avoiding false positives
|
||||||
"api_key",
|
/// from legitimate content that merely mentions "password" or "api_key".
|
||||||
"apikey",
|
const SECRET_VALUE_PATTERNS: &[&str] = &[
|
||||||
"api-key",
|
r#"sk-[a-zA-Z0-9]{20,}"#, // OpenAI API keys (sk-xxx, 20+ chars)
|
||||||
"secret_key",
|
r#"sk_live_[a-zA-Z0-9]{20,}"#, // Stripe live keys
|
||||||
"secretkey",
|
r#"sk_test_[a-zA-Z0-9]{20,}"#, // Stripe test keys
|
||||||
"access_token",
|
r#"AKIA[A-Z0-9]{16}"#, // AWS access keys (exact 20 chars)
|
||||||
"auth_token",
|
r#"-----BEGIN (RSA |EC )?PRIVATE KEY-----"#, // PEM private keys
|
||||||
"password",
|
r#"(?:api_?key|secret_?key|access_?token|auth_?token|password)\s*[:=]\s*["'][^"']{8,}["']"#, // key=value with actual secret
|
||||||
"private_key",
|
|
||||||
"-----BEGIN RSA",
|
|
||||||
"-----BEGIN PRIVATE",
|
|
||||||
"sk-", // OpenAI API keys
|
|
||||||
"sk_live_", // Stripe keys
|
|
||||||
"AKIA", // AWS access keys
|
|
||||||
];
|
];
|
||||||
|
|
||||||
/// Patterns that may indicate prompt injection in tool output.
|
/// Keyword patterns that indicate prompt injection in tool output.
|
||||||
|
/// These are specific enough to avoid false positives from normal content.
|
||||||
const INJECTION_PATTERNS: &[&str] = &[
|
const INJECTION_PATTERNS: &[&str] = &[
|
||||||
"ignore previous instructions",
|
"ignore previous instructions",
|
||||||
"ignore all previous",
|
"ignore all previous",
|
||||||
"disregard your instructions",
|
"disregard your instructions",
|
||||||
"you are now",
|
|
||||||
"new instructions:",
|
"new instructions:",
|
||||||
"system:",
|
|
||||||
"[INST]",
|
"[INST]",
|
||||||
"</scratchpad>",
|
"</scratchpad>",
|
||||||
"think step by step about",
|
|
||||||
];
|
];
|
||||||
|
|
||||||
/// Tool output sanitization middleware.
|
/// Tool output sanitization middleware.
|
||||||
@@ -105,22 +97,24 @@ impl AgentMiddleware for ToolOutputGuardMiddleware {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rule 2: Sensitive information detection — block output containing secrets (P2-22)
|
// Rule 2: Sensitive information detection — match actual secret values, not keywords
|
||||||
let output_lower = output_str.to_lowercase();
|
for pattern in SECRET_VALUE_PATTERNS {
|
||||||
for pattern in SENSITIVE_PATTERNS {
|
if let Ok(re) = regex::Regex::new(pattern) {
|
||||||
if output_lower.contains(pattern) {
|
if re.is_match(&output_str) {
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
"[ToolOutputGuard] BLOCKED tool '{}' output: sensitive pattern '{}'",
|
"[ToolOutputGuard] BLOCKED tool '{}' output: secret value matched pattern '{}'",
|
||||||
tool_name, pattern
|
tool_name, pattern
|
||||||
);
|
);
|
||||||
return Err(zclaw_types::ZclawError::Internal(format!(
|
return Err(zclaw_types::ZclawError::Internal(format!(
|
||||||
"[ToolOutputGuard] Tool '{}' output blocked: sensitive information detected ('{}')",
|
"[ToolOutputGuard] Tool '{}' output blocked: sensitive information detected",
|
||||||
tool_name, pattern
|
tool_name
|
||||||
)));
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rule 3: Injection marker detection — BLOCK the output (P2-22 fix)
|
// Rule 3: Injection marker detection — specific phrase matching
|
||||||
|
let output_lower = output_str.to_lowercase();
|
||||||
for pattern in INJECTION_PATTERNS {
|
for pattern in INJECTION_PATTERNS {
|
||||||
if output_lower.contains(pattern) {
|
if output_lower.contains(pattern) {
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
|
|||||||
@@ -696,13 +696,14 @@ export class GatewayClient {
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
case 'tool_call':
|
case 'tool_call':
|
||||||
// Tool call event
|
// Tool call start: onTool(name, input, '') — empty output signals start
|
||||||
if (callbacks.onTool && data.tool) {
|
if (callbacks.onTool && data.tool) {
|
||||||
callbacks.onTool(data.tool, JSON.stringify(data.input || {}), data.output || '');
|
callbacks.onTool(data.tool, JSON.stringify(data.input || {}), '');
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'tool_result':
|
case 'tool_result':
|
||||||
|
// Tool call end: onTool(name, '', output) — empty input signals end
|
||||||
if (callbacks.onTool && data.tool) {
|
if (callbacks.onTool && data.tool) {
|
||||||
callbacks.onTool(data.tool, '', String(data.result || data.output || ''));
|
callbacks.onTool(data.tool, '', String(data.result || data.output || ''));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,15 +34,15 @@
|
|||||||
|
|
||||||
**修复**: 区分完整工具(收到 ToolUseEnd)和不完整工具(仅收到 ToolUseStart/Delta)。完整工具照常执行,不完整工具发送取消 ToolEnd 事件。
|
**修复**: 区分完整工具(收到 ToolUseEnd)和不完整工具(仅收到 ToolUseStart/Delta)。完整工具照常执行,不完整工具发送取消 ToolEnd 事件。
|
||||||
|
|
||||||
### P1: 流式模式工具全串行
|
### P1: 流式模式工具全串行 — ✅ 已修复 (2026-04-24)
|
||||||
|
|
||||||
**文件**: `loop_runner.rs` 第 893-1070 行
|
**文件**: `loop_runner.rs` 流式模式工具执行段
|
||||||
|
|
||||||
非流式模式有 `JoinSet` + `Semaphore(3)` 并行执行 ReadOnly 工具,但流式模式用简单 `for` 循环串行执行所有工具。
|
非流式模式有 `JoinSet` + `Semaphore(3)` 并行执行 ReadOnly 工具,但流式模式用简单 `for` 循环串行执行所有工具。
|
||||||
|
|
||||||
**影响**: 多工具调用时延迟显著增加。
|
**修复**: 流式模式采用三阶段执行:Phase 1 中间件预检(serial) → Phase 2 并行+串行分区执行 → Phase 3 after_tool_call + 结果排序推送。
|
||||||
|
|
||||||
### P2: OpenAI 驱动工具参数静默替换
|
### P2: OpenAI 驱动工具参数静默替换 — ✅ 已修复 (2026-04-24)
|
||||||
|
|
||||||
**文件**: `crates/zclaw-runtime/src/driver/openai.rs` 第 222-228 行
|
**文件**: `crates/zclaw-runtime/src/driver/openai.rs` 第 222-228 行
|
||||||
|
|
||||||
@@ -59,24 +59,32 @@ let parsed_args = if args.is_empty() {
|
|||||||
|
|
||||||
JSON 解析失败时静默替换为 `{}`,结合 loop_runner.rs 的空参数处理(第 412-423 行),会注入 `_fallback_query` 替代实际参数。
|
JSON 解析失败时静默替换为 `{}`,结合 loop_runner.rs 的空参数处理(第 412-423 行),会注入 `_fallback_query` 替代实际参数。
|
||||||
|
|
||||||
### P2: ToolOutputGuard 过于激进
|
**修复**: 解析失败时返回 `_parse_error` + `_raw_args` 字段,让工具和 LLM 能感知到参数问题并自我修正。
|
||||||
|
|
||||||
|
### P2: ToolOutputGuard 过于激进 — ✅ 已修复 (2026-04-24)
|
||||||
|
|
||||||
**文件**: `crates/zclaw-runtime/src/middleware/tool_output_guard.rs` 第 109 行
|
**文件**: `crates/zclaw-runtime/src/middleware/tool_output_guard.rs` 第 109 行
|
||||||
|
|
||||||
使用 `to_lowercase()` 匹配敏感模式,合法内容中包含 "password"、"system:" 等字符串会被误拦。
|
使用 `to_lowercase()` 匹配敏感模式,合法内容中包含 "password"、"system:" 等字符串会被误拦。
|
||||||
|
|
||||||
### P2: ToolErrorMiddleware 失败计数器是全局的
|
**修复**: 改用 `regex` 精确匹配实际密钥值格式(如 `sk-[a-zA-Z0-9]{20,}`、`AKIA[A-Z0-9]{16}`、`key=value` 模式),不再误拦仅包含关键词的合法内容。移除了 "system:" 等过于宽泛的注入检测模式。
|
||||||
|
|
||||||
|
### P2: ToolErrorMiddleware 失败计数器是全局的 — ✅ 已修复 (2026-04-24)
|
||||||
|
|
||||||
**文件**: `crates/zclaw-runtime/src/middleware/tool_error.rs` 第 27 行
|
**文件**: `crates/zclaw-runtime/src/middleware/tool_error.rs` 第 27 行
|
||||||
|
|
||||||
`consecutive_failures: AtomicU32` 是结构体字段,所有 session 共享。高并发下 A session 失败 2 次 + B session 失败 1 次就会触发 AbortLoop(阈值 3)。
|
`consecutive_failures: AtomicU32` 是结构体字段,所有 session 共享。高并发下 A session 失败 2 次 + B session 失败 1 次就会触发 AbortLoop(阈值 3)。
|
||||||
|
|
||||||
### P3: Gateway 客户端 onTool 回调语义不一致
|
**修复**: 改用 `Mutex<HashMap<String, u32>>` 以 session_id 为 key 存储计数,每个会话独立跟踪。
|
||||||
|
|
||||||
|
### P3: Gateway 客户端 onTool 回调语义不一致 — ✅ 已修复 (2026-04-24)
|
||||||
|
|
||||||
**文件**: `desktop/src/lib/gateway-client.ts` 第 698-707 行
|
**文件**: `desktop/src/lib/gateway-client.ts` 第 698-707 行
|
||||||
|
|
||||||
`tool_call` 和 `tool_result` 两个 case 共用 `onTool` 回调,但参数约定不同,调用者必须通过 `output` 是否为空判断 start/end。
|
`tool_call` 和 `tool_result` 两个 case 共用 `onTool` 回调,但参数约定不同,调用者必须通过 `output` 是否为空判断 start/end。
|
||||||
|
|
||||||
|
**修复**: 明确 `tool_call` 的 output 始终为 `''`(修复了可能传递 data.output 的问题),添加清晰注释说明 start/end 语义约定。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 二、根因分析
|
## 二、根因分析
|
||||||
|
|||||||
@@ -9,6 +9,14 @@ tags: [log, history]
|
|||||||
|
|
||||||
> Append-only 操作记录。格式: `## [日期] 类型 | 描述`
|
> Append-only 操作记录。格式: `## [日期] 类型 | 描述`
|
||||||
|
|
||||||
|
## [2026-04-24] fix(runtime+middleware) | 工具调用 P1/P2/P3 全面修复
|
||||||
|
- **P1 流式工具并行**: 三阶段执行 (中间件预检→并行+串行分区→结果排序),ReadOnly 工具 JoinSet+Semaphore(3)
|
||||||
|
- **P2 OpenAI 驱动**: 参数解析失败不再静默替换为 `{}`,改为返回 `_parse_error`+`_raw_args` 让 LLM 自我修正
|
||||||
|
- **P2 ToolOutputGuard**: 从关键词匹配改为 regex 精确匹配实际密钥值 (sk-xxx/AKIA/PEM 等),消除误拦
|
||||||
|
- **P2 ToolErrorMiddleware**: 失败计数器从全局 AtomicU32 改为 per-session HashMap,消除跨会话误触发
|
||||||
|
- **P3 Gateway client**: 明确 tool_call/tool_result 的 onTool 回调语义约定 (output='' 为 start, input='' 为 end)
|
||||||
|
- **测试**: 91 tests PASS, tsc --noEmit PASS
|
||||||
|
|
||||||
## [2026-04-24] fix(runtime) | 工具调用两个 P0 修复
|
## [2026-04-24] fix(runtime) | 工具调用两个 P0 修复
|
||||||
- **P0: after_tool_call 中间件从未调用**: 流式+非流式模式均添加 `middleware_chain.run_after_tool_call()` 调用,ToolErrorMiddleware 和 ToolOutputGuardMiddleware 的 after 逻辑现在生效
|
- **P0: after_tool_call 中间件从未调用**: 流式+非流式模式均添加 `middleware_chain.run_after_tool_call()` 调用,ToolErrorMiddleware 和 ToolOutputGuardMiddleware 的 after 逻辑现在生效
|
||||||
- **P0: stream_errored 跳过所有工具**: 流式模式中 `stream_errored` 不再 `break 'outer`,改为区分完整工具(ToolUseEnd 已接收)和不完整工具;完整工具照常执行,不完整工具发送取消 ToolEnd 事件
|
- **P0: stream_errored 跳过所有工具**: 流式模式中 `stream_errored` 不再 `break 'outer`,改为区分完整工具(ToolUseEnd 已接收)和不完整工具;完整工具照常执行,不完整工具发送取消 ToolEnd 事件
|
||||||
|
|||||||
Reference in New Issue
Block a user